From 5acb7b51cff07756e9a7f3dd2366f0b6f8c53a55 Mon Sep 17 00:00:00 2001 From: Christoph Auer Date: Wed, 17 Jul 2024 15:21:13 +0200 Subject: [PATCH] Optimizations for table extraction quality, configurable options for cell matching --- docling/datamodel/base_models.py | 28 +++++++++-- docling/models/table_structure_model.py | 54 ++++++++++++++++----- docling/pipeline/standard_model_pipeline.py | 2 +- examples/convert.py | 2 - 4 files changed, 68 insertions(+), 18 deletions(-) diff --git a/docling/datamodel/base_models.py b/docling/datamodel/base_models.py index dd9795a..8b6796d 100644 --- a/docling/datamodel/base_models.py +++ b/docling/datamodel/base_models.py @@ -1,3 +1,4 @@ +import copy from enum import Enum, auto from io import BytesIO from typing import Any, Dict, List, Optional, Tuple, Union @@ -47,6 +48,15 @@ def width(self): def height(self): return abs(self.t - self.b) + def scaled(self, scale: float) -> "BoundingBox": + out_bbox = copy.deepcopy(self) + out_bbox.l *= scale + out_bbox.r *= scale + out_bbox.t *= scale + out_bbox.b *= scale + + return out_bbox + def as_tuple(self): if self.coord_origin == CoordOrigin.TOPLEFT: return (self.l, self.t, self.r, self.b) @@ -180,8 +190,7 @@ class TableStructurePrediction(BaseModel): table_map: Dict[int, TableElement] = {} -class TextElement(BasePageElement): - ... +class TextElement(BasePageElement): ... class FigureData(BaseModel): @@ -242,6 +251,17 @@ class DocumentStream(BaseModel): stream: BytesIO +class TableStructureOptions(BaseModel): + do_cell_matching: bool = ( + True + # True: Matches predictions back to PDF cells. Can break table output if PDF cells + # are merged across table columns. + # False: Let table structure model define the text cells, ignore PDF cells. + ) + + class PipelineOptions(BaseModel): - do_table_structure: bool = True - do_ocr: bool = False + do_table_structure: bool = True # True: perform table structure extraction + do_ocr: bool = False # True: perform OCR, replace programmatic PDF text + + table_structure_options: TableStructureOptions = TableStructureOptions() diff --git a/docling/models/table_structure_model.py b/docling/models/table_structure_model.py index 8ee4bda..132b141 100644 --- a/docling/models/table_structure_model.py +++ b/docling/models/table_structure_model.py @@ -1,7 +1,10 @@ -from typing import Iterable +import copy +import random +from typing import Iterable, List import numpy from docling_ibm_models.tableformer.data_management.tf_predictor import TFPredictor +from PIL import ImageDraw from docling.datamodel.base_models import ( BoundingBox, @@ -28,6 +31,21 @@ def __init__(self, config): self.tm_model_type = self.tm_config["model"]["type"] self.tf_predictor = TFPredictor(self.tm_config) + self.scale = 2.0 # Scale up table input images to 144 dpi + + def draw_table_and_cells(self, page: Page, tbl_list: List[TableElement]): + image = page._backend.get_page_image() + draw = ImageDraw.Draw(image) + + for table_element in tbl_list: + x0, y0, x1, y1 = table_element.cluster.bbox.as_tuple() + draw.rectangle([(x0, y0), (x1, y1)], outline="red") + + for tc in table_element.table_cells: + x0, y0, x1, y1 = tc.bbox.as_tuple() + draw.rectangle([(x0, y0), (x1, y1)], outline="blue") + + image.show() def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]: @@ -36,16 +54,17 @@ def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]: return for page in page_batch: + page.predictions.tablestructure = TableStructurePrediction() # dummy in_tables = [ ( cluster, [ - round(cluster.bbox.l), - round(cluster.bbox.t), - round(cluster.bbox.r), - round(cluster.bbox.b), + round(cluster.bbox.l) * self.scale, + round(cluster.bbox.t) * self.scale, + round(cluster.bbox.r) * self.scale, + round(cluster.bbox.b) * self.scale, ], ) for cluster in page.predictions.layout.clusters @@ -65,20 +84,29 @@ def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]: ): # Only allow non empty stings (spaces) into the cells of a table if len(c.text.strip()) > 0: - tokens.append(c.model_dump()) + new_cell = copy.deepcopy(c) + new_cell.bbox = new_cell.bbox.scaled(scale=self.scale) + + tokens.append(new_cell.model_dump()) - iocr_page = { - "image": numpy.asarray(page.image), + page_input = { "tokens": tokens, - "width": page.size.width, - "height": page.size.height, + "width": page.size.width * self.scale, + "height": page.size.height * self.scale, } + # add image to page input. + if self.scale == 1.0: + page_input["image"] = numpy.asarray(page.image) + else: # render new page image on the fly at desired scale + page_input["image"] = numpy.asarray( + page._backend.get_page_image(scale=self.scale) + ) table_clusters, table_bboxes = zip(*in_tables) if len(table_bboxes): tf_output = self.tf_predictor.multi_table_predict( - iocr_page, table_bboxes, do_matching=self.do_cell_matching + page_input, table_bboxes, do_matching=self.do_cell_matching ) for table_cluster, table_out in zip(table_clusters, tf_output): @@ -91,6 +119,7 @@ def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]: element["bbox"]["token"] = text_piece tc = TableCell.model_validate(element) + tc.bbox = tc.bbox.scaled(1 / self.scale) table_cells.append(tc) # Retrieving cols/rows, after post processing: @@ -111,4 +140,7 @@ def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]: page.predictions.tablestructure.table_map[table_cluster.id] = tbl + # For debugging purposes: + # self.draw_table_and_cells(page, page.predictions.tablestructure.table_map.values()) + yield page diff --git a/docling/pipeline/standard_model_pipeline.py b/docling/pipeline/standard_model_pipeline.py index 07c0113..33fee75 100644 --- a/docling/pipeline/standard_model_pipeline.py +++ b/docling/pipeline/standard_model_pipeline.py @@ -34,7 +34,7 @@ def __init__(self, artifacts_path: Path, pipeline_options: PipelineOptions): "artifacts_path": artifacts_path / StandardModelPipeline._table_model_path, "enabled": pipeline_options.do_table_structure, - "do_cell_matching": False, + "do_cell_matching": pipeline_options.table_structure_options.do_cell_matching, } ), ] diff --git a/examples/convert.py b/examples/convert.py index 89b3726..26a38c5 100644 --- a/examples/convert.py +++ b/examples/convert.py @@ -46,8 +46,6 @@ def main(): logging.basicConfig(level=logging.INFO) input_doc_paths = [ - # Path("/Users/cau/Downloads/Issue-36122.pdf"), - # Path("/Users/cau/Downloads/IBM_Storage_Insights_Fact_Sheet.pdf"), Path("./test/data/2206.01062.pdf"), Path("./test/data/2203.01017v2.pdf"), Path("./test/data/2305.03393v1.pdf"),