Docling/docling/backend/docx/latex/omml.py
Michele Dolfi 5458a88464
ci: add coverage and ruff (#1383)
* add coverage calculation and push

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* new codecov version and usage of token

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* enable ruff formatter instead of black and isort

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* apply ruff lint fixes

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* apply ruff unsafe fixes

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* add removed imports

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* runs 1 on linter issues

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* finalize linter fixes

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* Update pyproject.toml

Co-authored-by: Cesar Berrospi Ramis <75900930+ceberam@users.noreply.github.com>
Signed-off-by: Michele Dolfi <97102151+dolfim-ibm@users.noreply.github.com>

---------

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
Signed-off-by: Michele Dolfi <97102151+dolfim-ibm@users.noreply.github.com>
Co-authored-by: Cesar Berrospi Ramis <75900930+ceberam@users.noreply.github.com>
2025-04-14 18:01:26 +02:00

451 lines
12 KiB
Python

"""
Office Math Markup Language (OMML)
Adapted from https://github.com/xiilei/dwml/blob/master/dwml/omml.py
On 23/01/2025
"""
import logging
import lxml.etree as ET
from pylatexenc.latexencode import UnicodeToLatexEncoder
from docling.backend.docx.latex.latex_dict import (
ALN,
ARR,
BACKSLASH,
BLANK,
BRK,
CHARS,
CHR,
CHR_BO,
CHR_DEFAULT,
D_DEFAULT,
F_DEFAULT,
FUNC,
FUNC_PLACE,
LIM_FUNC,
LIM_TO,
LIM_UPP,
POS,
POS_DEFAULT,
RAD,
RAD_DEFAULT,
SUB,
SUP,
D,
F,
M,
T,
)
OMML_NS = "{http://schemas.openxmlformats.org/officeDocument/2006/math}"
_log = logging.getLogger(__name__)
def load(stream):
tree = ET.parse(stream)
for omath in tree.findall(OMML_NS + "oMath"):
yield oMath2Latex(omath)
def load_string(string):
root = ET.fromstring(string)
for omath in root.findall(OMML_NS + "oMath"):
yield oMath2Latex(omath)
def escape_latex(strs):
last = None
new_chr = []
strs = strs.replace(r"\\", "\\")
for c in strs:
if (c in CHARS) and (last != BACKSLASH):
new_chr.append(BACKSLASH + c)
else:
new_chr.append(c)
last = c
return BLANK.join(new_chr)
def get_val(key, default=None, store=CHR):
if key is not None:
return key if not store else store.get(key, key)
else:
return default
class Tag2Method:
def call_method(self, elm, stag=None):
getmethod = self.tag2meth.get
if stag is None:
stag = elm.tag.replace(OMML_NS, "")
method = getmethod(stag)
if method:
return method(self, elm)
else:
return None
def process_children_list(self, elm, include=None):
"""
process children of the elm,return iterable
"""
for _e in list(elm):
if OMML_NS not in _e.tag:
continue
stag = _e.tag.replace(OMML_NS, "")
if include and (stag not in include):
continue
t = self.call_method(_e, stag=stag)
if t is None:
t = self.process_unknow(_e, stag)
if t is None:
continue
yield (stag, t, _e)
def process_children_dict(self, elm, include=None):
"""
process children of the elm,return dict
"""
latex_chars = dict()
for stag, t, e in self.process_children_list(elm, include):
latex_chars[stag] = t
return latex_chars
def process_children(self, elm, include=None):
"""
process children of the elm,return string
"""
return BLANK.join(
(
t if not isinstance(t, Tag2Method) else str(t)
for stag, t, e in self.process_children_list(elm, include)
)
)
def process_unknow(self, elm, stag):
return None
class Pr(Tag2Method):
text = ""
__val_tags = ("chr", "pos", "begChr", "endChr", "type")
__innerdict = None # can't use the __dict__
""" common properties of element"""
def __init__(self, elm):
self.__innerdict = {}
self.text = self.process_children(elm)
def __str__(self):
return self.text
def __unicode__(self):
return self.__str__(self)
def __getattr__(self, name):
return self.__innerdict.get(name, None)
def do_brk(self, elm):
self.__innerdict["brk"] = BRK
return BRK
def do_common(self, elm):
stag = elm.tag.replace(OMML_NS, "")
if stag in self.__val_tags:
t = elm.get(f"{OMML_NS}val")
self.__innerdict[stag] = t
return None
tag2meth = {
"brk": do_brk,
"chr": do_common,
"pos": do_common,
"begChr": do_common,
"endChr": do_common,
"type": do_common,
}
class oMath2Latex(Tag2Method):
"""
Convert oMath element of omml to latex
"""
_t_dict = T
__direct_tags = ("box", "sSub", "sSup", "sSubSup", "num", "den", "deg", "e")
u = UnicodeToLatexEncoder(
replacement_latex_protection="braces-all",
unknown_char_policy="keep",
unknown_char_warning=False,
)
def __init__(self, element):
self._latex = self.process_children(element)
def __str__(self):
return self.latex.replace(" ", " ")
def __unicode__(self):
return self.__str__(self)
def process_unknow(self, elm, stag):
if stag in self.__direct_tags:
return self.process_children(elm)
elif stag[-2:] == "Pr":
return Pr(elm)
else:
return None
@property
def latex(self):
return self._latex
def do_acc(self, elm):
"""
the accent function
"""
c_dict = self.process_children_dict(elm)
latex_s = get_val(
c_dict["accPr"].chr, default=CHR_DEFAULT.get("ACC_VAL"), store=CHR
)
return latex_s.format(c_dict["e"])
def do_bar(self, elm):
"""
the bar function
"""
c_dict = self.process_children_dict(elm)
pr = c_dict["barPr"]
latex_s = get_val(pr.pos, default=POS_DEFAULT.get("BAR_VAL"), store=POS)
return pr.text + latex_s.format(c_dict["e"])
def do_d(self, elm):
"""
the delimiter object
"""
c_dict = self.process_children_dict(elm)
pr = c_dict["dPr"]
null = D_DEFAULT.get("null")
s_val = get_val(pr.begChr, default=D_DEFAULT.get("left"), store=T)
e_val = get_val(pr.endChr, default=D_DEFAULT.get("right"), store=T)
delim = pr.text + D.format(
left=null if not s_val else escape_latex(s_val),
text=c_dict["e"],
right=null if not e_val else escape_latex(e_val),
)
return delim
def do_spre(self, elm):
"""
the Pre-Sub-Superscript object -- Not support yet
"""
def do_sub(self, elm):
text = self.process_children(elm)
return SUB.format(text)
def do_sup(self, elm):
text = self.process_children(elm)
return SUP.format(text)
def do_f(self, elm):
"""
the fraction object
"""
c_dict = self.process_children_dict(elm)
pr = c_dict["fPr"]
latex_s = get_val(pr.type, default=F_DEFAULT, store=F)
return pr.text + latex_s.format(num=c_dict.get("num"), den=c_dict.get("den"))
def do_func(self, elm):
"""
the Function-Apply object (Examples:sin cos)
"""
c_dict = self.process_children_dict(elm)
func_name = c_dict.get("fName")
return func_name.replace(FUNC_PLACE, c_dict.get("e"))
def do_fname(self, elm):
"""
the func name
"""
latex_chars = []
for stag, t, e in self.process_children_list(elm):
if stag == "r":
if FUNC.get(t):
latex_chars.append(FUNC[t])
else:
_log.warning("Function not supported, will default to text: %s", t)
if isinstance(t, str):
latex_chars.append(t)
elif isinstance(t, str):
latex_chars.append(t)
t = BLANK.join(latex_chars)
return t if FUNC_PLACE in t else t + FUNC_PLACE # do_func will replace this
def do_groupchr(self, elm):
"""
the Group-Character object
"""
c_dict = self.process_children_dict(elm)
pr = c_dict["groupChrPr"]
latex_s = get_val(pr.chr)
return pr.text + latex_s.format(c_dict["e"])
def do_rad(self, elm):
"""
the radical object
"""
c_dict = self.process_children_dict(elm)
text = c_dict.get("e")
deg_text = c_dict.get("deg")
if deg_text:
return RAD.format(deg=deg_text, text=text)
else:
return RAD_DEFAULT.format(text=text)
def do_eqarr(self, elm):
"""
the Array object
"""
return ARR.format(
text=BRK.join(
[t for stag, t, e in self.process_children_list(elm, include=("e",))]
)
)
def do_limlow(self, elm):
"""
the Lower-Limit object
"""
t_dict = self.process_children_dict(elm, include=("e", "lim"))
latex_s = LIM_FUNC.get(t_dict["e"])
if not latex_s:
raise RuntimeError("Not support lim {}".format(t_dict["e"]))
else:
return latex_s.format(lim=t_dict.get("lim"))
def do_limupp(self, elm):
"""
the Upper-Limit object
"""
t_dict = self.process_children_dict(elm, include=("e", "lim"))
return LIM_UPP.format(lim=t_dict.get("lim"), text=t_dict.get("e"))
def do_lim(self, elm):
"""
the lower limit of the limLow object and the upper limit of the limUpp function
"""
return self.process_children(elm).replace(LIM_TO[0], LIM_TO[1])
def do_m(self, elm):
"""
the Matrix object
"""
rows = []
for stag, t, e in self.process_children_list(elm):
if stag == "mPr":
pass
elif stag == "mr":
rows.append(t)
return M.format(text=BRK.join(rows))
def do_mr(self, elm):
"""
a single row of the matrix m
"""
return ALN.join(
[t for stag, t, e in self.process_children_list(elm, include=("e",))]
)
def do_nary(self, elm):
"""
the n-ary object
"""
res = []
bo = ""
for stag, t, e in self.process_children_list(elm):
if stag == "naryPr":
bo = get_val(t.chr, store=CHR_BO)
else:
res.append(t)
return bo + BLANK.join(res)
def process_unicode(self, s):
# s = s if isinstance(s,unicode) else unicode(s,'utf-8')
# print(s, self._t_dict.get(s, s), unicode_to_latex(s))
# _str.append( self._t_dict.get(s, s) )
out_latex_str = self.u.unicode_to_latex(s)
if (
s.startswith("{") is False
and out_latex_str.startswith("{")
and s.endswith("}") is False
and out_latex_str.endswith("}")
):
out_latex_str = f" {out_latex_str[1:-1]} "
if "ensuremath" in out_latex_str:
out_latex_str = out_latex_str.replace("\\ensuremath{", " ")
out_latex_str = out_latex_str.replace("}", " ")
if out_latex_str.strip().startswith("\\text"):
out_latex_str = f" \\text{{{out_latex_str}}} "
return out_latex_str
def do_r(self, elm):
"""
Get text from 'r' element,And try convert them to latex symbols
@todo text style support , (sty)
@todo \text (latex pure text support)
"""
_str = []
_base_str = []
found_text = elm.findtext(f"./{OMML_NS}t")
if found_text:
for s in found_text:
out_latex_str = self.process_unicode(s)
_str.append(out_latex_str)
_base_str.append(s)
proc_str = escape_latex(BLANK.join(_str))
base_proc_str = BLANK.join(_base_str)
if "{" not in base_proc_str and "\\{" in proc_str:
proc_str = proc_str.replace("\\{", "{")
if "}" not in base_proc_str and "\\}" in proc_str:
proc_str = proc_str.replace("\\}", "}")
return proc_str
tag2meth = {
"acc": do_acc,
"r": do_r,
"bar": do_bar,
"sub": do_sub,
"sup": do_sup,
"f": do_f,
"func": do_func,
"fName": do_fname,
"groupChr": do_groupchr,
"d": do_d,
"rad": do_rad,
"eqArr": do_eqarr,
"limLow": do_limlow,
"limUpp": do_limupp,
"lim": do_lim,
"m": do_m,
"mr": do_mr,
"nary": do_nary,
}