Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

experimental: introduce img understand pipeline #95

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions docling/datamodel/base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,18 +224,27 @@ class TableStructurePrediction(BaseModel):
class TextElement(BasePageElement): ...


class FigureClassificationData(BaseModel):
provenance: str
predicted_class: str
confidence: float


class FigureDescriptionData(BaseModel):
text: str
provenance: str = ""


class FigureData(BaseModel):
pass
classification: Optional[FigureClassificationData] = None
description: Optional[FigureDescriptionData] = None


class FigureElement(BasePageElement):
data: Optional[FigureData] = None
provenance: Optional[str] = None
predicted_class: Optional[str] = None
confidence: Optional[float] = None
data: FigureData = FigureData()


class FigureClassificationPrediction(BaseModel):
class FigurePrediction(BaseModel):
figure_count: int = 0
figure_map: Dict[int, FigureElement] = {}

Expand All @@ -248,7 +257,7 @@ class EquationPrediction(BaseModel):
class PagePredictions(BaseModel):
layout: Optional[LayoutPrediction] = None
tablestructure: Optional[TableStructurePrediction] = None
figures_classification: Optional[FigureClassificationPrediction] = None
figures_prediction: Optional[FigurePrediction] = None
equations_prediction: Optional[EquationPrediction] = None


Expand Down
123 changes: 123 additions & 0 deletions docling/models/img_understand_api_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import base64
import datetime
import io
import logging
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple

import httpx
from PIL import Image
from pydantic import AnyUrl, BaseModel, ConfigDict

from docling.datamodel.base_models import Cluster, FigureDescriptionData
from docling.models.img_understand_base_model import (
ImgUnderstandBaseModel,
ImgUnderstandOptions,
)

_log = logging.getLogger(__name__)


class ImgUnderstandApiOptions(ImgUnderstandOptions):
kind: Literal["api"] = "api"

url: AnyUrl
headers: Dict[str, str]
params: Dict[str, Any]
timeout: float = 20

llm_prompt: str
provenance: str


class ChatMessage(BaseModel):
role: str
content: str


class ResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: str


class ResponseUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int


class ApiResponse(BaseModel):
model_config = ConfigDict(
protected_namespaces=(),
)

id: str
model_id: Optional[str] = None # returned by watsonx
model: Optional[str] = None # returned bu openai
choices: List[ResponseChoice]
created: int
usage: ResponseUsage


class ImgUnderstandApiModel(ImgUnderstandBaseModel):

def __init__(self, enabled: bool, options: ImgUnderstandApiOptions):
super().__init__(enabled=enabled, options=options)
self.options: ImgUnderstandApiOptions

def _annotate_image_batch(
self, batch: Iterable[Tuple[Cluster, Image.Image]]
) -> List[FigureDescriptionData]:

if not self.enabled:
return [FigureDescriptionData() for _ in batch]

results = []
for cluster, image in batch:
img_io = io.BytesIO()
image.save(img_io, "PNG")
image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8")

messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": self.options.llm_prompt,
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{image_base64}"
},
},
],
}
]

payload = {
"messages": messages,
**self.options.params,
}

r = httpx.post(
str(self.options.url),
headers=self.options.headers,
json=payload,
timeout=self.options.timeout,
)
if not r.is_success:
_log.error(f"Error calling the API. Reponse was {r.text}")
r.raise_for_status()

api_resp = ApiResponse.model_validate_json(r.text)
generated_text = api_resp.choices[0].message.content.strip()
results.append(
FigureDescriptionData(
text=generated_text, provenance=self.options.provenance
)
)
_log.info(f"Generated description: {generated_text}")

return results
145 changes: 145 additions & 0 deletions docling/models/img_understand_base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import logging
import time
from typing import Iterable, List, Literal, Tuple

from PIL import Image
from pydantic import BaseModel

from docling.datamodel.base_models import (
Cluster,
FigureData,
FigureDescriptionData,
FigureElement,
FigurePrediction,
Page,
)

_log = logging.getLogger(__name__)


class ImgUnderstandOptions(BaseModel):
kind: str
batch_size: int = 8
scale: float = 2

# if the relative area of the image with respect to the whole image page
# is larger than this threshold it will be processed, otherwise not.
# TODO: implement the skip logic
min_area: float = 0.05
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's call that min_area_frac to be aligned.



class ImgUnderstandBaseModel:

def __init__(self, enabled: bool, options: ImgUnderstandOptions):
self.enabled = enabled
self.options = options

def _annotate_image_batch(
self, batch: Iterable[Tuple[Cluster, Image.Image]]
) -> List[FigureDescriptionData]:
raise NotImplemented()

def _flush_merge(
self,
page: Page,
cluster_figure_batch: List[Tuple[Cluster, Image.Image]],
figures_prediction: FigurePrediction,
):
start_time = time.time()
results_batch = self._annotate_image_batch(cluster_figure_batch)
assert len(results_batch) == len(
cluster_figure_batch
), "The returned annotations is not matching the input size"
end_time = time.time()
_log.info(
f"Batch of {len(results_batch)} images processed in {end_time-start_time:.1f} seconds. Time per image is {(end_time-start_time) / len(results_batch):.3f} seconds."
)

for (cluster, _), desc_data in zip(cluster_figure_batch, results_batch):
if not cluster.id in figures_prediction.figure_map:
figures_prediction.figure_map[cluster.id] = FigureElement(
label=cluster.label,
id=cluster.id,
data=FigureData(desciption=desc_data),
cluster=cluster,
page_no=page.page_no,
)
elif figures_prediction.figure_map[cluster.id].data.description is None:
figures_prediction.figure_map[cluster.id].data.description = desc_data
else:
_log.warning(
f"Conflicting predictions. "
f"Another model ({figures_prediction.figure_map[cluster.id].data.description.provenance}) "
f"was already predicting an image description. The new prediction will be skipped."
)

def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:

if not self.enabled:
yield from page_batch
return

for page in page_batch:

# This model could be the first one initializing figures_prediction
if page.predictions.figures_prediction is None:
page.predictions.figures_prediction = FigurePrediction()

# Select the picture clusters
in_clusters = []
for cluster in page.predictions.layout.clusters:
if cluster.label != "Picture":
continue

crop_bbox = cluster.bbox.scaled(
scale=self.options.scale
).to_top_left_origin(page_height=page.size.height * self.options.scale)
in_clusters.append(
(
cluster,
crop_bbox.as_tuple(),
)
)

if not len(in_clusters):
yield page
continue

# save classifications using proper object
if (
page.predictions.figures_prediction.figure_count > 0
and page.predictions.figures_prediction.figure_count != len(in_clusters)
):
raise RuntimeError(
"Different models predicted a different number of figures."
)
page.predictions.figures_prediction.figure_count = len(in_clusters)

cluster_figure_batch = []
page_image = page.get_image(scale=self.options.scale)
if page_image is None:
raise RuntimeError("The page image cannot be generated.")

for cluster, figure_bbox in in_clusters:
figure = page_image.crop(figure_bbox)
cluster_figure_batch.append((cluster, figure))

# if enough figures then flush
if len(cluster_figure_batch) == self.options.batch_size:
self._flush_merge(
page=page,
cluster_figure_batch=cluster_figure_batch,
figures_prediction=page.predictions.figures_prediction,
)
cluster_figure_batch = []

# final flush
if len(cluster_figure_batch) > 0:
self._flush_merge(
page=page,
cluster_figure_batch=cluster_figure_batch,
figures_prediction=page.predictions.figures_prediction,
)
cluster_figure_batch = []

yield page
Loading
Loading