Skip to content

Commit

Permalink
fix: 1. resolve uncorrect pair relation of figure and footnote, 2. re…
Browse files Browse the repository at this point in the history
…solve uncorrect pair relation of table and caption #590
  • Loading branch information
icecraft committed Sep 12, 2024
1 parent 0140d7d commit faac0a4
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 46 deletions.
19 changes: 19 additions & 0 deletions magic_pdf/libs/boxbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,3 +426,22 @@ def dist(point1, point2):
elif top:
return y2 - y1b
return 0.0


def box_area(bbox):
return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])


def get_overlap_area(bbox1, bbox2):
"""计算box1和box2的重叠面积占bbox1的比例."""
# Determine the coordinates of the intersection rectangle
x_left = max(bbox1[0], bbox2[0])
y_top = max(bbox1[1], bbox2[1])
x_right = min(bbox1[2], bbox2[2])
y_bottom = min(bbox1[3], bbox2[3])

if x_right < x_left or y_bottom < y_top:
return 0.0

# The area of overlap area
return (x_right - x_left) * (y_bottom - y_top)
141 changes: 96 additions & 45 deletions magic_pdf/model/magic_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import json

from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance,
bbox_relative_pos, calculate_iou,
calculate_overlap_area_in_bbox1_area_ratio)
bbox_relative_pos, box_area, calculate_iou,
calculate_overlap_area_in_bbox1_area_ratio,
get_overlap_area)
from magic_pdf.libs.commons import fitz, join_path
from magic_pdf.libs.coordinate_transform import get_scale_ratio
from magic_pdf.libs.local_math import float_gt
Expand All @@ -12,6 +13,7 @@
from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter

CAPATION_OVERLAP_AREA_RATIO = 0.6
MERGE_BOX_OVERLAP_AREA_RATIO = 1.1


class MagicModel:
Expand Down Expand Up @@ -124,49 +126,51 @@ def __fix_footnote(self):
tables.append(obj)
if len(footnotes) * len(figures) == 0:
continue
dis_figure_footnote = {}
dis_table_footnote = {}

for i in range(len(footnotes)):
for j in range(len(figures)):
pos_flag_count = sum(
list(
map(
lambda x: 1 if x else 0,
bbox_relative_pos(
footnotes[i]['bbox'], figures[j]['bbox']
),
)
dis_figure_footnote = {}
dis_table_footnote = {}

for i in range(len(footnotes)):
for j in range(len(figures)):
pos_flag_count = sum(
list(
map(
lambda x: 1 if x else 0,
bbox_relative_pos(
footnotes[i]['bbox'], figures[j]['bbox']
),
)
)
if pos_flag_count > 1:
continue
dis_figure_footnote[i] = min(
bbox_distance(figures[j]['bbox'], footnotes[i]['bbox']),
dis_figure_footnote.get(i, float('inf')),
)
for i in range(len(footnotes)):
for j in range(len(tables)):
pos_flag_count = sum(
list(
map(
lambda x: 1 if x else 0,
bbox_relative_pos(
footnotes[i]['bbox'], tables[j]['bbox']
),
)
)
if pos_flag_count > 1:
continue
dis_figure_footnote[i] = min(
bbox_distance(figures[j]['bbox'], footnotes[i]['bbox']),
dis_figure_footnote.get(i, float('inf')),
)
for i in range(len(footnotes)):
for j in range(len(tables)):
pos_flag_count = sum(
list(
map(
lambda x: 1 if x else 0,
bbox_relative_pos(
footnotes[i]['bbox'], tables[j]['bbox']
),
)
)
if pos_flag_count > 1:
continue
)
if pos_flag_count > 1:
continue

dis_table_footnote[i] = min(
bbox_distance(tables[j]['bbox'], footnotes[i]['bbox']),
dis_table_footnote.get(i, float('inf')),
)
for i in range(len(footnotes)):
if dis_table_footnote.get(i, float('inf')) > dis_figure_footnote[i]:
footnotes[i]['category_id'] = CategoryId.ImageFootnote
dis_table_footnote[i] = min(
bbox_distance(tables[j]['bbox'], footnotes[i]['bbox']),
dis_table_footnote.get(i, float('inf')),
)
for i in range(len(footnotes)):
if i not in dis_figure_footnote:
continue
if dis_table_footnote.get(i, float('inf')) > dis_figure_footnote[i]:
footnotes[i]['category_id'] = CategoryId.ImageFootnote

def __reduct_overlap(self, bboxes):
N = len(bboxes)
Expand All @@ -191,6 +195,44 @@ def __tie_up_category_by_distance(
筛选出所有和 merged bbox 有 overlap 且 overlap 面积大于 object 的面积的 subjects。
再求出筛选出的 subjects 和 object 的最短距离
"""
def search_overlap_between_boxes(
subject_idx, object_idx
):
idxes = [subject_idx, object_idx]
x0s = [all_bboxes[idx]['bbox'][0] for idx in idxes]
y0s = [all_bboxes[idx]['bbox'][1] for idx in idxes]
x1s = [all_bboxes[idx]['bbox'][2] for idx in idxes]
y1s = [all_bboxes[idx]['bbox'][3] for idx in idxes]

merged_bbox = [
min(x0s),
min(y0s),
max(x1s),
max(y1s),
]
ratio = 0

other_objects = list(
map(
lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter(
lambda x: x['category_id']
not in (object_category_id, subject_category_id),
self.__model_list[page_no]['layout_dets'],
),
)
)
for other_object in other_objects:
ratio = max(
ratio,
get_overlap_area(
merged_bbox, other_object['bbox']
) * 1.0 / box_area(all_bboxes[object_idx]['bbox'])
)
if ratio >= MERGE_BOX_OVERLAP_AREA_RATIO:
break

return ratio

def may_find_other_nearest_bbox(subject_idx, object_idx):
ret = float('inf')
Expand Down Expand Up @@ -299,6 +341,15 @@ def expand_bbbox(idxes):
):
continue

subject_idx, object_idx = i, j
if all_bboxes[j]['category_id'] == subject_category_id:
subject_idx, object_idx = j, i

if search_overlap_between_boxes(subject_idx, object_idx) >= MERGE_BOX_OVERLAP_AREA_RATIO:
dis[i][j] = float('inf')
dis[j][i] = dis[i][j]
continue

dis[i][j] = bbox_distance(all_bboxes[i]['bbox'], all_bboxes[j]['bbox'])
dis[j][i] = dis[i][j]

Expand Down Expand Up @@ -627,13 +678,13 @@ def remove_duplicate_spans(spans):
span['type'] = ContentType.Image
elif category_id == 5:
# 获取table模型结果
latex = layout_det.get("latex", None)
html = layout_det.get("html", None)
latex = layout_det.get('latex', None)
html = layout_det.get('html', None)
if latex:
span["latex"] = latex
span['latex'] = latex
elif html:
span["html"] = html
span["type"] = ContentType.Table
span['html'] = html
span['type'] = ContentType.Table
elif category_id == 13:
span['content'] = layout_det['latex']
span['type'] = ContentType.InlineEquation
Expand Down
2 changes: 1 addition & 1 deletion magic_pdf/tools/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def do_parse(
end_page_id=None,
):
if debug_able:
logger.warning("debug mode is on")
logger.warning('debug mode is on')
f_dump_content_list = True
f_draw_model_bbox = True

Expand Down

0 comments on commit faac0a4

Please sign in to comment.