Skip to content

Commit

Permalink
Fix for sg overlap error in box_groups_to_span_groups when center=true (
Browse files Browse the repository at this point in the history
#276)

* stuff for solving sg overlap when center=true

* use Doc method instead of annotating

* increase version
  • Loading branch information
geli-gel authored Sep 1, 2023
1 parent e9708d6 commit 28cdf50
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = 'mmda'
version = '0.9.12'
version = '0.9.13'
description = 'MMDA - multimodal document analysis'
authors = [
{name = 'Allen Institute for Artificial Intelligence', email = '[email protected]'},
Expand Down
19 changes: 16 additions & 3 deletions src/mmda/utils/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

def allocate_overlapping_tokens_for_box(
tokens: List[SpanGroup], box, token_box_in_box_group: bool = False, x: float = 0.0, y: float = 0.0, center: bool = False
) -> Tuple[List[Span], List[Span]]:
) -> Tuple[List[SpanGroup], List[SpanGroup]]:
"""Finds overlap of tokens for given box
Args
`tokens` (List[SpanGroup])
Expand Down Expand Up @@ -102,7 +102,6 @@ def box_groups_to_span_groups(
center=center
)
all_page_tokens[box.page] = remaining_tokens

all_tokens_overlapping_box_group.extend(tokens_in_box)

merge_spans = (
Expand All @@ -119,10 +118,24 @@ def box_groups_to_span_groups(
index_distance=1,
)
)
derived_spans = merge_spans.merge_neighbor_spans_by_symbol_distance()

# tokens overlapping with derived spans:
sg_tokens = doc.find_overlapping(SpanGroup(spans=derived_spans), "tokens")

# remove any additional tokens added to the spangroup via MergeSpans from the list of available page tokens
# (this can happen if the MergeSpans algorithm merges tokens that are not adjacent, e.g. if `center` is True and
# a token is not found to be overlapping with the box, but MergeSpans decides it is close enough to be merged)
for sg_token in sg_tokens:
if sg_token not in all_tokens_overlapping_box_group:
if token_box_in_box_group:
all_page_tokens[sg_token.box_group.boxes[0].page].remove(sg_token)
else:
all_page_tokens[sg_token.spans[0].box.page].remove(sg_token)

derived_span_groups.append(
SpanGroup(
spans=merge_spans.merge_neighbor_spans_by_symbol_distance(),
spans=derived_spans,
box_group=box_group,
# id = box_id,
)
Expand Down

0 comments on commit 28cdf50

Please sign in to comment.