Skip to content

Commit

Permalink
Optimizations for table extraction quality, configurable options for …
Browse files Browse the repository at this point in the history
…cell matching
  • Loading branch information
cau-git committed Jul 17, 2024
1 parent 3ae8207 commit 5acb7b5
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 18 deletions.
28 changes: 24 additions & 4 deletions docling/datamodel/base_models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from enum import Enum, auto
from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -47,6 +48,15 @@ def width(self):
def height(self):
return abs(self.t - self.b)

def scaled(self, scale: float) -> "BoundingBox":
out_bbox = copy.deepcopy(self)
out_bbox.l *= scale
out_bbox.r *= scale
out_bbox.t *= scale
out_bbox.b *= scale

return out_bbox

def as_tuple(self):
if self.coord_origin == CoordOrigin.TOPLEFT:
return (self.l, self.t, self.r, self.b)
Expand Down Expand Up @@ -180,8 +190,7 @@ class TableStructurePrediction(BaseModel):
table_map: Dict[int, TableElement] = {}


class TextElement(BasePageElement):
...
class TextElement(BasePageElement): ...


class FigureData(BaseModel):
Expand Down Expand Up @@ -242,6 +251,17 @@ class DocumentStream(BaseModel):
stream: BytesIO


class TableStructureOptions(BaseModel):
do_cell_matching: bool = (
True
# True: Matches predictions back to PDF cells. Can break table output if PDF cells
# are merged across table columns.
# False: Let table structure model define the text cells, ignore PDF cells.
)


class PipelineOptions(BaseModel):
do_table_structure: bool = True
do_ocr: bool = False
do_table_structure: bool = True # True: perform table structure extraction
do_ocr: bool = False # True: perform OCR, replace programmatic PDF text

table_structure_options: TableStructureOptions = TableStructureOptions()
54 changes: 43 additions & 11 deletions docling/models/table_structure_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Iterable
import copy
import random
from typing import Iterable, List

import numpy
from docling_ibm_models.tableformer.data_management.tf_predictor import TFPredictor
from PIL import ImageDraw

from docling.datamodel.base_models import (
BoundingBox,
Expand All @@ -28,6 +31,21 @@ def __init__(self, config):
self.tm_model_type = self.tm_config["model"]["type"]

self.tf_predictor = TFPredictor(self.tm_config)
self.scale = 2.0 # Scale up table input images to 144 dpi

def draw_table_and_cells(self, page: Page, tbl_list: List[TableElement]):
image = page._backend.get_page_image()
draw = ImageDraw.Draw(image)

for table_element in tbl_list:
x0, y0, x1, y1 = table_element.cluster.bbox.as_tuple()
draw.rectangle([(x0, y0), (x1, y1)], outline="red")

for tc in table_element.table_cells:
x0, y0, x1, y1 = tc.bbox.as_tuple()
draw.rectangle([(x0, y0), (x1, y1)], outline="blue")

image.show()

def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:

Expand All @@ -36,16 +54,17 @@ def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
return

for page in page_batch:

page.predictions.tablestructure = TableStructurePrediction() # dummy

in_tables = [
(
cluster,
[
round(cluster.bbox.l),
round(cluster.bbox.t),
round(cluster.bbox.r),
round(cluster.bbox.b),
round(cluster.bbox.l) * self.scale,
round(cluster.bbox.t) * self.scale,
round(cluster.bbox.r) * self.scale,
round(cluster.bbox.b) * self.scale,
],
)
for cluster in page.predictions.layout.clusters
Expand All @@ -65,20 +84,29 @@ def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
):
# Only allow non empty stings (spaces) into the cells of a table
if len(c.text.strip()) > 0:
tokens.append(c.model_dump())
new_cell = copy.deepcopy(c)
new_cell.bbox = new_cell.bbox.scaled(scale=self.scale)

tokens.append(new_cell.model_dump())

iocr_page = {
"image": numpy.asarray(page.image),
page_input = {
"tokens": tokens,
"width": page.size.width,
"height": page.size.height,
"width": page.size.width * self.scale,
"height": page.size.height * self.scale,
}
# add image to page input.
if self.scale == 1.0:
page_input["image"] = numpy.asarray(page.image)
else: # render new page image on the fly at desired scale
page_input["image"] = numpy.asarray(
page._backend.get_page_image(scale=self.scale)
)

table_clusters, table_bboxes = zip(*in_tables)

if len(table_bboxes):
tf_output = self.tf_predictor.multi_table_predict(
iocr_page, table_bboxes, do_matching=self.do_cell_matching
page_input, table_bboxes, do_matching=self.do_cell_matching
)

for table_cluster, table_out in zip(table_clusters, tf_output):
Expand All @@ -91,6 +119,7 @@ def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
element["bbox"]["token"] = text_piece

tc = TableCell.model_validate(element)
tc.bbox = tc.bbox.scaled(1 / self.scale)
table_cells.append(tc)

# Retrieving cols/rows, after post processing:
Expand All @@ -111,4 +140,7 @@ def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:

page.predictions.tablestructure.table_map[table_cluster.id] = tbl

# For debugging purposes:
# self.draw_table_and_cells(page, page.predictions.tablestructure.table_map.values())

yield page
2 changes: 1 addition & 1 deletion docling/pipeline/standard_model_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, artifacts_path: Path, pipeline_options: PipelineOptions):
"artifacts_path": artifacts_path
/ StandardModelPipeline._table_model_path,
"enabled": pipeline_options.do_table_structure,
"do_cell_matching": False,
"do_cell_matching": pipeline_options.table_structure_options.do_cell_matching,
}
),
]
2 changes: 0 additions & 2 deletions examples/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ def main():
logging.basicConfig(level=logging.INFO)

input_doc_paths = [
# Path("/Users/cau/Downloads/Issue-36122.pdf"),
# Path("/Users/cau/Downloads/IBM_Storage_Insights_Fact_Sheet.pdf"),
Path("./test/data/2206.01062.pdf"),
Path("./test/data/2203.01017v2.pdf"),
Path("./test/data/2305.03393v1.pdf"),
Expand Down

0 comments on commit 5acb7b5

Please sign in to comment.