-
Notifications
You must be signed in to change notification settings - Fork 71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
experimental: introduce img understand pipeline #95
Draft
dolfim-ibm
wants to merge
1
commit into
main
Choose a base branch
from
feat/image-description
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
import base64 | ||
import datetime | ||
import io | ||
import logging | ||
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple | ||
|
||
import httpx | ||
from PIL import Image | ||
from pydantic import AnyUrl, BaseModel, ConfigDict | ||
|
||
from docling.datamodel.base_models import Cluster, FigureDescriptionData | ||
from docling.models.img_understand_base_model import ( | ||
ImgUnderstandBaseModel, | ||
ImgUnderstandOptions, | ||
) | ||
|
||
_log = logging.getLogger(__name__) | ||
|
||
|
||
class ImgUnderstandApiOptions(ImgUnderstandOptions): | ||
kind: Literal["api"] = "api" | ||
|
||
url: AnyUrl | ||
headers: Dict[str, str] | ||
params: Dict[str, Any] | ||
timeout: float = 20 | ||
|
||
llm_prompt: str | ||
provenance: str | ||
|
||
|
||
class ChatMessage(BaseModel): | ||
role: str | ||
content: str | ||
|
||
|
||
class ResponseChoice(BaseModel): | ||
index: int | ||
message: ChatMessage | ||
finish_reason: str | ||
|
||
|
||
class ResponseUsage(BaseModel): | ||
prompt_tokens: int | ||
completion_tokens: int | ||
total_tokens: int | ||
|
||
|
||
class ApiResponse(BaseModel): | ||
model_config = ConfigDict( | ||
protected_namespaces=(), | ||
) | ||
|
||
id: str | ||
model_id: Optional[str] = None # returned by watsonx | ||
model: Optional[str] = None # returned bu openai | ||
choices: List[ResponseChoice] | ||
created: int | ||
usage: ResponseUsage | ||
|
||
|
||
class ImgUnderstandApiModel(ImgUnderstandBaseModel): | ||
|
||
def __init__(self, enabled: bool, options: ImgUnderstandApiOptions): | ||
super().__init__(enabled=enabled, options=options) | ||
self.options: ImgUnderstandApiOptions | ||
|
||
def _annotate_image_batch( | ||
self, batch: Iterable[Tuple[Cluster, Image.Image]] | ||
) -> List[FigureDescriptionData]: | ||
|
||
if not self.enabled: | ||
return [FigureDescriptionData() for _ in batch] | ||
|
||
results = [] | ||
for cluster, image in batch: | ||
img_io = io.BytesIO() | ||
image.save(img_io, "PNG") | ||
image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8") | ||
|
||
messages = [ | ||
{ | ||
"role": "user", | ||
"content": [ | ||
{ | ||
"type": "text", | ||
"text": self.options.llm_prompt, | ||
}, | ||
{ | ||
"type": "image_url", | ||
"image_url": { | ||
"url": f"data:image/png;base64,{image_base64}" | ||
}, | ||
}, | ||
], | ||
} | ||
] | ||
|
||
payload = { | ||
"messages": messages, | ||
**self.options.params, | ||
} | ||
|
||
r = httpx.post( | ||
str(self.options.url), | ||
headers=self.options.headers, | ||
json=payload, | ||
timeout=self.options.timeout, | ||
) | ||
if not r.is_success: | ||
_log.error(f"Error calling the API. Reponse was {r.text}") | ||
r.raise_for_status() | ||
|
||
api_resp = ApiResponse.model_validate_json(r.text) | ||
generated_text = api_resp.choices[0].message.content.strip() | ||
results.append( | ||
FigureDescriptionData( | ||
text=generated_text, provenance=self.options.provenance | ||
) | ||
) | ||
_log.info(f"Generated description: {generated_text}") | ||
|
||
return results |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
import logging | ||
import time | ||
from typing import Iterable, List, Literal, Tuple | ||
|
||
from PIL import Image | ||
from pydantic import BaseModel | ||
|
||
from docling.datamodel.base_models import ( | ||
Cluster, | ||
FigureData, | ||
FigureDescriptionData, | ||
FigureElement, | ||
FigurePrediction, | ||
Page, | ||
) | ||
|
||
_log = logging.getLogger(__name__) | ||
|
||
|
||
class ImgUnderstandOptions(BaseModel): | ||
kind: str | ||
batch_size: int = 8 | ||
scale: float = 2 | ||
|
||
# if the relative area of the image with respect to the whole image page | ||
# is larger than this threshold it will be processed, otherwise not. | ||
# TODO: implement the skip logic | ||
min_area: float = 0.05 | ||
|
||
|
||
class ImgUnderstandBaseModel: | ||
|
||
def __init__(self, enabled: bool, options: ImgUnderstandOptions): | ||
self.enabled = enabled | ||
self.options = options | ||
|
||
def _annotate_image_batch( | ||
self, batch: Iterable[Tuple[Cluster, Image.Image]] | ||
) -> List[FigureDescriptionData]: | ||
raise NotImplemented() | ||
|
||
def _flush_merge( | ||
self, | ||
page: Page, | ||
cluster_figure_batch: List[Tuple[Cluster, Image.Image]], | ||
figures_prediction: FigurePrediction, | ||
): | ||
start_time = time.time() | ||
results_batch = self._annotate_image_batch(cluster_figure_batch) | ||
assert len(results_batch) == len( | ||
cluster_figure_batch | ||
), "The returned annotations is not matching the input size" | ||
end_time = time.time() | ||
_log.info( | ||
f"Batch of {len(results_batch)} images processed in {end_time-start_time:.1f} seconds. Time per image is {(end_time-start_time) / len(results_batch):.3f} seconds." | ||
) | ||
|
||
for (cluster, _), desc_data in zip(cluster_figure_batch, results_batch): | ||
if not cluster.id in figures_prediction.figure_map: | ||
figures_prediction.figure_map[cluster.id] = FigureElement( | ||
label=cluster.label, | ||
id=cluster.id, | ||
data=FigureData(desciption=desc_data), | ||
cluster=cluster, | ||
page_no=page.page_no, | ||
) | ||
elif figures_prediction.figure_map[cluster.id].data.description is None: | ||
figures_prediction.figure_map[cluster.id].data.description = desc_data | ||
else: | ||
_log.warning( | ||
f"Conflicting predictions. " | ||
f"Another model ({figures_prediction.figure_map[cluster.id].data.description.provenance}) " | ||
f"was already predicting an image description. The new prediction will be skipped." | ||
) | ||
|
||
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]: | ||
|
||
if not self.enabled: | ||
yield from page_batch | ||
return | ||
|
||
for page in page_batch: | ||
|
||
# This model could be the first one initializing figures_prediction | ||
if page.predictions.figures_prediction is None: | ||
page.predictions.figures_prediction = FigurePrediction() | ||
|
||
# Select the picture clusters | ||
in_clusters = [] | ||
for cluster in page.predictions.layout.clusters: | ||
if cluster.label != "Picture": | ||
continue | ||
|
||
crop_bbox = cluster.bbox.scaled( | ||
scale=self.options.scale | ||
).to_top_left_origin(page_height=page.size.height * self.options.scale) | ||
in_clusters.append( | ||
( | ||
cluster, | ||
crop_bbox.as_tuple(), | ||
) | ||
) | ||
|
||
if not len(in_clusters): | ||
yield page | ||
continue | ||
|
||
# save classifications using proper object | ||
if ( | ||
page.predictions.figures_prediction.figure_count > 0 | ||
and page.predictions.figures_prediction.figure_count != len(in_clusters) | ||
): | ||
raise RuntimeError( | ||
"Different models predicted a different number of figures." | ||
) | ||
page.predictions.figures_prediction.figure_count = len(in_clusters) | ||
|
||
cluster_figure_batch = [] | ||
page_image = page.get_image(scale=self.options.scale) | ||
if page_image is None: | ||
raise RuntimeError("The page image cannot be generated.") | ||
|
||
for cluster, figure_bbox in in_clusters: | ||
figure = page_image.crop(figure_bbox) | ||
cluster_figure_batch.append((cluster, figure)) | ||
|
||
# if enough figures then flush | ||
if len(cluster_figure_batch) == self.options.batch_size: | ||
self._flush_merge( | ||
page=page, | ||
cluster_figure_batch=cluster_figure_batch, | ||
figures_prediction=page.predictions.figures_prediction, | ||
) | ||
cluster_figure_batch = [] | ||
|
||
# final flush | ||
if len(cluster_figure_batch) > 0: | ||
self._flush_merge( | ||
page=page, | ||
cluster_figure_batch=cluster_figure_batch, | ||
figures_prediction=page.predictions.figures_prediction, | ||
) | ||
cluster_figure_batch = [] | ||
|
||
yield page |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's call that
min_area_frac
to be aligned.