From 28cdf50d5f3639bb45e4e5de067c83ecb9bf8058 Mon Sep 17 00:00:00 2001 From: Angele Zamarron Date: Fri, 1 Sep 2023 08:25:13 -0700 Subject: [PATCH] Fix for sg overlap error in box_groups_to_span_groups when center=true (#276) * stuff for solving sg overlap when center=true * use Doc method instead of annotating * increase version --- pyproject.toml | 2 +- src/mmda/utils/tools.py | 19 ++++++++++++++++--- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5e3ae43d..16265414 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = 'mmda' -version = '0.9.12' +version = '0.9.13' description = 'MMDA - multimodal document analysis' authors = [ {name = 'Allen Institute for Artificial Intelligence', email = 'contact@allenai.org'}, diff --git a/src/mmda/utils/tools.py b/src/mmda/utils/tools.py index 9effac49..d06821c2 100644 --- a/src/mmda/utils/tools.py +++ b/src/mmda/utils/tools.py @@ -15,7 +15,7 @@ def allocate_overlapping_tokens_for_box( tokens: List[SpanGroup], box, token_box_in_box_group: bool = False, x: float = 0.0, y: float = 0.0, center: bool = False -) -> Tuple[List[Span], List[Span]]: +) -> Tuple[List[SpanGroup], List[SpanGroup]]: """Finds overlap of tokens for given box Args `tokens` (List[SpanGroup]) @@ -102,7 +102,6 @@ def box_groups_to_span_groups( center=center ) all_page_tokens[box.page] = remaining_tokens - all_tokens_overlapping_box_group.extend(tokens_in_box) merge_spans = ( @@ -119,10 +118,24 @@ def box_groups_to_span_groups( index_distance=1, ) ) + derived_spans = merge_spans.merge_neighbor_spans_by_symbol_distance() + + # tokens overlapping with derived spans: + sg_tokens = doc.find_overlapping(SpanGroup(spans=derived_spans), "tokens") + + # remove any additional tokens added to the spangroup via MergeSpans from the list of available page tokens + # (this can happen if the MergeSpans algorithm merges tokens that are not adjacent, e.g. if `center` is True and + # a token is not found to be overlapping with the box, but MergeSpans decides it is close enough to be merged) + for sg_token in sg_tokens: + if sg_token not in all_tokens_overlapping_box_group: + if token_box_in_box_group: + all_page_tokens[sg_token.box_group.boxes[0].page].remove(sg_token) + else: + all_page_tokens[sg_token.spans[0].box.page].remove(sg_token) derived_span_groups.append( SpanGroup( - spans=merge_spans.merge_neighbor_spans_by_symbol_distance(), + spans=derived_spans, box_group=box_group, # id = box_id, )