From 12c3b3c06d25f4a3725f41bbae6466276895839b Mon Sep 17 00:00:00 2001 From: kyleclo Date: Sat, 5 Aug 2023 11:48:12 -0700 Subject: [PATCH 1/5] just modify box indexr test to be floats based making sure it works on sub integer --- tests/test_magelib/test_indexer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_magelib/test_indexer.py b/tests/test_magelib/test_indexer.py index 7aaba51..d503c84 100644 --- a/tests/test_magelib/test_indexer.py +++ b/tests/test_magelib/test_indexer.py @@ -103,22 +103,22 @@ def test_finds_matching_entities_in_doc_order(self): def test_finds_matching_entities_accounts_for_pages(self): entities_to_index = [ - Entity(boxes=[Box(0, 0, 1, 1, page=0), Box(2, 2, 1, 1, page=1)]), - Entity(boxes=[Box(4, 4, 1, 1, page=1)]), - Entity(boxes=[Box(100, 100, 1, 1, page=0)]), + Entity(boxes=[Box(0.0, 0.0, 0.1, 0.1, page=0), Box(0.2, 0.2, 0.1, 0.1, page=1)]), + Entity(boxes=[Box(0.4, 0.4, 0.1, 0.1, page=1)]), + Entity(boxes=[Box(10.0, 10.0, 0.1, 0.1, page=0)]), ] index = EntityBoxIndexer(entities_to_index) # shouldnt intersect any given page 0 - probe = Entity(boxes=[Box(1, 1, 5, 5, page=0), Box(9, 9, 5, 5, page=0)]) + probe = Entity(boxes=[Box(0.1, 0.1, 0.5, 0.5, page=0), Box(0.9, 0.9, 0.5, 0.5, page=0)]) matches = index.find(probe) self.assertEqual(len(matches), 1) self.assertEqual(matches, [entities_to_index[0]]) # shoudl intersect after switching to page 1 (and the page 2 box doesnt intersect) - probe = Entity(boxes=[Box(1, 1, 5, 5, page=1), Box(100, 100, 1, 1, page=2)]) + probe = Entity(boxes=[Box(0.1, 0.1, 0.5, 0.5, page=1), Box(10.0, 10.0, 0.1, 0.1, page=2)]) matches = index.find(probe) self.assertEqual(len(matches), 2) From d0452083b6ef4be0b84790eea8bc0cbbf517483e Mon Sep 17 00:00:00 2001 From: kyleclo Date: Sat, 5 Aug 2023 12:08:24 -0700 Subject: [PATCH 2/5] add find by box --- papermage/magelib/annotation.py | 21 +++++++++++++++-- papermage/magelib/document.py | 15 ++++++++++-- tests/test_magelib/test_document.py | 36 +++++++++++++++++++++++++++++ 3 files changed, 68 insertions(+), 4 deletions(-) diff --git a/papermage/magelib/annotation.py b/papermage/magelib/annotation.py index a29680f..4b25680 100644 --- a/papermage/magelib/annotation.py +++ b/papermage/magelib/annotation.py @@ -66,11 +66,28 @@ def from_json(cls, annotation_json: Union[Dict, List]) -> "Annotation": pass def __getattr__(self, field: str) -> List["Annotation"]: - """This method allows you to access overlapping Annotations within the Document""" + """This Overloading is convenient syntax since the `entity.layer` operation is intuitive for folks.""" + return self.find_by_span(field=field) + + def find_by_span(self, field: str) -> List["Annotation"]: + """This method allows you to access overlapping Annotations + within the Document based on Span""" + if self.doc is None: + raise ValueError("This annotation is not attached to a document") + + if field in self.doc.fields: + return self.doc.find_by_span(self, field) + + return self.__getattribute__(field) + + def find_by_box(self, field: str) -> List["Annotation"]: + """This method allows you to access overlapping Annotations + within the Document based on Box""" + if self.doc is None: raise ValueError("This annotation is not attached to a document") if field in self.doc.fields: - return self.doc.find_span_overlap_entities(self, field) + return self.doc.find_by_box(self, field) return self.__getattribute__(field) diff --git a/papermage/magelib/document.py b/papermage/magelib/document.py index 409ab23..21c1890 100644 --- a/papermage/magelib/document.py +++ b/papermage/magelib/document.py @@ -7,7 +7,13 @@ from typing import Dict, Iterable, List, Optional -from papermage.magelib import Entity, EntitySpanIndexer, Image, Metadata +from papermage.magelib import ( + Entity, + EntityBoxIndexer, + EntitySpanIndexer, + Image, + Metadata, +) # document field names SymbolsFieldName = "symbols" @@ -28,14 +34,18 @@ def __init__(self, symbols: str, metadata: Optional[Metadata] = None): self.symbols = symbols self.metadata = metadata if metadata else Metadata() self.__entity_span_indexers: Dict[str, EntitySpanIndexer] = {} + self.__entity_box_indexers: Dict[str, EntityBoxIndexer] = {} @property def fields(self) -> List[str]: return list(self.__entity_span_indexers.keys()) + self.SPECIAL_FIELDS - def find_span_overlap_entities(self, query: Entity, field_name: str) -> List[Entity]: + def find_by_span(self, query: Entity, field_name: str) -> List[Entity]: return self.__entity_span_indexers[field_name].find(query=query) + def find_by_box(self, query: Entity, field_name: str) -> List[Entity]: + return self.__entity_box_indexers[field_name].find(query=query) + def check_field_name_availability(self, field_name: str) -> None: if field_name in self.SPECIAL_FIELDS: raise AssertionError(f"{field_name} not allowed Document.SPECIAL_FIELDS.") @@ -55,6 +65,7 @@ def annotate_entity(self, field_name: str, entities: List[Entity]) -> None: setattr(self, field_name, entities) self.__entity_span_indexers[field_name] = EntitySpanIndexer(entities=entities) + self.__entity_box_indexers[field_name] = EntityBoxIndexer(entities=entities) def remove_entity(self, field_name: str): for entity in getattr(self, field_name): diff --git a/tests/test_magelib/test_document.py b/tests/test_magelib/test_document.py index 47c421c..51922c4 100644 --- a/tests/test_magelib/test_document.py +++ b/tests/test_magelib/test_document.py @@ -88,3 +88,39 @@ def test_metadata_deserializes_when_empty(self): self.assertEqual(symbols, doc.symbols) self.assertEqual(0, len(doc.metadata)) + + def test_cross_referencing(self): + doc = Document("This is a test document!") + # boxes are in a top-left to bottom-right diagonal fashion (same page) + tokens = [ + Entity.from_json({"spans": [[0, 4]], "boxes": [[0, 0, 0.5, 0.5, 0]]}), + Entity.from_json({"spans": [[5, 7]], "boxes": [[1, 1, 0.5, 0.5, 0]]}), + Entity.from_json({"spans": [[8, 9]], "boxes": [[2, 2, 0.5, 0.5, 0]]}), + Entity.from_json({"spans": [[10, 14]], "boxes": [[3, 3, 0.5, 0.5, 0]]}), + Entity.from_json({"spans": [[15, 23]], "boxes": [[4, 4, 0.5, 0.5, 0]]}), + Entity.from_json({"spans": [[23, 24]], "boxes": [[5, 5, 0.5, 0.5, 0]]}), + ] + # boxes are also same diagonal fashion, but bigger. + # last box super big on wrong page. + chunks = [ + Entity.from_json({"spans": [[0, 9]], "boxes": [[0, 0, 2.01, 2.01, 0]]}), + Entity.from_json({"spans": [[12, 23]], "boxes": [[3.0, 3.0, 4.0, 4.0, 0]]}), + Entity.from_json({"spans": [[23, 24]], "boxes": [[0, 0, 10.0, 10.0, 1]]}), + ] + doc.annotate_entity(field_name="tokens", entities=tokens) + doc.annotate_entity(field_name="chunks", entities=chunks) + + # find by span is the default overload of Entity.__attr__ + self.assertListEqual(doc.chunks[0].tokens, tokens[0:3]) + self.assertListEqual(doc.chunks[1].tokens, tokens[3:5]) + self.assertListEqual(doc.chunks[2].tokens, [tokens[5]]) + + # find by span works fine + self.assertListEqual(doc.chunks[0].tokens, doc.find_by_span(query=doc.chunks[0], field_name="tokens")) + self.assertListEqual(doc.chunks[1].tokens, doc.find_by_span(query=doc.chunks[1], field_name="tokens")) + self.assertListEqual(doc.chunks[2].tokens, doc.find_by_span(query=doc.chunks[2], field_name="tokens")) + + # find by box + self.assertListEqual(doc.find_by_box(query=doc.chunks[0], field_name="tokens"), doc.tokens[0:3]) + self.assertListEqual(doc.find_by_box(query=doc.chunks[1], field_name="tokens"), doc.tokens[3:6]) + self.assertListEqual(doc.find_by_box(query=doc.chunks[2], field_name="tokens"), []) From 23c91a7aa04d4e58a7fac7e5092a3a087825cc15 Mon Sep 17 00:00:00 2001 From: kyleclo Date: Sat, 5 Aug 2023 12:08:35 -0700 Subject: [PATCH 3/5] toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d88206d..3bcb3f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = 'papermage' -version = '0.4.0' +version = '0.5.0' description = 'Papermage. Casting magic over scientific PDFs.' license = {text = 'Apache-2.0'} readme = 'README.md' From d365d01f68d497f56b9532e011c42e2ddaa62a19 Mon Sep 17 00:00:00 2001 From: kyleclo Date: Sat, 5 Aug 2023 12:43:40 -0700 Subject: [PATCH 4/5] relax overlap for boxes --- papermage/magelib/indexer.py | 5 +++-- tests/test_magelib/test_indexer.py | 7 ++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/papermage/magelib/indexer.py b/papermage/magelib/indexer.py index 0a19730..463aae2 100644 --- a/papermage/magelib/indexer.py +++ b/papermage/magelib/indexer.py @@ -107,7 +107,7 @@ class EntityBoxIndexer(Indexer): @kylel """ - def __init__(self, entities: List[Entity]) -> None: + def __init__(self, entities: List[Entity], allow_overlap: bool = True) -> None: self._entities = entities self._box_id_to_entity_id = {} @@ -125,7 +125,8 @@ def __init__(self, entities: List[Entity]) -> None: self._np_boxes_y2 = np.array([b.t + b.h for b in self._boxes]) self._np_boxes_page = np.array([b.page for b in self._boxes]) - self._ensure_disjoint() + if not allow_overlap: + self._ensure_disjoint() def _find_overlap_boxes(self, query: Box) -> List[int]: x1, y1, x2, y2 = query.xy_coordinates diff --git a/tests/test_magelib/test_indexer.py b/tests/test_magelib/test_indexer.py index d503c84..4505647 100644 --- a/tests/test_magelib/test_indexer.py +++ b/tests/test_magelib/test_indexer.py @@ -74,7 +74,8 @@ def test_overlap_within_single_entity_fails_checks(self): entities = [Entity(boxes=[Box(0, 0, 5, 5, page=0), Box(4, 4, 7, 7, page=0)])] with self.assertRaises(ValueError): - EntityBoxIndexer(entities) + EntityBoxIndexer(entities=entities, allow_overlap=False) + EntityBoxIndexer(entities=entities, allow_overlap=True) def test_overlap_between_entities_fails_checks(self): entities = [ @@ -83,7 +84,8 @@ def test_overlap_between_entities_fails_checks(self): ] with self.assertRaises(ValueError): - EntityBoxIndexer(entities) + EntityBoxIndexer(entities=entities, allow_overlap=False) + EntityBoxIndexer(entities=entities, allow_overlap=True) def test_finds_matching_entities_in_doc_order(self): entities_to_index = [ @@ -91,7 +93,6 @@ def test_finds_matching_entities_in_doc_order(self): Entity(boxes=[Box(4, 4, 1, 1, page=0)]), Entity(boxes=[Box(100, 100, 1, 1, page=0)]), ] - index = EntityBoxIndexer(entities_to_index) # should intersect 1 and 2 but not 3 From cea9976a826dc3cf80e79bbf75cc9fa389f39d84 Mon Sep 17 00:00:00 2001 From: kyleclo Date: Sat, 5 Aug 2023 12:50:26 -0700 Subject: [PATCH 5/5] add tests --- tests/test_magelib/test_document.py | 55 +++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/tests/test_magelib/test_document.py b/tests/test_magelib/test_document.py index 51922c4..6cd17ac 100644 --- a/tests/test_magelib/test_document.py +++ b/tests/test_magelib/test_document.py @@ -124,3 +124,58 @@ def test_cross_referencing(self): self.assertListEqual(doc.find_by_box(query=doc.chunks[0], field_name="tokens"), doc.tokens[0:3]) self.assertListEqual(doc.find_by_box(query=doc.chunks[1], field_name="tokens"), doc.tokens[3:6]) self.assertListEqual(doc.find_by_box(query=doc.chunks[2], field_name="tokens"), []) + + def test_cross_referencing_with_missing_entity_fields(self): + """What happens when annotate a Doc with entiites missing spans or boxes? + How does the cross-referencing operation behave?""" + # this is same test as above but removing select fields + doc = Document("This is a test document!") + + # 1) tokens no boxes + tokens = [ + Entity.from_json({"spans": [[0, 4]]}), + Entity.from_json({"spans": [[5, 7]]}), + Entity.from_json({"spans": [[8, 9]]}), + Entity.from_json({"spans": [[10, 14]]}), + Entity.from_json({"spans": [[15, 23]]}), + Entity.from_json({"spans": [[23, 24]]}), + ] + chunks = [ + Entity.from_json({"spans": [[0, 9]], "boxes": [[0, 0, 2.01, 2.01, 0]]}), + Entity.from_json({"spans": [[12, 23]], "boxes": [[3.0, 3.0, 4.0, 4.0, 0]]}), + Entity.from_json({"spans": [[23, 24]], "boxes": [[0, 0, 10.0, 10.0, 1]]}), + ] + doc.annotate_entity(field_name="tokens", entities=tokens) + doc.annotate_entity(field_name="chunks", entities=chunks) + self.assertListEqual(doc.find_by_box(query=doc.chunks[0], field_name="tokens"), []) + self.assertListEqual(doc.find_by_box(query=doc.chunks[1], field_name="tokens"), []) + self.assertListEqual(doc.find_by_box(query=doc.chunks[2], field_name="tokens"), []) + # the reverse works just fine + self.assertListEqual(doc.find_by_box(query=doc.tokens[0], field_name="chunks"), []) + self.assertListEqual(doc.find_by_box(query=doc.tokens[1], field_name="chunks"), []) + self.assertListEqual(doc.find_by_box(query=doc.tokens[2], field_name="chunks"), []) + + # 2) tokens no spans + doc = Document("This is a test document!") + tokens = [ + Entity.from_json({"boxes": [[0, 0, 0.5, 0.5, 0]]}), + Entity.from_json({"boxes": [[1, 1, 0.5, 0.5, 0]]}), + Entity.from_json({"boxes": [[2, 2, 0.5, 0.5, 0]]}), + Entity.from_json({"boxes": [[3, 3, 0.5, 0.5, 0]]}), + Entity.from_json({"boxes": [[4, 4, 0.5, 0.5, 0]]}), + Entity.from_json({"boxes": [[5, 5, 0.5, 0.5, 0]]}), + ] + chunks = [ + Entity.from_json({"spans": [[0, 9]], "boxes": [[0, 0, 2.01, 2.01, 0]]}), + Entity.from_json({"spans": [[12, 23]], "boxes": [[3.0, 3.0, 4.0, 4.0, 0]]}), + Entity.from_json({"spans": [[23, 24]], "boxes": [[0, 0, 10.0, 10.0, 1]]}), + ] + doc.annotate_entity(field_name="tokens", entities=tokens) + doc.annotate_entity(field_name="chunks", entities=chunks) + self.assertListEqual(doc.find_by_span(query=doc.chunks[0], field_name="tokens"), []) + self.assertListEqual(doc.find_by_span(query=doc.chunks[1], field_name="tokens"), []) + self.assertListEqual(doc.find_by_span(query=doc.chunks[2], field_name="tokens"), []) + # the reverse works just fine + self.assertListEqual(doc.find_by_span(query=doc.tokens[0], field_name="chunks"), []) + self.assertListEqual(doc.find_by_span(query=doc.tokens[1], field_name="chunks"), []) + self.assertListEqual(doc.find_by_span(query=doc.tokens[2], field_name="chunks"), [])