Skip to content

Commit

Permalink
Resolve conflict by removing core recipe tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben Newman committed Aug 7, 2023
2 parents 64710c9 + 4b3187e commit 97e4b68
Show file tree
Hide file tree
Showing 48 changed files with 59,594 additions and 500 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
/pdfs/
/temp/
scripts/
notebooks/
lightning_logs/

Expand Down
158 changes: 158 additions & 0 deletions examples/quick_start_demo.ipynb

Large diffs are not rendered by default.

File renamed without changes.
3 changes: 3 additions & 0 deletions papermage/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .utils.version import get_version

__version__ = get_version()
20 changes: 14 additions & 6 deletions papermage/magelib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,18 @@
from papermage.magelib.indexer import EntitySpanIndexer, EntityBoxIndexer
from papermage.magelib.document import Document
from papermage.magelib.document import (
MetadataFieldName,
EntitiesFieldName,
MetadataFieldName,
EntitiesFieldName,
SymbolsFieldName,
RelationsFieldName,
PagesFieldName,
TokensFieldName,
RelationsFieldName,
PagesFieldName,
TokensFieldName,
RowsFieldName,
ImagesFieldName
BlocksFieldName,
ImagesFieldName,
WordsFieldName,
SentencesFieldName,
ParagraphsFieldName
)

__all__ = [
Expand All @@ -44,4 +48,8 @@
"PagesFieldName",
"TokensFieldName",
"RowsFieldName",
"BlocksFieldName",
"WordsFieldName",
"SentencesFieldName",
"ParagraphsFieldName",
]
14 changes: 9 additions & 5 deletions papermage/magelib/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@ def from_json(cls, annotation_json: Union[Dict, List]) -> "Annotation":

def __getattr__(self, field: str) -> List["Annotation"]:
"""This Overloading is convenient syntax since the `entity.layer` operation is intuitive for folks."""
return self.find_by_span(field=field)
try:
return self.find_by_span(field=field)
except ValueError:
# maybe users just want some attribute of the Annotation object
return self.__getattribute__(field)

def find_by_span(self, field: str) -> List["Annotation"]:
"""This method allows you to access overlapping Annotations
Expand All @@ -77,8 +81,8 @@ def find_by_span(self, field: str) -> List["Annotation"]:

if field in self.doc.fields:
return self.doc.find_by_span(self, field)

return self.__getattribute__(field)
else:
raise ValueError(f"Field {field} not found in Document")

def find_by_box(self, field: str) -> List["Annotation"]:
"""This method allows you to access overlapping Annotations
Expand All @@ -89,5 +93,5 @@ def find_by_box(self, field: str) -> List["Annotation"]:

if field in self.doc.fields:
return self.doc.find_by_box(self, field)

return self.__getattribute__(field)
else:
raise ValueError(f"Field {field} not found in Document")
30 changes: 25 additions & 5 deletions papermage/magelib/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,23 @@
"""

import logging
from typing import Dict, List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union

import numpy as np

from papermage.magelib import Span


class _BoxSpan(Span):
def __init__(self, start: float, end: float):
self.start = start # type: ignore
self.end = end # type: ignore


class Box:

__slots__ = ["l", "t", "w", "h", "page"]

def __init__(self, l: float, t: float, w: float, h: float, page: int):
assert w >= 0.0, "Box width cant be negative"
assert h >= 0.0, "Box height cant be negative"
Expand Down Expand Up @@ -73,6 +82,17 @@ def from_xy_coordinates(
def xy_coordinates(self) -> Tuple[float, float, float, float]:
return self.l, self.t, self.l + self.w, self.t + self.h

def __eq__(self, other: object) -> bool:
if not isinstance(other, Box):
return False
return (
self.l == other.l
and self.t == other.t
and self.w == other.w
and self.h == other.h
and self.page == other.page
)

def to_relative(self, page_width: float, page_height: float) -> "Box":
"""Get the relative coordinates of self based on page_width, page_height."""
return self.__class__(
Expand Down Expand Up @@ -107,14 +127,14 @@ def is_overlap(self, other: "Box") -> bool:
other_x1, other_y1, other_x2, other_y2 = other.xy_coordinates

# check x-axis
span_x_self = Span(start=self_x1, end=self_x2)
span_x_other = Span(start=other_x1, end=other_x2)
span_x_self = _BoxSpan(start=self_x1, end=self_x2)
span_x_other = _BoxSpan(start=other_x1, end=other_x2)
if not span_x_self.is_overlap(span_x_other):
return False

# check y-axis
span_y_self = Span(start=self_y1, end=self_y2)
span_y_other = Span(start=other_y1, end=other_y2)
span_y_self = _BoxSpan(start=self_y1, end=self_y2)
span_y_other = _BoxSpan(start=other_y1, end=other_y2)
if not span_y_self.is_overlap(span_y_other):
return False

Expand Down
32 changes: 28 additions & 4 deletions papermage/magelib/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@
"""

from typing import Dict, Iterable, List, Optional
from typing import Dict, Iterable, List, Optional, Union

from papermage.magelib import (
Box,
Entity,
EntityBoxIndexer,
EntitySpanIndexer,
Image,
Metadata,
Span,
)

# document field names
Expand All @@ -25,6 +27,10 @@
PagesFieldName = "pages"
TokensFieldName = "tokens"
RowsFieldName = "rows"
BlocksFieldName = "blocks"
WordsFieldName = "words"
SentencesFieldName = "sentences"
ParagraphsFieldName = "paragraphs"


class Document:
Expand All @@ -40,39 +46,57 @@ def __init__(self, symbols: str, metadata: Optional[Metadata] = None):
def fields(self) -> List[str]:
return list(self.__entity_span_indexers.keys()) + self.SPECIAL_FIELDS

def find(self, query: Union[Span, Box], field_name: str) -> List[Entity]:
if isinstance(query, Span):
return self.__entity_span_indexers[field_name].find(query=Entity(spans=[query]))
elif isinstance(query, Box):
return self.__entity_box_indexers[field_name].find(query=Entity(boxes=[query]))
else:
raise TypeError(f"Unsupported query type {type(query)}")

def find_by_span(self, query: Entity, field_name: str) -> List[Entity]:
# TODO: will rename this to `intersect_by_span`
return self.__entity_span_indexers[field_name].find(query=query)

def find_by_box(self, query: Entity, field_name: str) -> List[Entity]:
# TODO: will rename this to `intersect_by_span`
return self.__entity_box_indexers[field_name].find(query=query)

def check_field_name_availability(self, field_name: str) -> None:
if field_name in self.SPECIAL_FIELDS:
raise AssertionError(f"{field_name} not allowed Document.SPECIAL_FIELDS.")
if field_name in self.__entity_span_indexers.keys():
raise AssertionError(f"{field_name} already exists. Try `is_overwrite=True`")
raise AssertionError(f'{field_name} already exists. Try `doc.remove_entity("{field_name}")` first.')
if field_name in dir(self):
raise AssertionError(f"{field_name} clashes with Document class properties.")

def get_entity(self, field_name: str) -> List[Entity]:
return getattr(self, field_name)

def annotate(self, field_name: str, entities: List[Entity]) -> None:
if all(isinstance(e, Entity) for e in entities):
self.annotate_entity(field_name=field_name, entities=entities)
else:
raise NotImplementedError(f"entity list contains non-entities: {[type(e) for e in entities]}")

def annotate_entity(self, field_name: str, entities: List[Entity]) -> None:
self.check_field_name_availability(field_name=field_name)

for entity in entities:
for i, entity in enumerate(entities):
entity.doc = self
entity.id = i

setattr(self, field_name, entities)
self.__entity_span_indexers[field_name] = EntitySpanIndexer(entities=entities)
self.__entity_box_indexers[field_name] = EntityBoxIndexer(entities=entities)
setattr(self, field_name, entities)

def remove_entity(self, field_name: str):
for entity in getattr(self, field_name):
entity.doc = None

delattr(self, field_name)
del self.__entity_span_indexers[field_name]
del self.__entity_box_indexers[field_name]

def get_relation(self, name: str) -> List["Relation"]:
raise NotImplementedError
Expand Down
30 changes: 22 additions & 8 deletions papermage/magelib/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,9 @@ def __init__(
super().__init__()

def __repr__(self):
if self.doc and self.spans:
symbols: List[str] = [self.doc.symbols[span.start : span.end] for span in self.spans]
return f"Entity\tSymbols\t{symbols}"
return f"Entity{self.to_json()}"
if self.doc:
return f"Annotated Entity:\tSpans: {True if self.spans else False}\tBoxes: {True if self.boxes else False}\nText: {self.text}"
return f"Unannotated Entity: {self.to_json()}"

def to_json(self) -> Dict:
entity_dict = dict(
Expand Down Expand Up @@ -60,18 +59,33 @@ def end(self) -> Union[int, float]:
return max([span.end for span in self.spans]) if len(self.spans) > 0 else float("inf")

@property
def symbols(self) -> List[str]:
def symbols_from_spans(self) -> List[str]:
if self.doc is not None:
return [self.doc.symbols[span.start : span.end] for span in self.spans]
else:
return []

@property
def symbols_from_boxes(self) -> List[str]:
if self.doc is not None:
matched_tokens = self.doc.find_by_box(query=self, field_name="tokens")
return [self.doc.symbols[span.start : span.end] for t in matched_tokens for span in t.spans]
else:
return []

@property
def text(self) -> str:
# return stored metadata
maybe_text = self.metadata.get("text", None)
if maybe_text is None:
return " ".join(self.symbols)
return maybe_text
if maybe_text:
return maybe_text
# return derived from symbols
if self.symbols_from_spans:
return " ".join(self.symbols_from_spans)
# return derived from boxes and tokens
if self.symbols_from_boxes:
return " ".join(self.symbols_from_boxes)
return ""

@text.setter
def text(self, text: Union[str, None]) -> None:
Expand Down
4 changes: 2 additions & 2 deletions papermage/magelib/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def _ensure_disjoint(self) -> None:

def find(self, query: Entity) -> List[Entity]:
if not isinstance(query, Entity):
raise ValueError(f"EntityIndexer only works with `query` that is Entity type")
raise TypeError(f"EntityIndexer only works with `query` that is Entity type")

if not query.spans:
return []
Expand Down Expand Up @@ -159,7 +159,7 @@ def _ensure_disjoint(self) -> None:

def find(self, query: Entity) -> List[Entity]:
if not isinstance(query, Entity):
raise ValueError(f"EntityBoxIndexer only works with `query` that is Entity type")
raise TypeError(f"EntityBoxIndexer only works with `query` that is Entity type")

if not query.boxes:
return []
Expand Down
33 changes: 22 additions & 11 deletions papermage/magelib/span.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@


class Span:
__slots__ = ['start', 'end']

def __init__(self, start: int, end: int):
self.start = start
self.end = end
Expand All @@ -29,7 +31,9 @@ def from_json(cls, span_json: List) -> "Span":
def __repr__(self):
return f'Span{self.to_json()}'

def __eq__(self, other):
def __eq__(self, other: object) -> bool:
if not isinstance(other, Span):
return False
return self.start == other.start and self.end == other.end

def __lt__(self, other: 'Span'):
Expand All @@ -39,11 +43,16 @@ def __lt__(self, other: 'Span'):
return self.end < other.end
return self.start < other.start

def __hash__(self) -> int:
return hash((self.start, self.end))

def is_overlap(self, other: 'Span') -> bool:
"""Whether self overlaps with the other Span object."""
return self.start <= other.start < self.end or \
other.start <= self.start < other.end or \
self == other
return (
self.start <= other.start < self.end
or other.start <= self.start < other.end
or self == other
)

@classmethod
def create_enclosing_span(cls, spans: List['Span']) -> 'Span':
Expand Down Expand Up @@ -105,20 +114,22 @@ def clusters(self) -> List[List[Span]]:
self._clusters = self._cluster(spans=self.spans)
return self._clusters

def _build_graph(self, spans: List[Span], index_distance: int) -> defaultdict(list):
@staticmethod
def _is_neighboring_spans(span1: Span, span2: Span, index_distance: int) -> bool:
"""Whether two spans are considered neighboring"""
return min(
abs(span1.start - span2.end), abs(span1.end - span2.start)
) <= index_distance

def _build_graph(self, spans: List[Span], index_distance: int) -> Dict[int, List[int]]:
"""
Build graph, each node is the position within the input list of spans.
Spans are considered overlapping if they are index_distance apart
"""
is_neighboring_spans = (
lambda span1, span2: min(
abs(span1.start - span2.end), abs(span1.end - span2.start)
) <= index_distance
)
graph = defaultdict(list)
for i, span_i in enumerate(spans):
for j in range(i + 1, len(spans)):
if is_neighboring_spans(span_i, spans[j]):
if self._is_neighboring_spans(span1=span_i, span2=spans[j], index_distance=index_distance):
graph[i].append(j)
graph[j].append(i)
return graph
Expand Down
Loading

0 comments on commit 97e4b68

Please sign in to comment.