Skip to content

Commit

Permalink
feat: make search results more ergonomic (TagStudioDev#498)
Browse files Browse the repository at this point in the history
  • Loading branch information
yedpodtrzitko authored Sep 13, 2024
1 parent a8fdae8 commit c159638
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 57 deletions.
37 changes: 34 additions & 3 deletions tagstudio/src/core/library/alchemy/library.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dataclasses import dataclass
from datetime import datetime, UTC
import shutil
from os import makedirs
Expand Down Expand Up @@ -87,6 +88,33 @@ def get_default_tags() -> tuple[Tag, ...]:
return archive_tag, favorite_tag


@dataclass(frozen=True)
class SearchResult:
"""Wrapper for search results.
:param total_count: total number of items for given query, might be different than len(items)
:param items: items for current page (size matches filter.page_size)
"""

total_count: int
items: list[Entry]

def __bool__(self) -> bool:
"""Boolean evaluation for the wrapper.
:return: True if there are items in the result.
"""
return self.total_count > 0

def __len__(self) -> int:
"""Return the total number of items in the result."""
return len(self.items)

def __getitem__(self, index: int) -> Entry:
"""Allow to access items via index directly on the wrapper."""
return self.items[index]


class Library:
"""Class for the Library object, and all CRUD operations made upon it."""

Expand Down Expand Up @@ -325,7 +353,7 @@ def has_path_entry(self, path: Path) -> bool:
def search_library(
self,
search: FilterState,
) -> tuple[int, list[Entry]]:
) -> SearchResult:
"""Filter library by search query.
:return: number of entries matching the query and one page of results.
Expand Down Expand Up @@ -401,11 +429,14 @@ def search_library(
),
)

entries_ = list(session.scalars(statement).unique())
res = SearchResult(
total_count=count_all,
items=list(session.scalars(statement).unique()),
)

session.expunge_all()

return count_all, entries_
return res

def search_tags(
self,
Expand Down
6 changes: 3 additions & 3 deletions tagstudio/src/core/utils/dupe_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ def refresh_dupe_files(self, results_filepath: str | Path):
# The file is not in the library directory
continue

_, entries = self.library.search_library(
results = self.library.search_library(
FilterState(path=path_relative),
)

if not entries:
if not results:
# file not in library
continue

files.append(entries[0])
files.append(results[0])

if not len(files) > 1:
# only one file in the group, nothing to do
Expand Down
12 changes: 6 additions & 6 deletions tagstudio/src/qt/ts_qt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,26 +1009,26 @@ def filter_items(self, filter: FilterState | None = None) -> None:
self.main_window.statusbar.repaint()
start_time = time.time()

query_count, page_items = self.lib.search_library(self.filter)
results = self.lib.search_library(self.filter)

logger.info("items to render", count=len(page_items))
logger.info("items to render", count=len(results))

end_time = time.time()
if self.filter.summary:
self.main_window.statusbar.showMessage(
f'{query_count} Results Found for "{self.filter.summary}" ({format_timespan(end_time - start_time)})'
f'{results.total_count} Results Found for "{self.filter.summary}" ({format_timespan(end_time - start_time)})'
)
else:
self.main_window.statusbar.showMessage(
f"{query_count} Results ({format_timespan(end_time - start_time)})"
f"{results.total_count} Results ({format_timespan(end_time - start_time)})"
)

# update page content
self.frame_content = list(page_items)
self.frame_content = results.items
self.update_thumbs()

# update pagination
self.pages_count = math.ceil(query_count / self.filter.page_size)
self.pages_count = math.ceil(results.total_count / self.filter.page_size)
self.main_window.pagination.update_buttons(
self.pages_count, self.filter.page_index, emit=False
)
Expand Down
2 changes: 1 addition & 1 deletion tagstudio/src/qt/widgets/item_thumb.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def on_badge_check(self, badge_type: BadgeType):
# update the entry
self.driver.frame_content[idx] = self.lib.search_library(
FilterState(id=entry.id)
)[1][0]
).items[0]

self.driver.update_badges(update_items)

Expand Down
17 changes: 10 additions & 7 deletions tagstudio/src/qt/widgets/preview_panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ def update_selected_entry(driver: "QtDriver"):
for grid_idx in driver.selected:
entry = driver.frame_content[grid_idx]
# reload entry
_, entries = driver.lib.search_library(FilterState(id=entry.id))
results = driver.lib.search_library(FilterState(id=entry.id))
logger.info(
"found item", entries=entries, grid_idx=grid_idx, lookup_id=entry.id
"found item", entries=len(results), grid_idx=grid_idx, lookup_id=entry.id
)
assert entries, f"Entry not found: {entry.id}"
driver.frame_content[grid_idx] = entries[0]
assert results, f"Entry not found: {entry.id}"
driver.frame_content[grid_idx] = next(results)


class PreviewPanel(QWidget):
Expand Down Expand Up @@ -499,11 +499,14 @@ def update_widgets(self) -> bool:
# TODO - Entry reload is maybe not necessary
for grid_idx in self.driver.selected:
entry = self.driver.frame_content[grid_idx]
_, entries = self.lib.search_library(FilterState(id=entry.id))
results = self.lib.search_library(FilterState(id=entry.id))
logger.info(
"found item", entries=entries, grid_idx=grid_idx, lookup_id=entry.id
"found item",
entries=len(results.items),
grid_idx=grid_idx,
lookup_id=entry.id,
)
self.driver.frame_content[grid_idx] = entries[0]
self.driver.frame_content[grid_idx] = results[0]

if len(self.driver.selected) == 1:
# 1 Selected Entry
Expand Down
4 changes: 2 additions & 2 deletions tagstudio/tests/macros/test_missing_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,5 @@ def test_refresh_missing_files(library: Library):
assert list(registry.fix_missing_files()) == [1, 2]

# `bar.md` should be relinked to new correct path
_, entries = library.search_library(FilterState(path="bar.md"))
assert entries[0].path == pathlib.Path("bar.md")
results = library.search_library(FilterState(path="bar.md"))
assert results[0].path == pathlib.Path("bar.md")
66 changes: 31 additions & 35 deletions tagstudio/tests/test_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,16 @@ def test_library_search(library, generate_tag, entry_full):
assert library.entries_count == 2
tag = list(entry_full.tags)[0]

query_count, items = library.search_library(
results = library.search_library(
FilterState(
tag=tag.name,
),
)

assert query_count == 1
assert len(items) == 1
assert results.total_count == 1
assert len(results) == 1

entry = items[0]
entry = results[0]
assert {x.name for x in entry.tags} == {
"foo",
}
Expand Down Expand Up @@ -94,9 +94,9 @@ def test_tag_search(library):

def test_get_entry(library, entry_min):
assert entry_min.id
cnt, entries = library.search_library(FilterState(id=entry_min.id))
assert len(entries) == cnt == 1
assert entries[0].tags
results = library.search_library(FilterState(id=entry_min.id))
assert len(results) == results.total_count == 1
assert results[0].tags


def test_entries_count(library):
Expand All @@ -105,14 +105,14 @@ def test_entries_count(library):
for x in range(10)
]
library.add_entries(entries)
matches, page = library.search_library(
results = library.search_library(
FilterState(
page_size=5,
)
)

assert matches == 12
assert len(page) == 5
assert results.total_count == 12
assert len(results) == 5


def test_add_field_to_entry(library):
Expand Down Expand Up @@ -146,8 +146,8 @@ def test_add_field_tag(library, entry_full, generate_tag):
library.add_field_tag(entry_full, tag, tag_field.type_key)

# Then
_, entries = library.search_library(FilterState(id=entry_full.id))
tag_field = entries[0].tag_box_fields[0]
results = library.search_library(FilterState(id=entry_full.id))
tag_field = results[0].tag_box_fields[0]
assert [x.name for x in tag_field.tags if x.name == tag_name]


Expand Down Expand Up @@ -179,15 +179,15 @@ def test_search_filter_extensions(library, is_exclude):
library.set_prefs(LibraryPrefs.EXTENSION_LIST, ["md"])

# When
query_count, items = library.search_library(
results = library.search_library(
FilterState(),
)

# Then
assert query_count == 1
assert len(items) == 1
assert results.total_count == 1
assert len(results) == 1

entry = items[0]
entry = results[0]
assert (entry.path.suffix == ".txt") == is_exclude


Expand All @@ -200,15 +200,15 @@ def test_search_library_case_insensitive(library):
tag = list(entry.tags)[0]

# When
query_count, items = library.search_library(
results = library.search_library(
FilterState(tag=tag.name.upper()),
)

# Then
assert query_count == 1
assert len(items) == 1
assert results.total_count == 1
assert len(results) == 1

assert items[0].id == entry.id
assert results[0].id == entry.id


def test_preferences(library):
Expand All @@ -231,11 +231,11 @@ def test_save_windows_path(library, generate_tag):
# library.add_tag(tag)
library.add_field_tag(entry, tag, create_field=True)

_, found = library.search_library(FilterState(tag=tag_name))
assert found
results = library.search_library(FilterState(tag=tag_name))
assert results

# path should be saved in posix format
assert str(found[0].path) == "foo/bar.txt"
assert str(results[0].path) == "foo/bar.txt"


def test_remove_entry_field(library, entry_full):
Expand Down Expand Up @@ -312,13 +312,13 @@ def test_mirror_entry_fields(library, entry_full):

entry_id = library.add_entries([target_entry])[0]

_, entries = library.search_library(FilterState(id=entry_id))
new_entry = entries[0]
results = library.search_library(FilterState(id=entry_id))
new_entry = results[0]

library.mirror_entry_fields(new_entry, entry_full)

_, entries = library.search_library(FilterState(id=entry_id))
entry = entries[0]
results = library.search_library(FilterState(id=entry_id))
entry = results[0]

assert len(entry.fields) == 4
assert {x.type_key for x in entry.fields} == {
Expand Down Expand Up @@ -350,13 +350,11 @@ def test_remove_tag_from_field(library, entry_full):
],
)
def test_search_file_name(library, query_name, has_result):
res_count, items = library.search_library(
results = library.search_library(
FilterState(name=query_name),
)

assert (
res_count == has_result
), f"mismatch with query: {query_name}, result: {res_count}"
assert results.total_count == has_result


@pytest.mark.parametrize(
Expand All @@ -369,13 +367,11 @@ def test_search_file_name(library, query_name, has_result):
],
)
def test_search_entry_id(library, query_name, has_result):
res_count, items = library.search_library(
results = library.search_library(
FilterState(id=query_name),
)

assert (
res_count == has_result
), f"mismatch with query: {query_name}, result: {res_count}"
assert results.total_count == has_result


def test_update_field_order(library, entry_full):
Expand Down

0 comments on commit c159638

Please sign in to comment.