Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kylel/2023/box query doc #27

Merged
merged 5 commits into from
Aug 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions papermage/magelib/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,28 @@ def from_json(cls, annotation_json: Union[Dict, List]) -> "Annotation":
pass

def __getattr__(self, field: str) -> List["Annotation"]:
"""This method allows you to access overlapping Annotations within the Document"""
"""This Overloading is convenient syntax since the `entity.layer` operation is intuitive for folks."""
return self.find_by_span(field=field)

def find_by_span(self, field: str) -> List["Annotation"]:
"""This method allows you to access overlapping Annotations
within the Document based on Span"""
if self.doc is None:
raise ValueError("This annotation is not attached to a document")

if field in self.doc.fields:
return self.doc.find_by_span(self, field)

return self.__getattribute__(field)

def find_by_box(self, field: str) -> List["Annotation"]:
"""This method allows you to access overlapping Annotations
within the Document based on Box"""

if self.doc is None:
raise ValueError("This annotation is not attached to a document")

if field in self.doc.fields:
return self.doc.find_span_overlap_entities(self, field)
return self.doc.find_by_box(self, field)

return self.__getattribute__(field)
15 changes: 13 additions & 2 deletions papermage/magelib/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@

from typing import Dict, Iterable, List, Optional

from papermage.magelib import Entity, EntitySpanIndexer, Image, Metadata
from papermage.magelib import (
Entity,
EntityBoxIndexer,
EntitySpanIndexer,
Image,
Metadata,
)

# document field names
SymbolsFieldName = "symbols"
Expand All @@ -28,14 +34,18 @@ def __init__(self, symbols: str, metadata: Optional[Metadata] = None):
self.symbols = symbols
self.metadata = metadata if metadata else Metadata()
self.__entity_span_indexers: Dict[str, EntitySpanIndexer] = {}
self.__entity_box_indexers: Dict[str, EntityBoxIndexer] = {}

@property
def fields(self) -> List[str]:
return list(self.__entity_span_indexers.keys()) + self.SPECIAL_FIELDS

def find_span_overlap_entities(self, query: Entity, field_name: str) -> List[Entity]:
def find_by_span(self, query: Entity, field_name: str) -> List[Entity]:
return self.__entity_span_indexers[field_name].find(query=query)

def find_by_box(self, query: Entity, field_name: str) -> List[Entity]:
return self.__entity_box_indexers[field_name].find(query=query)

def check_field_name_availability(self, field_name: str) -> None:
if field_name in self.SPECIAL_FIELDS:
raise AssertionError(f"{field_name} not allowed Document.SPECIAL_FIELDS.")
Expand All @@ -55,6 +65,7 @@ def annotate_entity(self, field_name: str, entities: List[Entity]) -> None:

setattr(self, field_name, entities)
self.__entity_span_indexers[field_name] = EntitySpanIndexer(entities=entities)
self.__entity_box_indexers[field_name] = EntityBoxIndexer(entities=entities)
Comment on lines 67 to +68
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens here if any/all of the entities have only boxes or spans?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mmmm lemme test


def remove_entity(self, field_name: str):
for entity in getattr(self, field_name):
Expand Down
5 changes: 3 additions & 2 deletions papermage/magelib/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class EntityBoxIndexer(Indexer):
@kylel
"""

def __init__(self, entities: List[Entity]) -> None:
def __init__(self, entities: List[Entity], allow_overlap: bool = True) -> None:
self._entities = entities

self._box_id_to_entity_id = {}
Expand All @@ -125,7 +125,8 @@ def __init__(self, entities: List[Entity]) -> None:
self._np_boxes_y2 = np.array([b.t + b.h for b in self._boxes])
self._np_boxes_page = np.array([b.page for b in self._boxes])

self._ensure_disjoint()
if not allow_overlap:
self._ensure_disjoint()

def _find_overlap_boxes(self, query: Box) -> List[int]:
x1, y1, x2, y2 = query.xy_coordinates
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = 'papermage'
version = '0.4.0'
version = '0.5.0'
description = 'Papermage. Casting magic over scientific PDFs.'
license = {text = 'Apache-2.0'}
readme = 'README.md'
Expand Down
91 changes: 91 additions & 0 deletions tests/test_magelib/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,94 @@ def test_metadata_deserializes_when_empty(self):

self.assertEqual(symbols, doc.symbols)
self.assertEqual(0, len(doc.metadata))

def test_cross_referencing(self):
doc = Document("This is a test document!")
# boxes are in a top-left to bottom-right diagonal fashion (same page)
tokens = [
Entity.from_json({"spans": [[0, 4]], "boxes": [[0, 0, 0.5, 0.5, 0]]}),
Entity.from_json({"spans": [[5, 7]], "boxes": [[1, 1, 0.5, 0.5, 0]]}),
Entity.from_json({"spans": [[8, 9]], "boxes": [[2, 2, 0.5, 0.5, 0]]}),
Entity.from_json({"spans": [[10, 14]], "boxes": [[3, 3, 0.5, 0.5, 0]]}),
Entity.from_json({"spans": [[15, 23]], "boxes": [[4, 4, 0.5, 0.5, 0]]}),
Entity.from_json({"spans": [[23, 24]], "boxes": [[5, 5, 0.5, 0.5, 0]]}),
]
# boxes are also same diagonal fashion, but bigger.
# last box super big on wrong page.
chunks = [
Entity.from_json({"spans": [[0, 9]], "boxes": [[0, 0, 2.01, 2.01, 0]]}),
Entity.from_json({"spans": [[12, 23]], "boxes": [[3.0, 3.0, 4.0, 4.0, 0]]}),
Entity.from_json({"spans": [[23, 24]], "boxes": [[0, 0, 10.0, 10.0, 1]]}),
]
doc.annotate_entity(field_name="tokens", entities=tokens)
doc.annotate_entity(field_name="chunks", entities=chunks)

# find by span is the default overload of Entity.__attr__
self.assertListEqual(doc.chunks[0].tokens, tokens[0:3])
self.assertListEqual(doc.chunks[1].tokens, tokens[3:5])
self.assertListEqual(doc.chunks[2].tokens, [tokens[5]])

# find by span works fine
self.assertListEqual(doc.chunks[0].tokens, doc.find_by_span(query=doc.chunks[0], field_name="tokens"))
self.assertListEqual(doc.chunks[1].tokens, doc.find_by_span(query=doc.chunks[1], field_name="tokens"))
self.assertListEqual(doc.chunks[2].tokens, doc.find_by_span(query=doc.chunks[2], field_name="tokens"))

# find by box
self.assertListEqual(doc.find_by_box(query=doc.chunks[0], field_name="tokens"), doc.tokens[0:3])
self.assertListEqual(doc.find_by_box(query=doc.chunks[1], field_name="tokens"), doc.tokens[3:6])
self.assertListEqual(doc.find_by_box(query=doc.chunks[2], field_name="tokens"), [])

def test_cross_referencing_with_missing_entity_fields(self):
"""What happens when annotate a Doc with entiites missing spans or boxes?
How does the cross-referencing operation behave?"""
# this is same test as above but removing select fields
doc = Document("This is a test document!")

# 1) tokens no boxes
tokens = [
Entity.from_json({"spans": [[0, 4]]}),
Entity.from_json({"spans": [[5, 7]]}),
Entity.from_json({"spans": [[8, 9]]}),
Entity.from_json({"spans": [[10, 14]]}),
Entity.from_json({"spans": [[15, 23]]}),
Entity.from_json({"spans": [[23, 24]]}),
]
chunks = [
Entity.from_json({"spans": [[0, 9]], "boxes": [[0, 0, 2.01, 2.01, 0]]}),
Entity.from_json({"spans": [[12, 23]], "boxes": [[3.0, 3.0, 4.0, 4.0, 0]]}),
Entity.from_json({"spans": [[23, 24]], "boxes": [[0, 0, 10.0, 10.0, 1]]}),
]
doc.annotate_entity(field_name="tokens", entities=tokens)
doc.annotate_entity(field_name="chunks", entities=chunks)
self.assertListEqual(doc.find_by_box(query=doc.chunks[0], field_name="tokens"), [])
self.assertListEqual(doc.find_by_box(query=doc.chunks[1], field_name="tokens"), [])
self.assertListEqual(doc.find_by_box(query=doc.chunks[2], field_name="tokens"), [])
# the reverse works just fine
self.assertListEqual(doc.find_by_box(query=doc.tokens[0], field_name="chunks"), [])
self.assertListEqual(doc.find_by_box(query=doc.tokens[1], field_name="chunks"), [])
self.assertListEqual(doc.find_by_box(query=doc.tokens[2], field_name="chunks"), [])

# 2) tokens no spans
doc = Document("This is a test document!")
tokens = [
Entity.from_json({"boxes": [[0, 0, 0.5, 0.5, 0]]}),
Entity.from_json({"boxes": [[1, 1, 0.5, 0.5, 0]]}),
Entity.from_json({"boxes": [[2, 2, 0.5, 0.5, 0]]}),
Entity.from_json({"boxes": [[3, 3, 0.5, 0.5, 0]]}),
Entity.from_json({"boxes": [[4, 4, 0.5, 0.5, 0]]}),
Entity.from_json({"boxes": [[5, 5, 0.5, 0.5, 0]]}),
]
chunks = [
Entity.from_json({"spans": [[0, 9]], "boxes": [[0, 0, 2.01, 2.01, 0]]}),
Entity.from_json({"spans": [[12, 23]], "boxes": [[3.0, 3.0, 4.0, 4.0, 0]]}),
Entity.from_json({"spans": [[23, 24]], "boxes": [[0, 0, 10.0, 10.0, 1]]}),
]
doc.annotate_entity(field_name="tokens", entities=tokens)
doc.annotate_entity(field_name="chunks", entities=chunks)
self.assertListEqual(doc.find_by_span(query=doc.chunks[0], field_name="tokens"), [])
self.assertListEqual(doc.find_by_span(query=doc.chunks[1], field_name="tokens"), [])
self.assertListEqual(doc.find_by_span(query=doc.chunks[2], field_name="tokens"), [])
# the reverse works just fine
self.assertListEqual(doc.find_by_span(query=doc.tokens[0], field_name="chunks"), [])
self.assertListEqual(doc.find_by_span(query=doc.tokens[1], field_name="chunks"), [])
self.assertListEqual(doc.find_by_span(query=doc.tokens[2], field_name="chunks"), [])
17 changes: 9 additions & 8 deletions tests/test_magelib/test_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def test_overlap_within_single_entity_fails_checks(self):
entities = [Entity(boxes=[Box(0, 0, 5, 5, page=0), Box(4, 4, 7, 7, page=0)])]

with self.assertRaises(ValueError):
EntityBoxIndexer(entities)
EntityBoxIndexer(entities=entities, allow_overlap=False)
EntityBoxIndexer(entities=entities, allow_overlap=True)

def test_overlap_between_entities_fails_checks(self):
entities = [
Expand All @@ -83,15 +84,15 @@ def test_overlap_between_entities_fails_checks(self):
]

with self.assertRaises(ValueError):
EntityBoxIndexer(entities)
EntityBoxIndexer(entities=entities, allow_overlap=False)
EntityBoxIndexer(entities=entities, allow_overlap=True)

def test_finds_matching_entities_in_doc_order(self):
entities_to_index = [
Entity(boxes=[Box(0, 0, 1, 1, page=0), Box(2, 2, 1, 1, page=0)]),
Entity(boxes=[Box(4, 4, 1, 1, page=0)]),
Entity(boxes=[Box(100, 100, 1, 1, page=0)]),
]

index = EntityBoxIndexer(entities_to_index)

# should intersect 1 and 2 but not 3
Expand All @@ -103,22 +104,22 @@ def test_finds_matching_entities_in_doc_order(self):

def test_finds_matching_entities_accounts_for_pages(self):
entities_to_index = [
Entity(boxes=[Box(0, 0, 1, 1, page=0), Box(2, 2, 1, 1, page=1)]),
Entity(boxes=[Box(4, 4, 1, 1, page=1)]),
Entity(boxes=[Box(100, 100, 1, 1, page=0)]),
Entity(boxes=[Box(0.0, 0.0, 0.1, 0.1, page=0), Box(0.2, 0.2, 0.1, 0.1, page=1)]),
Entity(boxes=[Box(0.4, 0.4, 0.1, 0.1, page=1)]),
Entity(boxes=[Box(10.0, 10.0, 0.1, 0.1, page=0)]),
]

index = EntityBoxIndexer(entities_to_index)

# shouldnt intersect any given page 0
probe = Entity(boxes=[Box(1, 1, 5, 5, page=0), Box(9, 9, 5, 5, page=0)])
probe = Entity(boxes=[Box(0.1, 0.1, 0.5, 0.5, page=0), Box(0.9, 0.9, 0.5, 0.5, page=0)])
matches = index.find(probe)

self.assertEqual(len(matches), 1)
self.assertEqual(matches, [entities_to_index[0]])

# shoudl intersect after switching to page 1 (and the page 2 box doesnt intersect)
probe = Entity(boxes=[Box(1, 1, 5, 5, page=1), Box(100, 100, 1, 1, page=2)])
probe = Entity(boxes=[Box(0.1, 0.1, 0.5, 0.5, page=1), Box(10.0, 10.0, 0.1, 0.1, page=2)])
matches = index.find(probe)

self.assertEqual(len(matches), 2)
Expand Down
Loading