diff --git a/docling/backend/html_backend.py b/docling/backend/html_backend.py index 286dfbf..234e5da 100644 --- a/docling/backend/html_backend.py +++ b/docling/backend/html_backend.py @@ -1,9 +1,9 @@ import logging from io import BytesIO from pathlib import Path -from typing import Optional, Set, Union +from typing import Optional, Union, cast -from bs4 import BeautifulSoup, Tag +from bs4 import BeautifulSoup, NavigableString, PageElement, Tag from docling_core.types.doc import ( DocItemLabel, DoclingDocument, @@ -12,6 +12,7 @@ from docling_core.types.doc import ( TableCell, TableData, ) +from typing_extensions import override from docling.backend.abstract_backend import DeclarativeDocumentBackend from docling.datamodel.base_models import InputFormat @@ -21,6 +22,7 @@ _log = logging.getLogger(__name__) class HTMLDocumentBackend(DeclarativeDocumentBackend): + @override def __init__(self, in_doc: "InputDocument", path_or_stream: Union[BytesIO, Path]): super().__init__(in_doc, path_or_stream) _log.debug("About to init HTML backend...") @@ -48,13 +50,16 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend): f"Could not initialize HTML backend for file with hash {self.document_hash}." ) from e + @override def is_valid(self) -> bool: return self.soup is not None @classmethod + @override def supports_pagination(cls) -> bool: return False + @override def unload(self): if isinstance(self.path_or_stream, BytesIO): self.path_or_stream.close() @@ -62,9 +67,11 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend): self.path_or_stream = None @classmethod - def supported_formats(cls) -> Set[InputFormat]: + @override + def supported_formats(cls) -> set[InputFormat]: return {InputFormat.HTML} + @override def convert(self) -> DoclingDocument: # access self.path_or_stream to load stuff origin = DocumentOrigin( @@ -80,98 +87,78 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend): assert self.soup is not None content = self.soup.body or self.soup # Replace
tags with newline characters - for br in content.find_all("br"): - br.replace_with("\n") - doc = self.walk(content, doc) + for br in content("br"): + br.replace_with(NavigableString("\n")) + self.walk(content, doc) else: raise RuntimeError( f"Cannot convert doc with {self.document_hash} because the backend failed to init." ) return doc - def walk(self, element: Tag, doc: DoclingDocument): - try: - # Iterate over elements in the body of the document - for idx, element in enumerate(element.children): + def walk(self, tag: Tag, doc: DoclingDocument) -> None: + # Iterate over elements in the body of the document + for element in tag.children: + if isinstance(element, Tag): try: - self.analyse_element(element, idx, doc) + self.analyze_tag(cast(Tag, element), doc) except Exception as exc_child: - - _log.error(" -> error treating child: ", exc_child) - _log.error(" => element: ", element, "\n") + _log.error( + f"Error processing child from tag{tag.name}: {exc_child}" + ) raise exc_child - except Exception as exc: - pass + return - return doc - - def analyse_element(self, element: Tag, idx: int, doc: DoclingDocument): - """ - if element.name!=None: - _log.debug("\t"*self.level, idx, "\t", f"{element.name} ({self.level})") - """ - - if element.name in self.labels: - self.labels[element.name] += 1 + def analyze_tag(self, tag: Tag, doc: DoclingDocument) -> None: + if tag.name in self.labels: + self.labels[tag.name] += 1 else: - self.labels[element.name] = 1 + self.labels[tag.name] = 1 - if element.name in ["h1", "h2", "h3", "h4", "h5", "h6"]: - self.handle_header(element, idx, doc) - elif element.name in ["p"]: - self.handle_paragraph(element, idx, doc) - elif element.name in ["pre"]: - self.handle_code(element, idx, doc) - elif element.name in ["ul", "ol"]: - self.handle_list(element, idx, doc) - elif element.name in ["li"]: - self.handle_listitem(element, idx, doc) - elif element.name == "table": - self.handle_table(element, idx, doc) - elif element.name == "figure": - self.handle_figure(element, idx, doc) - elif element.name == "img": - self.handle_image(element, idx, doc) + if tag.name in ["h1", "h2", "h3", "h4", "h5", "h6"]: + self.handle_header(tag, doc) + elif tag.name in ["p"]: + self.handle_paragraph(tag, doc) + elif tag.name in ["pre"]: + self.handle_code(tag, doc) + elif tag.name in ["ul", "ol"]: + self.handle_list(tag, doc) + elif tag.name in ["li"]: + self.handle_list_item(tag, doc) + elif tag.name == "table": + self.handle_table(tag, doc) + elif tag.name == "figure": + self.handle_figure(tag, doc) + elif tag.name == "img": + self.handle_image(doc) else: - self.walk(element, doc) + self.walk(tag, doc) - def get_direct_text(self, item: Tag): - """Get the direct text of the
  • element (ignoring nested lists).""" - text = item.find(string=True, recursive=False) - if isinstance(text, str): - return text.strip() + def get_text(self, item: PageElement) -> str: + """Get the text content of a tag.""" + parts: list[str] = self.extract_text_recursively(item) - return "" + return "".join(parts) + " " # Function to recursively extract text from all child nodes - def extract_text_recursively(self, item: Tag): - result = [] + def extract_text_recursively(self, item: PageElement) -> list[str]: + result: list[str] = [] - if isinstance(item, str): + if isinstance(item, NavigableString): return [item] - if item.name not in ["ul", "ol"]: - try: - # Iterate over the children (and their text and tails) - for child in item: - try: - # Recursively get the child's text content - result.extend(self.extract_text_recursively(child)) - except: - pass - except: - _log.warn("item has no children") - pass + tag = cast(Tag, item) + if tag.name not in ["ul", "ol"]: + for child in tag: + # Recursively get the child's text content + result.extend(self.extract_text_recursively(child)) - return "".join(result) + " " + return ["".join(result) + " "] - def handle_header(self, element: Tag, idx: int, doc: DoclingDocument): + def handle_header(self, element: Tag, doc: DoclingDocument) -> None: """Handles header tags (h1, h2, etc.).""" hlevel = int(element.name.replace("h", "")) - slevel = hlevel - 1 - - label = DocItemLabel.SECTION_HEADER text = element.text.strip() if hlevel == 1: @@ -197,7 +184,7 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend): elif hlevel < self.level: # remove the tail - for key, val in self.parents.items(): + for key in self.parents.keys(): if key > hlevel: self.parents[key] = None self.level = hlevel @@ -208,27 +195,24 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend): level=hlevel, ) - def handle_code(self, element: Tag, idx: int, doc: DoclingDocument): + def handle_code(self, element: Tag, doc: DoclingDocument) -> None: """Handles monospace code snippets (pre).""" if element.text is None: return text = element.text.strip() - label = DocItemLabel.CODE - if len(text) == 0: - return - doc.add_code(parent=self.parents[self.level], text=text) + if text: + doc.add_code(parent=self.parents[self.level], text=text) - def handle_paragraph(self, element: Tag, idx: int, doc: DoclingDocument): + def handle_paragraph(self, element: Tag, doc: DoclingDocument) -> None: """Handles paragraph tags (p).""" if element.text is None: return text = element.text.strip() label = DocItemLabel.PARAGRAPH - if len(text) == 0: - return - doc.add_text(parent=self.parents[self.level], label=label, text=text) + if text: + doc.add_text(parent=self.parents[self.level], label=label, text=text) - def handle_list(self, element: Tag, idx: int, doc: DoclingDocument): + def handle_list(self, element: Tag, doc: DoclingDocument) -> None: """Handles list tags (ul, ol) and their list items.""" if element.name == "ul": @@ -250,18 +234,17 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend): self.parents[self.level + 1] = None self.level -= 1 - def handle_listitem(self, element: Tag, idx: int, doc: DoclingDocument): + def handle_list_item(self, element: Tag, doc: DoclingDocument) -> None: """Handles listitem tags (li).""" - nested_lists = element.find(["ul", "ol"]) + nested_list = element.find(["ul", "ol"]) parent_list_label = self.parents[self.level].label index_in_list = len(self.parents[self.level].children) + 1 - if nested_lists: - name = element.name + if nested_list: # Text in list item can be hidden within hierarchy, hence # we need to extract it recursively - text = self.extract_text_recursively(element) + text: str = self.get_text(element) # Flatten text, remove break lines: text = text.replace("\n", "").replace("\r", "") text = " ".join(text.split()).strip() @@ -287,7 +270,7 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend): self.parents[self.level + 1] = None self.level -= 1 - elif isinstance(element.text, str): + elif element.text.strip(): text = element.text.strip() marker = "" @@ -302,59 +285,79 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend): parent=self.parents[self.level], ) else: - _log.warn("list-item has no text: ", element) - - def handle_table(self, element: Tag, idx: int, doc: DoclingDocument): - """Handles table tags.""" + _log.warning(f"list-item has no text: {element}") + @staticmethod + def parse_table_data(element: Tag) -> Optional[TableData]: nested_tables = element.find("table") if nested_tables is not None: - _log.warn("detected nested tables: skipping for now") - return + _log.warning("Skipping nested table.") + return None # Count the number of rows (number of elements) - num_rows = len(element.find_all("tr")) + num_rows = len(element("tr")) # Find the number of columns (taking into account colspan) num_cols = 0 - for row in element.find_all("tr"): + for row in element("tr"): col_count = 0 - for cell in row.find_all(["td", "th"]): - colspan = int(cell.get("colspan", 1)) + if not isinstance(row, Tag): + continue + for cell in row(["td", "th"]): + if not isinstance(row, Tag): + continue + val = cast(Tag, cell).get("colspan", "1") + colspan = int(val) if (isinstance(val, str) and val.isnumeric()) else 1 col_count += colspan num_cols = max(num_cols, col_count) - grid = [[None for _ in range(num_cols)] for _ in range(num_rows)] + grid: list = [[None for _ in range(num_cols)] for _ in range(num_rows)] data = TableData(num_rows=num_rows, num_cols=num_cols, table_cells=[]) # Iterate over the rows in the table - for row_idx, row in enumerate(element.find_all("tr")): + for row_idx, row in enumerate(element("tr")): + if not isinstance(row, Tag): + continue # For each row, find all the column cells (both and ) - cells = row.find_all(["td", "th"]) + cells = row(["td", "th"]) # Check if each cell in the row is a header -> means it is a column header col_header = True - for j, html_cell in enumerate(cells): - if html_cell.name == "td": + for html_cell in cells: + if isinstance(html_cell, Tag) and html_cell.name == "td": col_header = False + # Extract the text content of each cell col_idx = 0 - # Extract and print the text content of each cell - for _, html_cell in enumerate(cells): + for html_cell in cells: + if not isinstance(html_cell, Tag): + continue + # extract inline formulas + for formula in html_cell("inline-formula"): + math_parts = formula.text.split("$$") + if len(math_parts) == 3: + math_formula = f"$${math_parts[1]}$$" + formula.replace_with(NavigableString(math_formula)) + + # TODO: extract content correctly from table-cells with lists text = html_cell.text - try: - text = self.extract_table_cell_text(html_cell) - except Exception as exc: - _log.warn("exception: ", exc) - exit(-1) # label = html_cell.name - - col_span = int(html_cell.get("colspan", 1)) - row_span = int(html_cell.get("rowspan", 1)) + col_val = html_cell.get("colspan", "1") + col_span = ( + int(col_val) + if isinstance(col_val, str) and col_val.isnumeric() + else 1 + ) + row_val = html_cell.get("rowspan", "1") + row_span = ( + int(row_val) + if isinstance(row_val, str) and row_val.isnumeric() + else 1 + ) while grid[row_idx][col_idx] is not None: col_idx += 1 @@ -362,7 +365,7 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend): for c in range(col_span): grid[row_idx + r][col_idx + c] = text - cell = TableCell( + table_cell = TableCell( text=text, row_span=row_span, col_span=col_span, @@ -373,57 +376,57 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend): col_header=col_header, row_header=((not col_header) and html_cell.name == "th"), ) - data.table_cells.append(cell) + data.table_cells.append(table_cell) - doc.add_table(data=data, parent=self.parents[self.level]) + return data - def get_list_text(self, list_element: Tag, level=0): + def handle_table(self, element: Tag, doc: DoclingDocument) -> None: + """Handles table tags.""" + + table_data = HTMLDocumentBackend.parse_table_data(element) + + if table_data is not None: + doc.add_table(data=table_data, parent=self.parents[self.level]) + + def get_list_text(self, list_element: Tag, level: int = 0) -> list[str]: """Recursively extract text from