diff --git a/tagstudio/src/core/library/alchemy/library.py b/tagstudio/src/core/library/alchemy/library.py index 0a54f049d..fc8efbac5 100644 --- a/tagstudio/src/core/library/alchemy/library.py +++ b/tagstudio/src/core/library/alchemy/library.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from datetime import datetime, UTC import shutil from os import makedirs @@ -87,6 +88,33 @@ def get_default_tags() -> tuple[Tag, ...]: return archive_tag, favorite_tag +@dataclass(frozen=True) +class SearchResult: + """Wrapper for search results. + + :param total_count: total number of items for given query, might be different than len(items) + :param items: items for current page (size matches filter.page_size) + """ + + total_count: int + items: list[Entry] + + def __bool__(self) -> bool: + """Boolean evaluation for the wrapper. + + :return: True if there are items in the result. + """ + return self.total_count > 0 + + def __len__(self) -> int: + """Return the total number of items in the result.""" + return len(self.items) + + def __getitem__(self, index: int) -> Entry: + """Allow to access items via index directly on the wrapper.""" + return self.items[index] + + class Library: """Class for the Library object, and all CRUD operations made upon it.""" @@ -325,7 +353,7 @@ def has_path_entry(self, path: Path) -> bool: def search_library( self, search: FilterState, - ) -> tuple[int, list[Entry]]: + ) -> SearchResult: """Filter library by search query. :return: number of entries matching the query and one page of results. @@ -401,11 +429,14 @@ def search_library( ), ) - entries_ = list(session.scalars(statement).unique()) + res = SearchResult( + total_count=count_all, + items=list(session.scalars(statement).unique()), + ) session.expunge_all() - return count_all, entries_ + return res def search_tags( self, diff --git a/tagstudio/src/core/utils/dupe_files.py b/tagstudio/src/core/utils/dupe_files.py index 719470189..425785ad7 100644 --- a/tagstudio/src/core/utils/dupe_files.py +++ b/tagstudio/src/core/utils/dupe_files.py @@ -50,15 +50,15 @@ def refresh_dupe_files(self, results_filepath: str | Path): # The file is not in the library directory continue - _, entries = self.library.search_library( + results = self.library.search_library( FilterState(path=path_relative), ) - if not entries: + if not results: # file not in library continue - files.append(entries[0]) + files.append(results[0]) if not len(files) > 1: # only one file in the group, nothing to do diff --git a/tagstudio/src/qt/ts_qt.py b/tagstudio/src/qt/ts_qt.py index 0b215090a..92d5318ad 100644 --- a/tagstudio/src/qt/ts_qt.py +++ b/tagstudio/src/qt/ts_qt.py @@ -1009,26 +1009,26 @@ def filter_items(self, filter: FilterState | None = None) -> None: self.main_window.statusbar.repaint() start_time = time.time() - query_count, page_items = self.lib.search_library(self.filter) + results = self.lib.search_library(self.filter) - logger.info("items to render", count=len(page_items)) + logger.info("items to render", count=len(results)) end_time = time.time() if self.filter.summary: self.main_window.statusbar.showMessage( - f'{query_count} Results Found for "{self.filter.summary}" ({format_timespan(end_time - start_time)})' + f'{results.total_count} Results Found for "{self.filter.summary}" ({format_timespan(end_time - start_time)})' ) else: self.main_window.statusbar.showMessage( - f"{query_count} Results ({format_timespan(end_time - start_time)})" + f"{results.total_count} Results ({format_timespan(end_time - start_time)})" ) # update page content - self.frame_content = list(page_items) + self.frame_content = results.items self.update_thumbs() # update pagination - self.pages_count = math.ceil(query_count / self.filter.page_size) + self.pages_count = math.ceil(results.total_count / self.filter.page_size) self.main_window.pagination.update_buttons( self.pages_count, self.filter.page_index, emit=False ) diff --git a/tagstudio/src/qt/widgets/item_thumb.py b/tagstudio/src/qt/widgets/item_thumb.py index 05a02d060..04f61d77f 100644 --- a/tagstudio/src/qt/widgets/item_thumb.py +++ b/tagstudio/src/qt/widgets/item_thumb.py @@ -487,7 +487,7 @@ def on_badge_check(self, badge_type: BadgeType): # update the entry self.driver.frame_content[idx] = self.lib.search_library( FilterState(id=entry.id) - )[1][0] + ).items[0] self.driver.update_badges(update_items) diff --git a/tagstudio/src/qt/widgets/preview_panel.py b/tagstudio/src/qt/widgets/preview_panel.py index 386fa9ddc..2b31afc98 100644 --- a/tagstudio/src/qt/widgets/preview_panel.py +++ b/tagstudio/src/qt/widgets/preview_panel.py @@ -63,12 +63,12 @@ def update_selected_entry(driver: "QtDriver"): for grid_idx in driver.selected: entry = driver.frame_content[grid_idx] # reload entry - _, entries = driver.lib.search_library(FilterState(id=entry.id)) + results = driver.lib.search_library(FilterState(id=entry.id)) logger.info( - "found item", entries=entries, grid_idx=grid_idx, lookup_id=entry.id + "found item", entries=len(results), grid_idx=grid_idx, lookup_id=entry.id ) - assert entries, f"Entry not found: {entry.id}" - driver.frame_content[grid_idx] = entries[0] + assert results, f"Entry not found: {entry.id}" + driver.frame_content[grid_idx] = next(results) class PreviewPanel(QWidget): @@ -499,11 +499,14 @@ def update_widgets(self) -> bool: # TODO - Entry reload is maybe not necessary for grid_idx in self.driver.selected: entry = self.driver.frame_content[grid_idx] - _, entries = self.lib.search_library(FilterState(id=entry.id)) + results = self.lib.search_library(FilterState(id=entry.id)) logger.info( - "found item", entries=entries, grid_idx=grid_idx, lookup_id=entry.id + "found item", + entries=len(results.items), + grid_idx=grid_idx, + lookup_id=entry.id, ) - self.driver.frame_content[grid_idx] = entries[0] + self.driver.frame_content[grid_idx] = results[0] if len(self.driver.selected) == 1: # 1 Selected Entry diff --git a/tagstudio/tests/macros/test_missing_files.py b/tagstudio/tests/macros/test_missing_files.py index 68e9c2572..f269d249e 100644 --- a/tagstudio/tests/macros/test_missing_files.py +++ b/tagstudio/tests/macros/test_missing_files.py @@ -27,5 +27,5 @@ def test_refresh_missing_files(library: Library): assert list(registry.fix_missing_files()) == [1, 2] # `bar.md` should be relinked to new correct path - _, entries = library.search_library(FilterState(path="bar.md")) - assert entries[0].path == pathlib.Path("bar.md") + results = library.search_library(FilterState(path="bar.md")) + assert results[0].path == pathlib.Path("bar.md") diff --git a/tagstudio/tests/test_library.py b/tagstudio/tests/test_library.py index 2232626c3..584eec5df 100644 --- a/tagstudio/tests/test_library.py +++ b/tagstudio/tests/test_library.py @@ -57,16 +57,16 @@ def test_library_search(library, generate_tag, entry_full): assert library.entries_count == 2 tag = list(entry_full.tags)[0] - query_count, items = library.search_library( + results = library.search_library( FilterState( tag=tag.name, ), ) - assert query_count == 1 - assert len(items) == 1 + assert results.total_count == 1 + assert len(results) == 1 - entry = items[0] + entry = results[0] assert {x.name for x in entry.tags} == { "foo", } @@ -94,9 +94,9 @@ def test_tag_search(library): def test_get_entry(library, entry_min): assert entry_min.id - cnt, entries = library.search_library(FilterState(id=entry_min.id)) - assert len(entries) == cnt == 1 - assert entries[0].tags + results = library.search_library(FilterState(id=entry_min.id)) + assert len(results) == results.total_count == 1 + assert results[0].tags def test_entries_count(library): @@ -105,14 +105,14 @@ def test_entries_count(library): for x in range(10) ] library.add_entries(entries) - matches, page = library.search_library( + results = library.search_library( FilterState( page_size=5, ) ) - assert matches == 12 - assert len(page) == 5 + assert results.total_count == 12 + assert len(results) == 5 def test_add_field_to_entry(library): @@ -146,8 +146,8 @@ def test_add_field_tag(library, entry_full, generate_tag): library.add_field_tag(entry_full, tag, tag_field.type_key) # Then - _, entries = library.search_library(FilterState(id=entry_full.id)) - tag_field = entries[0].tag_box_fields[0] + results = library.search_library(FilterState(id=entry_full.id)) + tag_field = results[0].tag_box_fields[0] assert [x.name for x in tag_field.tags if x.name == tag_name] @@ -179,15 +179,15 @@ def test_search_filter_extensions(library, is_exclude): library.set_prefs(LibraryPrefs.EXTENSION_LIST, ["md"]) # When - query_count, items = library.search_library( + results = library.search_library( FilterState(), ) # Then - assert query_count == 1 - assert len(items) == 1 + assert results.total_count == 1 + assert len(results) == 1 - entry = items[0] + entry = results[0] assert (entry.path.suffix == ".txt") == is_exclude @@ -200,15 +200,15 @@ def test_search_library_case_insensitive(library): tag = list(entry.tags)[0] # When - query_count, items = library.search_library( + results = library.search_library( FilterState(tag=tag.name.upper()), ) # Then - assert query_count == 1 - assert len(items) == 1 + assert results.total_count == 1 + assert len(results) == 1 - assert items[0].id == entry.id + assert results[0].id == entry.id def test_preferences(library): @@ -231,11 +231,11 @@ def test_save_windows_path(library, generate_tag): # library.add_tag(tag) library.add_field_tag(entry, tag, create_field=True) - _, found = library.search_library(FilterState(tag=tag_name)) - assert found + results = library.search_library(FilterState(tag=tag_name)) + assert results # path should be saved in posix format - assert str(found[0].path) == "foo/bar.txt" + assert str(results[0].path) == "foo/bar.txt" def test_remove_entry_field(library, entry_full): @@ -312,13 +312,13 @@ def test_mirror_entry_fields(library, entry_full): entry_id = library.add_entries([target_entry])[0] - _, entries = library.search_library(FilterState(id=entry_id)) - new_entry = entries[0] + results = library.search_library(FilterState(id=entry_id)) + new_entry = results[0] library.mirror_entry_fields(new_entry, entry_full) - _, entries = library.search_library(FilterState(id=entry_id)) - entry = entries[0] + results = library.search_library(FilterState(id=entry_id)) + entry = results[0] assert len(entry.fields) == 4 assert {x.type_key for x in entry.fields} == { @@ -350,13 +350,11 @@ def test_remove_tag_from_field(library, entry_full): ], ) def test_search_file_name(library, query_name, has_result): - res_count, items = library.search_library( + results = library.search_library( FilterState(name=query_name), ) - assert ( - res_count == has_result - ), f"mismatch with query: {query_name}, result: {res_count}" + assert results.total_count == has_result @pytest.mark.parametrize( @@ -369,13 +367,11 @@ def test_search_file_name(library, query_name, has_result): ], ) def test_search_entry_id(library, query_name, has_result): - res_count, items = library.search_library( + results = library.search_library( FilterState(id=query_name), ) - assert ( - res_count == has_result - ), f"mismatch with query: {query_name}, result: {res_count}" + assert results.total_count == has_result def test_update_field_order(library, entry_full):