chore: expose draw_clusters function (#803)

feat: expose draw_clusters function



add type annotations to function signature

Signed-off-by: Yusik Kim <kmyusk@gmail.com>
This commit is contained in:
Yusik Kim 2025-01-24 17:35:29 +01:00 committed by GitHub
parent 3213b247ad
commit e9768ae6a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 88 additions and 84 deletions

View File

@ -1,28 +1,21 @@
import copy import copy
import logging import logging
import random
import time
from pathlib import Path from pathlib import Path
from typing import Iterable, List from typing import Iterable
from docling_core.types.doc import CoordOrigin, DocItemLabel from docling_core.types.doc import DocItemLabel
from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor
from PIL import Image, ImageDraw, ImageFont from PIL import Image
from docling.datamodel.base_models import ( from docling.datamodel.base_models import BoundingBox, Cluster, LayoutPrediction, Page
BoundingBox,
Cell,
Cluster,
LayoutPrediction,
Page,
)
from docling.datamodel.document import ConversionResult from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import AcceleratorDevice, AcceleratorOptions from docling.datamodel.pipeline_options import AcceleratorOptions
from docling.datamodel.settings import settings from docling.datamodel.settings import settings
from docling.models.base_model import BasePageModel from docling.models.base_model import BasePageModel
from docling.utils.accelerator_utils import decide_device from docling.utils.accelerator_utils import decide_device
from docling.utils.layout_postprocessor import LayoutPostprocessor from docling.utils.layout_postprocessor import LayoutPostprocessor
from docling.utils.profiling import TimeRecorder from docling.utils.profiling import TimeRecorder
from docling.utils.visualization import draw_clusters
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
@ -82,78 +75,9 @@ class LayoutModel(BasePageModel):
left_image = copy.deepcopy(page.image) left_image = copy.deepcopy(page.image)
right_image = copy.deepcopy(page.image) right_image = copy.deepcopy(page.image)
# Function to draw clusters on an image
def draw_clusters(image, clusters):
draw = ImageDraw.Draw(image, "RGBA")
# Create a smaller font for the labels
try:
font = ImageFont.truetype("arial.ttf", 12)
except OSError:
# Fallback to default font if arial is not available
font = ImageFont.load_default()
for c_tl in clusters:
all_clusters = [c_tl, *c_tl.children]
for c in all_clusters:
# Draw cells first (underneath)
cell_color = (0, 0, 0, 40) # Transparent black for cells
for tc in c.cells:
cx0, cy0, cx1, cy1 = tc.bbox.as_tuple()
cx0 *= scale_x
cx1 *= scale_x
cy0 *= scale_x
cy1 *= scale_y
draw.rectangle(
[(cx0, cy0), (cx1, cy1)],
outline=None,
fill=cell_color,
)
# Draw cluster rectangle
x0, y0, x1, y1 = c.bbox.as_tuple()
x0 *= scale_x
x1 *= scale_x
y0 *= scale_x
y1 *= scale_y
cluster_fill_color = (*list(DocItemLabel.get_color(c.label)), 70)
cluster_outline_color = (
*list(DocItemLabel.get_color(c.label)),
255,
)
draw.rectangle(
[(x0, y0), (x1, y1)],
outline=cluster_outline_color,
fill=cluster_fill_color,
)
# Add label name and confidence
label_text = f"{c.label.name} ({c.confidence:.2f})"
# Create semi-transparent background for text
text_bbox = draw.textbbox((x0, y0), label_text, font=font)
text_bg_padding = 2
draw.rectangle(
[
(
text_bbox[0] - text_bg_padding,
text_bbox[1] - text_bg_padding,
),
(
text_bbox[2] + text_bg_padding,
text_bbox[3] + text_bg_padding,
),
],
fill=(255, 255, 255, 180), # Semi-transparent white
)
# Draw text
draw.text(
(x0, y0),
label_text,
fill=(0, 0, 0, 255), # Solid black
font=font,
)
# Draw clusters on both images # Draw clusters on both images
draw_clusters(left_image, left_clusters) draw_clusters(left_image, left_clusters, scale_x, scale_y)
draw_clusters(right_image, right_clusters) draw_clusters(right_image, right_clusters, scale_x, scale_y)
# Combine the images side by side # Combine the images side by side
combined_width = left_image.width * 2 combined_width = left_image.width * 2
combined_height = left_image.height combined_height = left_image.height

View File

@ -0,0 +1,80 @@
from docling_core.types.doc import DocItemLabel
from PIL import Image, ImageDraw, ImageFont
from PIL.ImageFont import FreeTypeFont
from docling.datamodel.base_models import Cluster
def draw_clusters(
image: Image.Image, clusters: list[Cluster], scale_x: float, scale_y: float
) -> None:
"""
Draw clusters on an image
"""
draw = ImageDraw.Draw(image, "RGBA")
# Create a smaller font for the labels
font: ImageFont.ImageFont | FreeTypeFont
try:
font = ImageFont.truetype("arial.ttf", 12)
except OSError:
# Fallback to default font if arial is not available
font = ImageFont.load_default()
for c_tl in clusters:
all_clusters = [c_tl, *c_tl.children]
for c in all_clusters:
# Draw cells first (underneath)
cell_color = (0, 0, 0, 40) # Transparent black for cells
for tc in c.cells:
cx0, cy0, cx1, cy1 = tc.bbox.as_tuple()
cx0 *= scale_x
cx1 *= scale_x
cy0 *= scale_x
cy1 *= scale_y
draw.rectangle(
[(cx0, cy0), (cx1, cy1)],
outline=None,
fill=cell_color,
)
# Draw cluster rectangle
x0, y0, x1, y1 = c.bbox.as_tuple()
x0 *= scale_x
x1 *= scale_x
y0 *= scale_x
y1 *= scale_y
cluster_fill_color = (*list(DocItemLabel.get_color(c.label)), 70)
cluster_outline_color = (
*list(DocItemLabel.get_color(c.label)),
255,
)
draw.rectangle(
[(x0, y0), (x1, y1)],
outline=cluster_outline_color,
fill=cluster_fill_color,
)
# Add label name and confidence
label_text = f"{c.label.name} ({c.confidence:.2f})"
# Create semi-transparent background for text
text_bbox = draw.textbbox((x0, y0), label_text, font=font)
text_bg_padding = 2
draw.rectangle(
[
(
text_bbox[0] - text_bg_padding,
text_bbox[1] - text_bg_padding,
),
(
text_bbox[2] + text_bg_padding,
text_bbox[3] + text_bg_padding,
),
],
fill=(255, 255, 255, 180), # Semi-transparent white
)
# Draw text
draw.text(
(x0, y0),
label_text,
fill=(0, 0, 0, 255), # Solid black
font=font,
)