diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..8141b74 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,3 @@ +* text=auto eol=lf +/examples/* filter=lfs diff=lfs merge=lfs -text +/models/* filter=lfs diff=lfs merge=lfs -text diff --git a/GLAM/__init__.py b/GLAM/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/GLAM/common.py b/GLAM/common.py new file mode 100644 index 0000000..23a5efe --- /dev/null +++ b/GLAM/common.py @@ -0,0 +1,386 @@ +from scipy.spatial import KDTree +import logging +import math +import re +import time +from collections import namedtuple +from dataclasses import dataclass +from typing import NamedTuple, Sequence, MutableSequence, Optional, Generator, Iterable, Any, Mapping + +import cv2 +import easyocr +import numpy as np +import tesserocr +import torch +from tesserocr import PyTessBaseAPI + +import fitz # PyMuPDF +from PIL import Image, ImageDraw + + +INVALID_UNICODE = chr(0xFFFD) +EasyocrTextResult = namedtuple("EasyocrTextResult", ["bbox", "text", "confidence"]) +MuPDFTextTraceChar = namedtuple("MuPDFTextTraceChar", ["unicode", "glyph", "origin", "bbox"]) +logger = logging.getLogger(__name__) +HI_RES_MATRIX = fitz.Matrix(5, 5) +# HI_RES_MATRIX = fitz.Matrix(1, 1) # enough for tesseract in english + + +def truncate_with_ellipsis(s: str, max_length=128): + return (s[:max_length - 3] + '...') if len(s) > max_length else s + + +def get_color_mode(colorspace, alpha: bool, unmultiply: bool = False) -> str: + if colorspace is None: + return "L" + if colorspace.n == 1: + return "LA" if alpha else "L" + if colorspace.n == 3: + return "RGBa" if (alpha and unmultiply) else "RGBA" if alpha else "RGB" + return "CMYK" + + +def get_bytes_per_pixel(colorspace, alpha: bool) -> int: + if colorspace is None: + return 1 + if colorspace.n == 1: + return 1 + int(alpha) + if colorspace.n == 3: + return 3 + int(alpha) + return 4 + + +def pixmap_to_image(pixmap: fitz.Pixmap, unmultiply: bool = False, mode: Optional[str] = None) -> Image.Image: + """ + Makes a view of the pixmap samples as a PIL image, copies if colorspace is odd. + """ + mode = mode or get_color_mode(pixmap.colorspace, pixmap.alpha, unmultiply) + image = Image.frombuffer(mode, (int(pixmap.w), int(pixmap.h)), pixmap.samples, "raw", mode, 0, 1) + return image + + +def pixmap_to_ndarray(pixmap: fitz.Pixmap) -> np.ndarray: + """ + Makes a view of the pixmap samples as a numpy array. + """ + bpp = get_bytes_per_pixel(pixmap.colorspace, pixmap.alpha) + image_data = np.frombuffer(pixmap.samples, dtype=np.uint8) + if bpp == 1: + image_data = image_data.reshape((pixmap.height, pixmap.stride)) + else: + image_data = image_data.reshape((pixmap.height, pixmap.stride // bpp, bpp)) + # image_data = image_data[:, :pixmap.width, ...] + return image_data + + +def image_to_ndarray(image: Image.Image) -> np.ndarray: + """ + Makes a view of the image samples as a numpy array. + """ + image_data = np.frombuffer(image.tobytes(), dtype=np.uint8) + image_data = image_data.reshape((image.height, image.width, 4)) + return image_data + + +# Frozen is used to implement __hash__ and __eq__ +@dataclass(frozen=True) +class Node: + bbox_min_x: float + bbox_min_y: float + bbox_max_x: float + bbox_max_y: float + + def centroid(self) -> tuple[float, float]: + return ( + (self.bbox_min_x + self.bbox_max_x) / 2, + (self.bbox_min_y + self.bbox_max_y) / 2, + ) + + +@dataclass(frozen=True) +class TextNode(Node): + """ + Text node in the GLAM graph. + """ + text: str + font_name: str + font_size: float + font_color_r: float # 0.0-1.0 + font_color_g: float # 0.0-1.0 + font_color_b: float # 0.0-1.0 + + @classmethod + def from_span( + cls, + span: dict, # span from PyMuPDF page.get_text("dict")["blocks"][...]["lines"][...]["spans"][...] + *, + text: Optional[str] = None, + ) -> "TextNode": + return cls( + bbox_min_x=span["bbox"][0], + bbox_min_y=span["bbox"][1], + bbox_max_x=span["bbox"][2], + bbox_max_y=span["bbox"][3], + text=text or span["text"], + font_name=span["font"], + font_size=span["size"], + font_color_r=((span["color"]) & 0xFF) / 255, # TODO: check if correct order + font_color_g=((span["color"] >> 8) & 0xFF) / 255, + font_color_b=((span["color"] >> 16) & 0xFF) / 255, + ) + + +@dataclass(frozen=True) +class ImageNode(Node): + """ + Image node in the GLAM graph. + """ + image: Optional[Image] = None + + @classmethod + def from_page_block( + cls, + page_block: Mapping[str, Any], # page_block from PyMuPDF page.get_text("dict")["blocks"][...] + ) -> "ImageNode": + return cls( + bbox_min_x=page_block["bbox"][0], + bbox_min_y=page_block["bbox"][1], + bbox_max_x=page_block["bbox"][2], + bbox_max_y=page_block["bbox"][3], + image=None # TODO: implement + ) + + +class Edge(NamedTuple): + """ + Edge in the GLAM graph. + """ + node_index_1: int + node_index_2: int + centroid_distance: float # pixels + centroid_angle: float # radians + ordered_hint: int # 0-1 + + @classmethod + def from_node_pair( + cls, + node_1: Node, + node_2: Node, + node_index_1: int, + node_index_2: int, + ordered_hint: int, + ) -> "Edge": + node_1_centroid = node_1.centroid() + node_2_centroid = node_2.centroid() + delta_x = node_2_centroid[0] - node_1_centroid[0] + delta_y = node_2_centroid[1] - node_1_centroid[1] + centroid_distance = math.hypot(delta_x, delta_y) + centroid_angle = math.atan2(delta_y, delta_x) + + return cls( + node_index_1=node_index_1, + node_index_2=node_index_2, + centroid_distance=centroid_distance, + centroid_angle=centroid_angle, + ordered_hint=ordered_hint, + ) + + def copy_inverted(self) -> "Edge": + return Edge( + node_index_1=self.node_index_2, + node_index_2=self.node_index_1, + centroid_distance=self.centroid_distance, + centroid_angle=self.centroid_angle + math.pi, + ordered_hint=self.ordered_hint, + ) + + +class PageNodes(list[Node]): + features_len = 14 + re_text_spaces = re.compile(r"\s+") + + def to_node_features(self) -> torch.Tensor: + node_list = [ + [ + # Node - type + int(isinstance(node, TextNode)), + int(isinstance(node, ImageNode)), + # Node + node.bbox_min_x, + node.bbox_min_y, + node.bbox_max_x, + node.bbox_max_y, + # TextNode + node.font_color_r if isinstance(node, TextNode) else 0, + node.font_color_g if isinstance(node, TextNode) else 0, + node.font_color_b if isinstance(node, TextNode) else 0, + node.font_size if isinstance(node, TextNode) else 0, + "bold" in node.font_name.lower() if isinstance(node, TextNode) else 0, + "italic" in node.font_name.lower() if isinstance(node, TextNode) else 0, + len(node.text) if isinstance(node, TextNode) else 0, + len(self.re_text_spaces.split(node.text)) if isinstance(node, TextNode) else 0, + # ImageNode - nothing + ] + for node in self + ] + return torch.tensor(node_list, dtype=torch.float32) + + +def is_angle_in_range(angle, range): + start, end = range + if start > end: # This handles the wrap-around case + return start <= angle or angle < end + else: + return start <= angle < end + + +class PageEdges(list[Edge]): + features_len = 3 + + def to_edge_index(self) -> torch.Tensor: + edge_list = [ + [edge.node_index_1, edge.node_index_2] + for edge in self + ] + return torch.tensor(edge_list, dtype=torch.int64) + + def to_edge_features(self) -> torch.Tensor: + return torch.tensor( + [ + [ + edge.centroid_distance, + edge.centroid_angle, + edge.ordered_hint, + ] + for edge in self + ], + dtype=torch.float32, + ) + + @classmethod + def from_page_nodes_as_complete_graph(cls, page_nodes) -> "PageEdges": + page_edges = cls() + for i in range(len(page_nodes)): + for j in range(len(page_nodes)): + if i == j: + continue + edge = Edge.from_node_pair( + node_1=page_nodes[i], + node_2=page_nodes[j], + node_index_1=i, + node_index_2=j, + ordered_hint=int(i + 1 == j), + ) + page_edges.extend([edge, edge.copy_inverted()]) + return page_edges + + @classmethod + def from_page_nodes_by_top_closest(cls, page_nodes, always_has_next=True, k=10 + 1) -> "PageEdges": + page_edges = cls() + centroids = np.array([node.centroid() for node in page_nodes]) + k = min(len(centroids), k) + tree = KDTree(centroids) + + for i, centroid in enumerate(centroids): + if always_has_next and i + 1 < len(page_nodes): + edge = Edge.from_node_pair( + node_1=page_nodes[i], + node_2=page_nodes[i + 1], + node_index_1=i, + node_index_2=i + 1, + ordered_hint=int(True), + ) + page_edges.extend([edge, edge.copy_inverted()]) + + distances, indices = tree.query(centroid, k=k) + for j, distance in zip(indices, distances): + if i == j or always_has_next and i + 1 == j: + continue + + delta_x = centroids[j][0] - centroid[0] + delta_y = centroids[j][1] - centroid[1] + centroid_angle = math.atan2(delta_y, delta_x) + + edge = Edge( + node_index_1=i, + node_index_2=j, + centroid_distance=distance, + centroid_angle=centroid_angle, + ordered_hint=int(i + 1 == j), + ) + page_edges.extend([edge, edge.copy_inverted()]) + + return page_edges + + @classmethod + def from_page_nodes_by_directions( + cls, + page_nodes, + always_has_next=True, + # Left, Up, Right, Down (because in both PIL and MuPDF coordinate system Y axis is inverted) + directions=((3 * math.pi / 4, -3 * math.pi / 4), (-3 * math.pi / 4, -1 * math.pi / 4), + (-1 * math.pi / 4, 1 * math.pi / 4), (1 * math.pi / 4, 3 * math.pi / 4)), + k=10 + 1 + ) -> "PageEdges": + page_edges = cls() + centroids = np.array([node.centroid() for node in page_nodes]) + k = min(len(centroids), k) + tree = KDTree(centroids) + + for i, centroid in enumerate(centroids): + if always_has_next and i + 1 < len(page_nodes): + edge = Edge.from_node_pair( + node_1=page_nodes[i], + node_2=page_nodes[i + 1], + node_index_1=i, + node_index_2=i + 1, + ordered_hint=int(True), + ) + page_edges.extend([ + edge, + edge.copy_inverted() + ]) + + closest_nodes = {dir_range: None for dir_range in directions} + distances, indices = tree.query(centroid, k=k) + + for j, distance in zip(indices, distances): + if i == j: + continue + + delta_x = centroids[j][0] - centroid[0] + delta_y = centroids[j][1] - centroid[1] + centroid_angle = math.atan2(delta_y, delta_x) # -pi..pi + + # Check each direction + for dir_range in directions: + if is_angle_in_range(centroid_angle, dir_range): + if closest_nodes[dir_range] is None or distance < closest_nodes[dir_range][1]: + closest_nodes[dir_range] = (j, distance, centroid_angle) + + # Create edges for each direction + for dir_range, node_info in closest_nodes.items(): + # if True: + # dir_range = directions[1] + node_info = closest_nodes[dir_range] + + if node_info is None: + continue + + j, distance, centroid_angle = node_info + if always_has_next and i + 1 == j: + continue + + edge = Edge( + node_index_1=i, + node_index_2=j, + centroid_distance=distance, + centroid_angle=centroid_angle, + ordered_hint=int(i + 1 == j), + ) + page_edges.extend([ + edge, + edge.copy_inverted(), + ]) + + return page_edges diff --git a/GLAM/models.py b/GLAM/models.py new file mode 100644 index 0000000..c533c12 --- /dev/null +++ b/GLAM/models.py @@ -0,0 +1,100 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_geometric.nn import GCNConv, TAGConv +from torch_geometric.data import Data + + +class NodeEncoder(nn.Module): + def __init__(self, node_features_len: int, initial_hidden_len: int, activation=nn.GELU()): + super(NodeEncoder, self).__init__() + self.batch_norm = nn.BatchNorm1d(node_features_len) + self.activation = activation + self.linear1 = nn.Linear(node_features_len, (cur_len := initial_hidden_len)) + self.conv1 = TAGConv(cur_len, (cur_len := cur_len // 2)) + self.linear2 = nn.Linear(cur_len, (cur_len := cur_len // 2)) + self.conv2 = TAGConv(cur_len, (cur_len := cur_len // 2)) + + self.linear3 = nn.Linear((cur_len := cur_len + node_features_len), (cur_len := cur_len // 2)) + self.linear4 = nn.Linear(cur_len, (cur_len := cur_len // 2)) + self.embeddings_len = cur_len + + def forward(self, data: Data) -> torch.Tensor: + x = data.node_features + x = self.batch_norm(x) + x = self.activation(self.linear1(x)) + x = self.activation(self.conv1(x, data.edge_index)) + x = self.activation(self.linear2(x)) + x = self.activation(self.conv2(x, data.edge_index)) + x = torch.cat([x, data.node_features], dim=1) + x = self.activation(self.linear3(x)) + x = self.linear4(x) + return x + + +class NodeClassifier(nn.Module): + def __init__(self, node_embeddings_len: int, classes_len: int, activation=nn.GELU()): + super(NodeClassifier, self).__init__() + self.activation = activation + self.linear = nn.Linear(node_embeddings_len, classes_len) + self.classes_len = classes_len + + def forward(self, node_embeddings: torch.Tensor) -> torch.Tensor: + return self.linear(self.activation(node_embeddings)) + + +class EdgeEncoder(nn.Module): + def __init__(self, node_embeddings_len: int, edge_feature_len: int, activation=nn.GELU()): + super(EdgeEncoder, self).__init__() + cur_len = node_embeddings_len + edge_feature_len + self.batch_norm = nn.BatchNorm1d(cur_len) + self.activation = activation + self.linear1 = nn.Linear(cur_len, (cur_len := cur_len // 2)) + self.linear2 = nn.Linear(cur_len, (cur_len := cur_len // 2)) + self.embeddings_len = cur_len + + def forward(self, node_embeddings: torch.Tensor, data: Data) -> torch.Tensor: + aggregated_node_features = (node_embeddings[data.edge_index[0]] + node_embeddings[data.edge_index[1]]) / 2 + x = torch.cat([aggregated_node_features, data.edge_features], dim=1) + x = self.batch_norm(x) + x = self.activation(self.linear1(x)) + x = self.linear2(x) + return x + + +class EdgeClassifier(nn.Module): + def __init__(self, edge_embeddings_len: int, activation=nn.GELU()): + super(EdgeClassifier, self).__init__() + self.activation = activation + self.linear = nn.Linear(edge_embeddings_len, (cur_len := 1)) + self.classes_len = cur_len + + def forward(self, edge_embeddings: torch.Tensor) -> torch.Tensor: + return self.linear(self.activation(edge_embeddings)) + + +class GLAMGraphNetwork(nn.Module): + """https://arxiv.org/abs/2308.02051""" + + def __init__(self, node_features_len, edge_feature_len, initial_hidden_len, node_classes_len): + super(GLAMGraphNetwork, self).__init__() + self.node_encoder = NodeEncoder(node_features_len=node_features_len, initial_hidden_len=initial_hidden_len) + self.node_classifier = NodeClassifier(node_embeddings_len=self.node_encoder.embeddings_len, classes_len=node_classes_len) + self.edge_encoder = EdgeEncoder(node_embeddings_len=self.node_encoder.embeddings_len, edge_feature_len=edge_feature_len) + self.edge_classifier = EdgeClassifier(edge_embeddings_len=self.edge_encoder.embeddings_len) + + def forward(self, data: Data) -> (torch.Tensor, torch.Tensor): + node_embeddings = self.node_encoder(data) + node_class_scores = self.node_classifier(node_embeddings) + edge_embeddings = self.edge_encoder(node_embeddings, data) + edge_class_scores = self.edge_classifier(edge_embeddings) + return node_class_scores, edge_class_scores + + +def main(): + model = GLAMGraphNetwork(10, 20, 30, 40) + print(model) + + +if __name__ == '__main__': + main() diff --git a/LICENSE-APACHE-2.0 b/LICENSE-APACHE-2.0 new file mode 100644 index 0000000..adcf15c --- /dev/null +++ b/LICENSE-APACHE-2.0 @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2024 Ivan Stepanov + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/LICENSE-MIT b/LICENSE-MIT new file mode 100644 index 0000000..7dc6f87 --- /dev/null +++ b/LICENSE-MIT @@ -0,0 +1,25 @@ +Copyright (c) 2024 Ivan Stepanov + +Permission is hereby granted, free of charge, to any +person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the +Software without restriction, including without +limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice +shall be included in all copies or substantial portions +of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..29f343f --- /dev/null +++ b/README.md @@ -0,0 +1,105 @@ +# GLAM + +Graph-based Layout Analysis Model (GLAM) is a deep learning model for document layout analysis. + +Unofficial implementation in PyTorch of "A Graphical Approach to Document Layout Analysis" [[arXiv]](https://arxiv.org/abs/2308.02051). + +## Introduction + +The Graph-based Layout Analysis Model (GLAM) is a novel deep learning model designed for advanced document layout analysis. This repository contains an unofficial PyTorch implementation of the model as described in the paper "A Graphical Approach to Document Layout Analysis". You can find the original paper [here](https://arxiv.org/abs/2308.02051). + +Large language models do not have ability to read PDFs, since PDF is not a text format, but a mapping of (sometimes, unlabeled) font glyphs to positions on a page. Of course, you can use image embedding or OCR, but you need to find what is a title, what is a paragraph, what is a table, what is a figure, etc. in order to make structured data from unstructured PDF. GLAM does it for you. After you have structured data, you can make dataset for your retrieval augmented generation RAG tasks. + +## Prerequisites + +- Python 3.6+ +- pip +- Optional: Tesseract or EasyOCR +- Optional: Git with [Git LFS](https://git-lfs.github.com/) support + +### Ubuntu/Debian + +```shell +apt-get update -q -y +apt-get install -q -y tesseract-ocr tesseract-ocr-eng tesseract-ocr-deu tesseract-ocr-fra tesseract-ocr-jpn +python -m pip install -q -U -r requirements.txt +TESSDATA_PREFIX=/usr/share/tesseract-ocr/4.00/tessdata +``` + +## Dataset preparation + +Download and extract DocLayNet dataset: + +```shell +python dln_download_and_extract.py --download-path /home/i/dataset/DocLayNet/raw --extract-path /home/i/dataset/DocLayNet/raw/DocLayNet +``` + +Make own DocLayNet-v1.1, [free from bugs](https://huggingface.co/datasets/ds4sd/DocLayNet-v1.1/discussions/1), parsing spans with unlabelled glyphs with Tesseract: + +```shell +python dln_parse_pdf.py --dataset-path /home/i/dataset/DocLayNet/raw/DocLayNet --image-scale 1 +``` + +Make training examples: + +```shell +python dln_glam_prepare.py --dataset-path /home/i/dataset/DocLayNet/raw/DocLayNet/DATA --output-path /home/i/dataset/DocLayNet/glam +``` + +## Training + +Some paths are hardcoded in `dln_glam_train.py`. Please, change them before training. + +```shell +python dln_glam_train.py +``` + +## Evaluation + +Please, change paths in `dln_glam_evaluate.py` before evaluation. + +```shell +python dln_glam_inference.py +``` + +## Features + +- Simple architecture. +- Fast. With batch size of 128 examples it takes 00:11:35 for training on 507 batches and 00:02:17 for validation on 48 batches on CPU per 1 epoch. + +## Limitations + +- No reading order prediction, though it is not objective of this model, and dataset does not contain such information. + +## TODO + +- Implement mAP@IoU\[0.5:0.05:0.95] metric because there is no way to compare with other models yet. +- Implement input features normalization. +- Implement text and image features. +- Batching in inference. Currently, only one page is processed at a time. +- W&B integration for training. +- Some text spans in PDF contains unlabelled font glyphs. Currently, whole span is passed to OCR. It is faster to OCR font glyphs separately and then merge them into spans. + +## Alternatives + +* [Kensho Extract](https://kensho.com/extract) (GLAM author's SaaS closed-source implementation) +* [Unstructured](https://github.com/Unstructured-IO/unstructured) + +## License + +Licensed under either of + +* Apache License, Version 2.0, ([LICENSE-APACHE-2.0](LICENSE-APACHE-2.0) or http://www.apache.org/licenses/LICENSE-2.0) +* MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT) + +at your option. + +## Contribution + +Unless you explicitly state otherwise, any contribution intentionally submitted +for inclusion in the work by you, as defined in the Apache-2.0 license, shall be dual licensed as above, without any +additional terms or conditions. + +## Acknowledgements + +- Jilin Wang, Michael Krumdick, Baojia Tong, Hamima Halim, Maxim Sokolov, Vadym Barda, Delphine Vendryes, and Chris Tanner. "A Graphical Approach to Document Layout Analysis". 2023. arXiv: [2308.02051](https://arxiv.org/abs/2308.02051) diff --git a/dln_download_and_extract.py b/dln_download_and_extract.py new file mode 100644 index 0000000..5d9e436 --- /dev/null +++ b/dln_download_and_extract.py @@ -0,0 +1,125 @@ +import argparse +import logging +import os +import signal +from typing import Optional, Callable +from urllib.parse import urlparse +from zipfile import ZipFile + +import requests +from tqdm import tqdm + + +def download(url: str, output_filepath: Optional[str] = None) -> str: + if not output_filepath: + parsed_url = urlparse(url) + output_filepath = os.path.basename(parsed_url.path) + + if os.path.exists(output_filepath): + return output_filepath + + tmp_output_filepath = output_filepath + ".download" + + if os.path.exists(tmp_output_filepath): + resume_byte_pos = os.path.getsize(tmp_output_filepath) + headers = {'Range': f'bytes={resume_byte_pos}-'} if resume_byte_pos else {} + else: + resume_byte_pos = 0 + headers = {} + + with requests.get(url, stream=True, headers=headers) as response: + if response.status_code == 206: + pass + elif response.status_code == 200: + resume_byte_pos = 0 + else: + raise Exception(f"Unexpected response status code: {response.status_code}") + + total_size_in_bytes = int(response.headers.get('content-length', 0)) + resume_byte_pos + block_size = 4 << 20 # 4 MiB + progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True, initial=resume_byte_pos) + + # Write or append to the file + mode = 'ab' if resume_byte_pos else 'wb' + with open(tmp_output_filepath, mode) as file: + for data in response.iter_content(block_size): + file.write(data) + progress_bar.update(len(data)) + + progress_bar.close() + + if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: + raise Exception("Something went wrong during the download") + + os.rename(tmp_output_filepath, output_filepath) + + return output_filepath + + +def extract_zip( + zip_filepath: str, + extract_path: Optional[str] = None, + filename_predicate: Optional[Callable[[str], bool]] = None +) -> str: + extract_path = extract_path if extract_path else os.getcwd() + + with ZipFile(zip_filepath, 'r') as zip_ref: + file_list = [file for file in zip_ref.namelist() if filename_predicate(file)] if filename_predicate else zip_ref.namelist() + with tqdm(total=len(file_list), desc="Extracting files") as progress_bar: + for file in file_list: + zip_ref.extract(member=file, path=extract_path) + progress_bar.update(1) + + return extract_path + + +def download_extract_cached( + url: str, + *, + output_filepath: Optional[str] = None, + extract_path: Optional[str] = None, + filename_predicate: Optional[Callable[[str], bool]] = None, + extracted_marker_filepath: Optional[str] = None +) -> str: + extract_path = extract_path if extract_path else os.getcwd() + if extracted_marker_filepath and os.path.exists(extracted_marker_filepath): + return extract_path + output_filepath = download(url, output_filepath=output_filepath) + extract_zip(output_filepath, extract_path=extract_path, filename_predicate=filename_predicate) + if extracted_marker_filepath: + with open(extracted_marker_filepath, "w") as f: + f.write("") + return extract_path + + +def main(): + parser = argparse.ArgumentParser("DocLayNet dataset downloader and extractor") + parser.add_argument("--download-path", type=str, default="/home/i/dataset/DocLayNet/raw", + help="Directory for the raw dataset (default: %(default)s)") + parser.add_argument("--extract-path", type=str, default="/home/i/dataset/DocLayNet/raw/DocLayNet", + help="Directory for the processed dataset (default: %(default)s)") + args = parser.parse_args() + + os.makedirs(args.extract_path, exist_ok=True) + download_extract_cached( + "https://codait-cos-dax.s3.us.cloud-object-storage.appdomain.cloud/dax-doclaynet/1.0.0/DocLayNet_core.zip", + output_filepath=os.path.join(args.download_path, "DocLayNet_core.zip"), + extract_path=args.extract_path, + # filename_predicate=lambda filename: (filename.startswith("COCO/")), + filename_predicate=lambda filename: (filename.startswith("COCO/") + and not os.path.exists(os.path.join(args.extract_path, filename))), + extracted_marker_filepath=os.path.join(args.extract_path, ".core_extracted.txt"), + ) + download_extract_cached( + "https://codait-cos-dax.s3.us.cloud-object-storage.appdomain.cloud/dax-doclaynet/1.0.0/DocLayNet_extra.zip", + output_filepath=os.path.join(args.download_path, "DocLayNet_extra.zip"), + extract_path=args.extract_path, + # filename_predicate=lambda filename: (filename.startswith("PDF/")), + filename_predicate=lambda filename: (filename.startswith("PDF/") + and not os.path.exists(os.path.join(args.extract_path, filename))), + extracted_marker_filepath=os.path.join(args.extract_path, ".extra_extracted.txt"), + ) + + +if __name__ == '__main__': + main() diff --git a/dln_glam_inference.py b/dln_glam_inference.py new file mode 100644 index 0000000..90f260b --- /dev/null +++ b/dln_glam_inference.py @@ -0,0 +1,150 @@ +import logging +import math +import re +import time +from collections import namedtuple +from dataclasses import dataclass +from typing import NamedTuple, Sequence, MutableSequence, Optional, Generator, Iterable, Any, Mapping + +import cv2 +import easyocr +import networkx as nx +import numpy as np +import tesserocr +import torch +from shapely import Polygon +from tesserocr import PyTessBaseAPI +from torch_geometric.data import Data + +import models +import fitz # PyMuPDF +from PIL import Image + +from GLAM.common import PageEdges, ImageNode, TextNode, get_bytes_per_pixel, PageNodes +from GLAM.models import GLAMGraphNetwork +from dln_glam_prepare import CLASSES_MAP + +INVALID_UNICODE = chr(0xFFFD) +EasyocrTextResult = namedtuple("EasyocrTextResult", ["bbox", "text", "confidence"]) +MuPDFTextTraceChar = namedtuple("MuPDFTextTraceChar", ["unicode", "glyph", "origin", "bbox"]) +logger = logging.getLogger(__name__) + + +def main(): + pdf_filepath = "examples/pdf/book law.pdf" + model_filepath = "models/glam_dln.pt" + easyocr_languages = ["en", "ar"] + TESSDATA_PREFIX = "/usr/share/tesseract/tessdata" + tesserocr_languages = ["eng", "ara"] + + device = ("cuda" if torch.cuda.is_available() else "cpu") + + logger.setLevel(logging.DEBUG) + logger.addHandler(logging.StreamHandler()) + # reader = easyocr.Reader(easyocr_languages) + api = PyTessBaseAPI(path=TESSDATA_PREFIX, lang="+".join(tesserocr_languages)) + + model = GLAMGraphNetwork(PageNodes.features_len, PageEdges.features_len, 512, len(CLASSES_MAP)) + model.load_state_dict(torch.load(model_filepath)) + model = model.to(device) + model.eval() + + doc = fitz.Document(pdf_filepath) + for page in doc: + # Find all nodes + page_nodes = PageNodes() + page_dict = fitz.utils.get_text( + page=page, + option="dict", + flags=fitz.TEXT_PRESERVE_IMAGES + ) + for block in page_dict["blocks"]: + if block["type"] == 0: + for line in block["lines"]: + for span in line["spans"]: + text = span["text"] + + if INVALID_UNICODE in text: + ls = " " * (len(text) - len(text.lstrip())) + rs = " " * (len(text) - len(text.rstrip())) + pixmap = fitz.utils.get_pixmap( + page=page, + matrix=fitz.Matrix(5, 5), + clip=span["bbox"], + colorspace=fitz.csGRAY, + ) + + bpp = get_bytes_per_pixel(pixmap.colorspace, pixmap.alpha) + api.SetImageBytes( + imagedata=pixmap.samples, + width=pixmap.w, + height=pixmap.h, + bytes_per_pixel=bpp, + bytes_per_line=pixmap.stride, + ) + api.SetPageSegMode(tesserocr.PSM.RAW_LINE) + api.Recognize() + ocr_text = api.GetUTF8Text().rstrip() + + old_text, text = text, ls + ocr_text + rs + logger.debug(f"Replaced {old_text!r} with {text!r}") + + page_nodes.append(TextNode.from_span(span, text=text)) + elif block["type"] == 1: + page_nodes.append(ImageNode.from_page_block(block)) + else: + raise ValueError(f"Unknown block type {block['type']}") + + # Find all edges + page_edges = PageEdges.from_page_nodes_as_complete_graph(page_nodes) + + node_features = page_nodes.to_node_features() + edge_index = page_edges.to_edge_index().t() + edge_features = page_edges.to_edge_features() + # print("node_features.shape", node_features.shape, "edge_index.shape", edge_index.shape, "edge_features.shape", edge_features.shape) + + if edge_index.shape[0] == 0: + continue + + example = Data( + node_features=node_features, + edge_index=edge_index, + edge_features=edge_features, + ) + + with torch.no_grad(): + node_class_scores, edge_class_scores = model(example) + print("node_class_scores", node_class_scores.shape, "edge_class_scores", edge_class_scores.shape) + + edge_prob_threshold = 0.5 + graph = nx.Graph() + for k in range(example.edge_index.shape[1]): + src_node_i = example.edge_index[0, k].item() + dst_node_i = example.edge_index[1, k].item() + edge_prob = edge_class_scores[k].item() + + if edge_prob >= edge_prob_threshold: + graph.add_edge(src_node_i, dst_node_i, weight=edge_prob) + else: + graph.add_node(src_node_i) + graph.add_node(dst_node_i) + + clusters: list[set[int]] = list(nx.connected_components(graph)) + cluster_min_spanning_boxes: list[Polygon] = [ + Polygon([ + (min(page_nodes[node_i].bbox_min_x for node_i in cluster), min(page_nodes[node_i].bbox_min_y for node_i in cluster)), + (max(page_nodes[node_i].bbox_max_x for node_i in cluster), min(page_nodes[node_i].bbox_min_y for node_i in cluster)), + (max(page_nodes[node_i].bbox_max_x for node_i in cluster), max(page_nodes[node_i].bbox_max_y for node_i in cluster)), + (min(page_nodes[node_i].bbox_min_x for node_i in cluster), max(page_nodes[node_i].bbox_max_y for node_i in cluster)), + ]) + for cluster in clusters + ] + cluster_classes: list[int] = torch.stack([node_class_scores[torch.tensor(list(cluster))].sum(dim=0) for cluster in clusters]).argmax(dim=1).tolist() + + print("clusters", clusters) + print("cluster_min_spanning_boxes", cluster_min_spanning_boxes) + print("cluster_classes", cluster_classes) + + +if __name__ == '__main__': + main() diff --git a/dln_glam_prepare.py b/dln_glam_prepare.py new file mode 100644 index 0000000..37bb4f3 --- /dev/null +++ b/dln_glam_prepare.py @@ -0,0 +1,513 @@ +import argparse +import io +import json +import logging +import multiprocessing +import os +import pickle +import sys +import time +from collections import defaultdict +from typing import Mapping, Optional, Any, Iterable + +import networkx as nx +import numpy as np +import psutil +import torch +import torch.nn.functional as F +import torch_geometric +from PIL import Image, ImageDraw +from scipy.spatial import ConvexHull +from shapely.geometry import Polygon, box +from torch_geometric.data import Data +from tqdm import tqdm +from tqdm.contrib.logging import logging_redirect_tqdm + +from GLAM.common import PageNodes, PageEdges, TextNode, ImageNode + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) +# logger.addHandler(logging.StreamHandler()) + + +CLASSES_MAP = { + 0: "Unknown", + 1: "Caption", + 2: "Footnote", + 3: "Formula", + 4: "List-item", + 5: "Page-footer", + 6: "Page-header", + 7: "Picture", + 8: "Section-header", + 9: "Table", + 10: "Text", + 11: "Title", +} + + +def iou(polygon1: Polygon, polygon2: Polygon) -> float: + """Calculate intersection over union between two polygons. May return NaN if polygons are invalid.""" + intersection = polygon1.intersection(polygon2).area + union = polygon1.union(polygon2).area + iou = intersection / union + return iou + + +# # @lru_cache(maxsize=1) +# def compare_segmentations(segmentations1, segmentations2): +# """Function to compare segmentations and return a similarity score""" +# if not segmentations1 or not segmentations2: +# return 0 # No segmentations to compare +# +# max_similarity = 0 +# for seg1 in segmentations1: +# polygon1 = Polygon([(seg1[i], seg1[i + 1]) for i in range(0, len(seg1), 2)]) +# for seg2 in segmentations2: +# polygon2 = Polygon([(seg2[i], seg2[i + 1]) for i in range(0, len(seg2), 2)]) +# max_similarity = max(max_similarity, compare_polygons(polygon1, polygon2)) +# +# return max_similarity + + +class DLNDataset(torch_geometric.data.Dataset): + index_to_example_filename: Mapping[int, str] + + def __init__(self, root, split_name, transform=None, pre_transform=None): + super().__init__(root, transform, pre_transform) + + self.split_name = split_name + image_id_to_example_filename = { + int(filename.split(".")[0]): filename + for filename in os.listdir(os.path.join(self.root, split_name)) + } + self.index_to_example_filename = { + i: image_id_to_example_filename[image_id] + for i, image_id in enumerate(sorted(image_id_to_example_filename)) + } + + # @property + # def raw_file_names(self): + # return ['some_file_1', 'some_file_2', ...] + + @property + def processed_file_names(self): + return tuple(self.index_to_example_filename.values()) + # + # def download(self): + # # Download to `self.raw_dir`. + # path = download_url(url, self.raw_dir) + # ... + # + # def process(self): + # idx = 0 + # for raw_path in self.raw_paths: + # # Read data from `raw_path`. + # data = Data(...) + # + # if self.pre_filter is not None and not self.pre_filter(data): + # continue + # + # if self.pre_transform is not None: + # data = self.pre_transform(data) + # + # torch.save(data, osp.join(self.processed_dir, f'data_{idx}.pt')) + # idx += 1 + + def len(self): + return len(self.processed_file_names) + + def get_filepath(self, idx): + example_filename = self.index_to_example_filename[idx] + return os.path.join(self.root, self.split_name, example_filename) + + def get(self, idx): + example = torch.load(self.get_filepath(idx)) + # assert example.node_features.isfinite().all() + # assert example.edge_features.isfinite().all() + return example + + +def process_row(dataset_dir, output_dir, split_name, image_id) -> Optional[Data]: + id_dir = os.path.join(dataset_dir, split_name, "by-id", str(image_id)) + pdf_filepath = os.path.join(id_dir, "page.pdf") + image_filepath = os.path.join(id_dir, "image.webp") + row_filepath = os.path.join(id_dir, "row.json") + annotations_filepath = os.path.join(id_dir, "annotations.json") + page_dict_filepath = os.path.join(id_dir, "page_dict.pkl") + + with open(row_filepath, "rb") as f: + row = json.load(f) + + with open(annotations_filepath, "rb") as f: + annotations = json.load(f) + + with open(page_dict_filepath, "rb") as f: + page_dict = pickle.load(f) + + # with open(image_filepath, "rb") as f: + # image = Image.open(f).convert("L").convert("RGBA") + + nodes = PageNodes() + for block in page_dict["blocks"]: + if block["type"] == 0: + for line in block["lines"]: + for span in line["spans"]: + nodes.append(TextNode.from_span(span)) + elif block["type"] == 1: + try: + nodes.append(ImageNode.from_page_block(block)) + except ValueError as e: + logger.warning(f"{split_name}/{image_id}: Could not parse image block {block}: {e}") + else: + raise ValueError(f"{split_name}/{image_id}: Unknown block type {block['type']}") + + if len(nodes) <= 0: + logger.warning(f"{split_name}/{image_id}: Skipping: No nodes found") + return + + if len(nodes) == 1: + logger.warning(f"{split_name}/{image_id}: Skipping: Only one node found, cannot make edges") + return + + if len(nodes) > 1024: + logger.warning(f"{split_name}/{image_id}: Skipping: Too many nodes ({len(nodes)}), slow to process") + return + + edges = PageEdges.from_page_nodes_by_directions(nodes, k=31) + # edges = PageEdges.from_page_nodes_by_top_closest(nodes, k=4+1) + + segmentations: list[Polygon] = [ + Polygon([ + (segmentation[i1], segmentation[i1 + 1]) + for i1 in range(0, len(segmentation), 2) + ]) + for annotation in annotations + for segmentation in annotation["segmentation"] + ] + + # Calculate probabilities for each class for each node. This is a target for the node classification model. + node_probs = torch.zeros(len(nodes), len(CLASSES_MAP), dtype=torch.float32) + node_segmentations: dict[int, list[int]] = {} # Mapping from node index to a list of segmentation indices + uncovered_segmentations = set(range(len(segmentations))) # Set of uncovered segmentations for cleaning dataset + for node_i, node in enumerate(nodes): + # Get bounding box + node_bbox = box(node.bbox_min_x, node.bbox_min_y, node.bbox_max_x, node.bbox_max_y) + if node_bbox.area <= 0: + logger.warning(f"{split_name}/{image_id}: Node {node} has zero area.") + continue + + # Iterate over the segmentations + segmentation_i = 0 + for annotation in annotations: + for _ in annotation["segmentation"]: + overlap_area = segmentations[segmentation_i].intersection(node_bbox).area + if overlap_area > 0: + # Weighted votes by the proportion of overlap + node_probs[node_i][annotation["category_id"]] += overlap_area / node_bbox.area + node_segmentations.setdefault(node_i, []).append(segmentation_i) + uncovered_segmentations.discard(segmentation_i) + segmentation_i += 1 + + # Normalize node_probs by the number of segmentations for each node + node_probs /= torch.tensor([len(node_segmentations.get(node_i, [None])) for node_i in range(len(nodes))], dtype=torch.float32).unsqueeze(1) + # if node_probs.max() < 0.95: + # logger.debug(f"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!11") + # logger.debug(f"node_probs: {node_probs.min():.2f}, {node_probs.max():.2f}, {node_probs.mean():.2f}, {node_probs.std():.2f}") + + unlabelled_nodes = [node_i for node_i, node in enumerate(nodes) if node_i not in node_segmentations] + + # # Debug draw + # if unlabelled_nodes or uncovered_segmentations: + # logger.warning(f"pdf_filepath: {pdf_filepath}, unlabelled_nodes: {unlabelled_nodes}, uncovered_segmentations: {uncovered_segmentations}") + # # Render page + # #image = Image.open(io.BytesIO(data["image"]["bytes"])).convert("L").convert("RGBA") + # overlay = Image.new("RGBA", image.size, (0, 0, 0, 0)) + # draw = ImageDraw.Draw(overlay) + # + # # Render segmentation + # for annotation in annotations: + # segmentation = annotation["segmentation"][0] + # segmentation_polygon = Polygon(( + # (segmentation[i], segmentation[i + 1]) + # for i in range(0, len(segmentation), 2) + # )) + # draw.polygon(segmentation_polygon.exterior.coords, fill=(200, 200, 200, 200), outline="magenta", width=4) + # # area = segmentation_polygon.intersection(node_bbox).area + # # print("area:", area, "node_bbox.area:", node_bbox.area, "area / node_bbox.area:", area / node_bbox.area) + # + # image = Image.alpha_composite(image, overlay) + # draw = ImageDraw.Draw(image) + # + # # Render bad nodes + # for node_i in nodes: + # outline = "green" if isinstance(node_i, TextNode) else "blue" + # draw.rectangle((node_i.bbox_min_x, node_i.bbox_min_y, node_i.bbox_max_x, node_i.bbox_max_y), outline=outline, width=1) + # for node_i in unlabelled_nodes: + # outline = "red" if isinstance(node_i, TextNode) else "orange" + # draw.rectangle((node_i.bbox_min_x, node_i.bbox_min_y, node_i.bbox_max_x, node_i.bbox_max_y), outline=outline, width=1) + # + # image.show() + # print("showing", len(unlabelled_nodes) / len(nodes)) + # # breakpoint() + # # time.sleep(3) + + if len(unlabelled_nodes) >= 5: + logger.warning(f"{split_name}/{image_id}: Skipping: Too many # unlabelled nodes ({len(unlabelled_nodes)}, {len(unlabelled_nodes) / len(nodes):.2f}%)") + return + + if len(unlabelled_nodes) >= 3 and len(unlabelled_nodes) / len(nodes) > 0.05: + logger.warning(f"{split_name}/{image_id}: Skipping: Too many % unlabelled nodes ({len(unlabelled_nodes)}, {len(unlabelled_nodes) / len(nodes):.2f}%)") + return + + node_features = nodes.to_node_features() + edge_index = edges.to_edge_index().t() + edge_features = edges.to_edge_features() + + # Calculate probability of same segmentation for each edge. This is a target for the edge classification model. + edge_probs = torch.zeros(edge_index.shape[1]) + edge_connections: dict[int, list[tuple[int, int]]] = {} # Mapping from edge index to a list of (segmentation_i1, segmentation_i2) tuples + + def precompute_iou(polygons): + n = len(polygons) + ious = np.zeros((n, n)) + for i in range(n): + ious[i, i] = 1 + for j in range(i + 1, n): + ious[i, j] = ious[j, i] = iou(polygons[i], polygons[j]) + return ious + + segmentation_ious = precompute_iou(segmentations) + + # Iterate over the edges to label them + for k in range(edge_index.shape[1]): + src_node_i = edge_index[0, k].item() + dst_node_i = edge_index[1, k].item() + + # Get segmentations for both nodes + src_node_segmentations = node_segmentations.get(src_node_i, []) + dst_node_segmentations = node_segmentations.get(dst_node_i, []) + + # Calculate similarity between segmentations + for src_node_segmentation in src_node_segmentations: + for dst_node_segmentation in dst_node_segmentations: + segmentation_iou = segmentation_ious[src_node_segmentation, dst_node_segmentation].item() + if segmentation_iou > 0: + edge_probs[k] += segmentation_iou + edge_connections.setdefault(k, []).append((src_node_segmentation, dst_node_segmentation)) + + # Normalize edge_probs + edge_probs /= torch.tensor([len(edge_connections.get(k, [None])) for k in range(edge_index.shape[1])], dtype=torch.float32) + # logger.debug(f"edge_probs: {edge_probs.min():.2f}, {edge_probs.max():.2f}, {edge_probs.mean():.2f}, {edge_probs.std():.2f}") + + example = Data( + split_name=split_name, # metadata + image_id=image_id, # metadata + node_features=node_features, # input + edge_index=edge_index, # input + edge_features=edge_features, # input + node_probs=node_probs, # target + edge_probs=edge_probs, # target + ) + + ######################################### + # Calculate mAP@IoU[0.5:0.95:0.05] + ######################################### + # def single_image_evaluation(bbox_preds, class_preds, bbox_gts, class_gts, iou_threshold): + # """Evaluate a single image for precision and recall at a specific IoU threshold.""" + # tp = 0 + # fp = 0 + # gt_used = [False] * len(bbox_gts) + # for bbox_pred, class_pred in zip(bbox_preds, class_preds): + # matched = False + # for i, (bbox_gt, class_gt) in enumerate(zip(bbox_gts, class_gts)): + # if gt_used[i] or class_pred != class_gt: + # continue + # if iou(bbox_pred, bbox_gt) >= iou_threshold: + # gt_used[i] = True + # tp += 1 + # matched = True + # break + # if not matched: + # fp += 1 + # + # fn = sum(1 for used in gt_used if not used) + # precision = tp / (tp + fp) if (tp + fp) > 0 else 0 + # recall = tp / (tp + fn) if (tp + fn) > 0 else 0 + # return precision, recall + # + # def mean_average_precision(bbox_preds, class_preds, bbox_gts, class_gts, iou_thresholds: Iterable[float] = np.arange(0.5, 1, 0.05)): + # """Calculate the mean Average Precision over a range of IoU thresholds.""" + # average_precisions = [] + # for iou_threshold in iou_thresholds: + # precisions = [] + # recalls = [] + # # for preds, gts in zip(predictions, ground_truths): + # for bbox_pred, class_pred, bbox_gt, class_gt in zip(bbox_preds, class_preds, bbox_gts, class_gts): + # precision, recall = single_image_evaluation(bbox_preds, class_preds, bbox_gts, class_gts, iou_threshold) + # precisions.append(precision) + # recalls.append(recall) + # ap = np.mean(precisions) # Average precision for this IoU threshold + # average_precisions.append(ap) + # return np.mean(average_precisions) # Mean over all IoU thresholds + + # edge_prob_threshold = 0.5 + # graph = nx.Graph() + # for k in range(example.edge_index.shape[1]): + # src_node_i = example.edge_index[0, k].item() + # dst_node_i = example.edge_index[1, k].item() + # edge_prob = example.edge_probs[k].item() + # + # if edge_prob >= edge_prob_threshold: + # graph.add_edge(src_node_i, dst_node_i, weight=edge_prob) + # else: + # graph.add_node(src_node_i) + # graph.add_node(dst_node_i) + # + # clusters: list[set[int]] = list(nx.connected_components(graph)) + # # cluster_min_spanning_boxes: list[Polygon] = [ + # # Polygon([ + # # (min(nodes[node_i].bbox_min_x for node_i in cluster), min(nodes[node_i].bbox_min_y for node_i in cluster)), + # # (max(nodes[node_i].bbox_max_x for node_i in cluster), min(nodes[node_i].bbox_min_y for node_i in cluster)), + # # (max(nodes[node_i].bbox_max_x for node_i in cluster), max(nodes[node_i].bbox_max_y for node_i in cluster)), + # # (min(nodes[node_i].bbox_min_x for node_i in cluster), max(nodes[node_i].bbox_max_y for node_i in cluster)), + # # ]) + # # for cluster in clusters + # # ] + # cluster_classes: list[int] = torch.stack([example.node_probs[torch.tensor(list(cluster))].sum(dim=0) for cluster in clusters]).argmax(dim=1).tolist() + # + # node_class_accuracy = 0 + # for cluster, cluster_class in zip(clusters, cluster_classes): + # for node_index in cluster: + # node_class = example.node_probs[node_index].argmax().item() + # if node_class == cluster_class: + # node_class_accuracy += 1 + # node_class_accuracy /= len(nodes) + # logger.debug(f"{split_name}/{image_id}: node_class_accuracy: {node_class_accuracy:.2f}, len(unlabelled_nodes): {len(unlabelled_nodes)}, len(nodes): {len(nodes)}, len(uncovered_segmentations): {len(uncovered_segmentations)}, len(annotation): {len(annotations)}, len(segmentations): {len(segmentations)}, len(clusters): {len(clusters)}") + + # # Render page + # if node_class_accuracy < 0.98: + # draw = ImageDraw.Draw(image) + # + # # for annotation in annotations: + # # for segmentation in annotation["segmentation"]: + # # draw.polygon(segmentation, outline=(0, 0, 255), width=6) + # + # for cluster, cluster_class in zip(clusters, cluster_classes): + # cluster_bbox = ( + # min(example.node_features[node_i][2] for node_i in cluster), + # min(example.node_features[node_i][3] for node_i in cluster), + # max(example.node_features[node_i][4] for node_i in cluster), + # max(example.node_features[node_i][5] for node_i in cluster), + # ) + # draw.rectangle(cluster_bbox, outline=(0, 255, 0), width=3) + # draw.text(cluster_bbox[:2], CLASSES_MAP[cluster_class], fill=(0, 0, 0)) + # + # # for k, node_features in zip(range(example.node_features.size(0)), example.node_features): + # # node_bbox = (node_features[2], node_features[3], node_features[4], node_features[5]) + # # draw.rectangle(node_bbox, outline=(255, 0, 0), width=1) + # # + # # for annotation in annotations: + # # for segmentation in annotation["segmentation"]: + # # draw.text(segmentation, CLASSES_MAP[annotation["category_id"]], fill=(0, 0, 0)) + # + # logger.debug(f"{split_name}/{image_id}: node_class_accuracy: {node_class_accuracy}") + # image.show(title=f"{split_name}/{image_id}") + # breakpoint() + + return example + + +def process(dataset_dir, output_dir, split_name, image_id) -> tuple[Optional[Data], Any]: + return process_row(dataset_dir, output_dir, split_name, image_id), (dataset_dir, output_dir, split_name, image_id) + + +def main(): + parser = argparse.ArgumentParser("DocLayNet dataset") + parser.add_argument("--dataset-path", type=str, default="/home/i/dataset/DocLayNet/raw/DocLayNet/DATA", + help="Directory for the raw dataset (default: %(default)s)") + parser.add_argument("--output-path", type=str, default="/home/i/dataset/DocLayNet/glam", + help="Directory for the processed dataset (default: %(default)s)") + args = parser.parse_args() + + split_names = ["train", "test", "val"] + # split_names = ["val"] + split_image_ids = {} + + for split_name in split_names: + image_ids = os.listdir(os.path.join(args.dataset_path, split_name, "by-id")) + image_ids = [int(x.split(".")[0]) for x in image_ids] + image_ids = sorted(image_ids) + split_image_ids[split_name] = image_ids + + num_processes = psutil.cpu_count(logical=False) + # num_processes = 1 + logger.debug(f"Using {num_processes} processes.") + tasks_in_pool = 0 + max_tasks_in_pool = 100 + num_processes + + pbar = tqdm(desc=f"Processing...", total=sum(len(image_ids) for image_ids in split_image_ids.values()), smoothing=0.001, position=0, leave=False) + + with logging_redirect_tqdm(), multiprocessing.Pool(num_processes) as pool: + def callback(result): + nonlocal tasks_in_pool + tasks_in_pool -= 1 + pbar.update(1) + + example, context = result + dataset_dir, output_path, split_name, image_id = context + + if not example: + return + + assert example.node_features.isfinite().all() + assert example.edge_features.isfinite().all() + + example_filepath = os.path.join(output_path, f'{image_id}.pt') + torch.save(example, example_filepath) + + def my_error_callback(e): + nonlocal tasks_in_pool + tasks_in_pool -= 1 + pbar.update(1) + # logger.exception(e) + + for split_name in split_names: + output_path = os.path.join(args.output_path, split_name) + os.makedirs(output_path, exist_ok=True) + + image_ids = split_image_ids[split_name] + for image_id in image_ids: + example_filepath = os.path.join(output_path, f'{image_id}.pt') + if os.path.exists(example_filepath): + pbar.update(1) + # pbar.total -= 1 + continue + + while tasks_in_pool >= max_tasks_in_pool: + time.sleep(0.1) + + tasks_in_pool += 1 + pool.apply_async(process, args=(args.dataset_path, output_path, split_name, image_id), callback=callback, error_callback=my_error_callback) + # callback(process(args.dataset_path, output_path, split_name, image_id)) + + while tasks_in_pool > 0: + pbar.refresh() + print("Tasks in pool:", tasks_in_pool) + print("Waiting for following tasks:") + # print(pool._cache) + print(pool._taskqueue) + time.sleep(1) + + pool.close() + pool.join() + + pbar.refresh() + pbar.close() + + print("Done.") + + +if __name__ == '__main__': + main() diff --git a/dln_glam_train.py b/dln_glam_train.py new file mode 100644 index 0000000..06cf473 --- /dev/null +++ b/dln_glam_train.py @@ -0,0 +1,177 @@ +import logging +import random +import time + +import networkx as nx +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.data +import torch_geometric +import torch_geometric.nn.inits +import tqdm +from tqdm.contrib.logging import logging_redirect_tqdm + +from GLAM.common import PageNodes, PageEdges +from GLAM.models import GLAMGraphNetwork +from dln_glam_prepare import DLNDataset, CLASSES_MAP + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +class Stopwatch: + def __init__(self): + self.elapsed = 0. + + def __enter__(self): + self.start = time.time() + return self + + def __exit__(self, *args): + self.end = time.time() + self.elapsed = self.end - self.start + + def __str__(self): + return f"{self.elapsed:.4f} seconds" + + +def set_seed(n): + torch.manual_seed(n) + random.seed(n) + np.random.seed(n) + + +def data_reset_dtype(data: torch_geometric.data.data.BaseData) -> torch_geometric.data.data.BaseData: + data.node_features = data.node_features.to(torch.float32) + data.edge_index = data.edge_index.to(torch.int64) + data.edge_features = data.edge_features.to(torch.float32) + data.node_probs = data.node_probs.to(torch.float32) + data.edge_probs = data.edge_probs.to(torch.float32) + return data + + +def main(): + device = ("cuda" if torch.cuda.is_available() else "cpu") + + # Create or load model + model_filepath = "models/glam_dln.pt" + set_seed(42) + + # if os.path.exists(model_filepath): + # model = glam.glam.GLAMGraphNetwork(PageNodes.features_len, PageEdges.features_len, 512, len(CLASSES_MAP)) + # model.load_state_dict(torch.load(model_filepath)) + # else: + model = GLAMGraphNetwork(PageNodes.features_len, PageEdges.features_len, 512, len(CLASSES_MAP)) + model = model.to(device) + + # TODO: normalize + # transforms = torch_geometric.transforms.Compose([ + # torch_geometric.transforms.NormalizeFeatures(attrs=['node_features', 'edge_features']), + # ]) + + train_dataset = DLNDataset("/home/i/dataset/DocLayNet/glam", 'train', transform=None, pre_transform=None) + val_dataset = DLNDataset("/home/i/dataset/DocLayNet/glam", 'val', transform=None, pre_transform=None) + # test_dataset = DLNDataset("/home/i/dataset/DocLayNet/glam", 'test', transform=None, pre_transform=None) + + train_loader = torch_geometric.loader.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=8) + val_loader = torch_geometric.loader.DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=8) + # test_loader = torch_geometric.loader.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0) + + # Train parameters + epochs = 1 + + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2) + edge_loss_scale = 4 + + # Progress bars + pbar_epoch = tqdm.tqdm(total=epochs, unit="epoch", position=0) + pbar_train = tqdm.tqdm(total=train_loader.__len__(), unit="batch", position=1) + pbar_val = tqdm.tqdm(total=val_loader.__len__(), unit="batch", position=2) + + # Train + def closure() -> float: + optimizer.zero_grad() + + node_class_scores, edge_class_scores = model(example) + + node_class_loss = F.cross_entropy(node_class_scores, example.node_probs) # multi-class classification problem + edge_class_loss = F.binary_cross_entropy_with_logits(edge_class_scores, example.edge_probs[..., None]) # multi-label classification problem + loss = node_class_loss + edge_loss_scale * edge_class_loss + loss.backward() + return loss.item() + + with logging_redirect_tqdm(): + for epoch in range(epochs): + # Train + pbar_train.reset() + model.train() + for i, data in enumerate(train_loader.__iter__()): + assert isinstance(data, torch_geometric.data.Data) + example = data.clone() + example = data_reset_dtype(example) + example = example.to(device) + + loss = optimizer.step(closure) + logger.info(f"Loss: {loss:.4f}") + pbar_train.update(1) + + # Validation + pbar_val.reset() + model.eval() + with torch.no_grad(): + for i, data in enumerate(val_loader.__iter__()): + assert isinstance(data, torch_geometric.data.Data) + example = data.clone() + example = data_reset_dtype(example) + example = example.to(device) + + node_class_scores, edge_class_scores = model(example) + node_class_loss = F.cross_entropy(node_class_scores, example.node_probs) # multi-class classification problem + edge_class_loss = F.binary_cross_entropy_with_logits(edge_class_scores, example.edge_probs[..., None]) # multi-label classification problem + loss = node_class_loss + edge_loss_scale * edge_class_loss + + edge_prob_threshold = 0.5 + graph = nx.Graph() + for k in range(example.edge_index.shape[1]): + src_node_i = example.edge_index[0, k].item() + dst_node_i = example.edge_index[1, k].item() + edge_prob = example.edge_probs[k].item() + + if edge_prob >= edge_prob_threshold: + graph.add_edge(src_node_i, dst_node_i, weight=edge_prob) + else: + graph.add_node(src_node_i) + graph.add_node(dst_node_i) + + clusters: list[set[int]] = list(nx.connected_components(graph)) + # cluster_min_spanning_boxes: list[Polygon] = [ + # Polygon([ + # # (min(nodes[node_i].bbox_min_x for node_i in cluster), min(nodes[node_i].bbox_min_y for node_i in cluster)), + # # (max(nodes[node_i].bbox_max_x for node_i in cluster), min(nodes[node_i].bbox_min_y for node_i in cluster)), + # # (max(nodes[node_i].bbox_max_x for node_i in cluster), max(nodes[node_i].bbox_max_y for node_i in cluster)), + # # (min(nodes[node_i].bbox_min_x for node_i in cluster), max(nodes[node_i].bbox_max_y for node_i in cluster)), + # (min(example.node_features[cluster, 2]), min(example.node_features[cluster, 3])), + # (max(example.node_features[cluster, 4]), min(example.node_features[cluster, 3])), + # (max(example.node_features[cluster, 4]), max(example.node_features[cluster, 5])), + # (min(example.node_features[cluster, 2]), max(example.node_features[cluster, 5])), + # ]) + # for cluster in clusters + # ] + cluster_classes: list[int] = torch.stack([example.node_probs[torch.tensor(list(cluster))].sum(dim=0) for cluster in clusters]).argmax(dim=1).tolist() + + pbar_val.update(1) + + pbar_epoch.update(1) + + pbar_val.close() + pbar_train.close() + pbar_epoch.close() + + # Save model + torch.save(model.state_dict(), model_filepath) + + +if __name__ == '__main__': + main() diff --git a/dln_parse_pdf.py b/dln_parse_pdf.py new file mode 100644 index 0000000..7ea190e --- /dev/null +++ b/dln_parse_pdf.py @@ -0,0 +1,304 @@ +import argparse +import io +import json +import logging +import multiprocessing +import os +import pickle +import signal +import time +from typing import Optional + +import qoi +import fitz +import numpy as np +import polars as pl +import psutil +import tesserocr +from PIL import Image +from tesserocr import PyTessBaseAPI +from tesserocr import get_languages +from tqdm import tqdm + +from GLAM.common import get_bytes_per_pixel, truncate_with_ellipsis, pixmap_to_image, pixmap_to_ndarray + +# logging.basicConfig(filename='dln.log', level=logging.DEBUG) +logger = logging.getLogger(__name__) +logger_stream_handler = logging.StreamHandler() +logger.addHandler(logger_stream_handler) +logger.setLevel(logging.DEBUG) + +INVALID_UNICODE = chr(0xFFFD) + +tessdata_prefix = os.environ.get("TESSDATA_PREFIX", "/usr/share/tesseract/tessdata") +tesseract_languages_required = ["eng", "deu", "fra", "jpn"] +api = PyTessBaseAPI(path=tessdata_prefix, lang="+".join(tesseract_languages_required)) + + +CLASSES_MAP = { + 1: "Caption", + 2: "Footnote", + 3: "Formula", + 4: "List-item", + 5: "Page-footer", + 6: "Page-header", + 7: "Picture", + 8: "Section-header", + 9: "Table", + 10: "Text", + 11: "Title", +} + + +def worker_init(): + # global api + # api = PyTessBaseAPI(path=tessdata_prefix, lang="+".join(tesseract_languages_required)) + pass + + +def pdf_extract( + pdf_filepath: str, + scale: float = 1, +) -> list[tuple[np.ndarray, int, int, dict]]: + """Returns a list of tuples (image_webp, width, height, page_dict)""" + doc = fitz.Document(pdf_filepath) + result: list[tuple[np.ndarray, int, int, dict]] = [] + + for page_i in range(doc.page_count): + page: fitz.Page = doc.load_page(page_i) + page_dict = fitz.utils.get_text(page=page, option="dict", clip=page.rect, flags=fitz.TEXT_PRESERVE_IMAGES) + + # Filter out empty image blocks + page_dict["blocks"] = [ + block + for block in page_dict["blocks"] + if not (block["type"] == 1 and len(block["image"]) == 0) + ] + # Filter out empty spans and resolve invalid unicode + for block in page_dict["blocks"]: + if block["type"] == 0: + for line in block["lines"]: + for span in line["spans"]: + text = span["text"] + + if INVALID_UNICODE in text: + page_pixmap = fitz.utils.get_pixmap( + page=page, + matrix=fitz.Matrix(5, 5), # 360 dpi + clip=span["bbox"], + colorspace=fitz.csRGB, + alpha=False, + ) + if page_pixmap.samples_ptr != 0: + bpp = get_bytes_per_pixel(page_pixmap.colorspace, page_pixmap.alpha) + api.SetImageBytes( + imagedata=page_pixmap.samples, + width=page_pixmap.w, + height=page_pixmap.h, + bytes_per_pixel=bpp, + bytes_per_line=page_pixmap.stride, + ) + api.SetPageSegMode(tesserocr.PSM.RAW_LINE) + api.Recognize() + ocr_text = api.GetUTF8Text().rstrip() + + ls = " " * (len(text) - len(text.lstrip())) + rs = " " * (len(text) - len(text.rstrip())) + old_text, text = text, ls + ocr_text + rs + span["text_ocr"] = text + # logger.debug(f"Replaced {old_text!r} with {text!r}") + + # Use list comprehension to filter empty spans + line["spans"] = [span for span in line["spans"] if not span["text"].strip() == ""] + elif block["type"] == 1: + # QOI + block_image: Image.Image = Image.open(io.BytesIO(block["image"])) + block_image: np.ndarray = np.array(block_image) # Makes a copy + if block_image.ndim in (3, 4): + image_qoi = qoi.encode(block_image) + else: + image_qoi = None + + # WebP + try: + block_image = Image.open(io.BytesIO(block["image"])) + block_webp_buffer = io.BytesIO() + block_image.save(block_webp_buffer, format="WEBP", lossless=True, quality=100, method=1) + image_webp = block_webp_buffer.getvalue() + except OSError as e: + print(f"Failed to open image block {truncate_with_ellipsis(str(block), 128)} in pdf {pdf_filepath}: {e}") + image_webp = None + + # Select the best image format + smallest_image = block["image"] + if image_qoi is not None and len(image_qoi) < len(smallest_image): + smallest_image = image_qoi + if image_webp is not None and len(image_webp) < len(smallest_image): + smallest_image = image_webp + block["image"] = smallest_image + else: + raise ValueError(f"Unknown block type {block['type']} in pdf {pdf_filepath}") + + if scale != 0: + page_pixmap = fitz.utils.get_pixmap( + page=page, + matrix=fitz.Matrix(scale, scale), + colorspace=fitz.csRGB, + alpha=False, + ) + assert page_pixmap.samples_ptr != 0 + image = pixmap_to_ndarray(page_pixmap) + else: + image = None + + result.append((image, page.rect.width, page.rect.height, page_dict)) + + return result + + +def process( + pdf_filepath: str, + scale: float, + split_name: str, + row: dict, +): + try: + return pdf_extract(pdf_filepath, scale), (pdf_filepath, scale, split_name, row) + except Exception as e: + print(f"pdf_ser failed: {e}") + return None + + +def main(): + # signal.signal(signal.SIGINT, lambda sig, frame: exit(sig)) + + parser = argparse.ArgumentParser("DocLayNet dataset preparation. Using paper proposed dataset splits.") + parser.add_argument("--dataset-path", type=str, default="/home/i/dataset/DocLayNet/raw/DocLayNet", + help="Directory for the raw dataset (default: %(default)s)") + parser.add_argument("--image-scale", type=float, default=1, + help="Set scaling factor for an image. A scale of 1 is 72 dpi. (default: %(default)s)") + args = parser.parse_args() + + print("Processing DocLayNet dataset") + split_names = ["train", "test", "val"] + + num_processes = psutil.cpu_count(logical=False) + logger.debug(f"Using {num_processes} processes.") + tasks_in_pool = 0 + max_tasks_in_pool = 100 + num_processes + + pbar = tqdm(desc=f"Processing...", smoothing=0.001) + + with multiprocessing.Pool(num_processes, initializer=worker_init) as pool: + def callback(result): + nonlocal tasks_in_pool + tasks_in_pool -= 1 + pbar.update(1) + + if result is None: + return + + example, (orig_pdf_filepath, scale, split_name, row) = result + assert len(example) == 1, f"Expected 1 page, got {len(result)} pages" + image, width, height, page_dict = example[0] + + id_file = os.path.join(args.dataset_path, "DATA", split_name, "by-id", str(row["id"])) + os.makedirs(id_file, exist_ok=True) + pdf_filepath = os.path.join(id_file, "page.pdf") + row_filepath = os.path.join(id_file, "row.json") + webp_filepath = os.path.join(id_file, "image.webp") + qoi_filepath = os.path.join(id_file, "image.qoi") + page_dict_filepath = os.path.join(id_file, "page_dict.pkl") + annotations_filepath = os.path.join(id_file, "annotations.json") + + # Convert annotations to original coordinates + scale_x = width / row["width"] + scale_y = height / row["height"] + annotations = [] + for ann_id in image_id_to_annotations_index.get(row["id"], []): + ann = split_coco['annotations'][ann_id] + for b in range(0, len(ann['bbox']), 2): + ann['bbox'][b] *= scale_x + ann['bbox'][b + 1] *= scale_y + for seg in ann['segmentation']: + for s in range(0, len(seg), 2): + seg[s] *= scale_x + seg[s + 1] *= scale_y + annotations.append(ann) + + with open(annotations_filepath, "w", encoding="utf-8") as f: + json.dump(annotations, f) + + if image is not None: + # Save image as QOI + _ = qoi.write(qoi_filepath, image) + + # Save image as WebP + image = Image.fromarray(image) + image.save(webp_filepath, format="WEBP", lossless=True, quality=100, method=1) + + if page_dict is not None: + with open(page_dict_filepath, "wb") as f: + pickle.dump(page_dict, f) + + # Save row + row["width"] = width + row["height"] = height + with open(row_filepath, "w", encoding="utf-8") as f: + json.dump(row, f) + + # Hard link PDF from pdf_filepath to orig_pdf_filepath + try: + os.unlink(pdf_filepath) + except FileNotFoundError: + pass + os.link(orig_pdf_filepath, pdf_filepath) + + def my_error_callback(e): + nonlocal tasks_in_pool + tasks_in_pool -= 1 + pbar.update(1) + # logger.exception(e) + + for split_name in split_names: + coco_filepath = os.path.join(args.dataset_path, "COCO", f"{split_name}.json") + with open(coco_filepath, "r", encoding="utf-8") as f: + split_coco = json.load(f) + + image_id_to_annotations_index = {} + for i, ann in enumerate(split_coco['annotations']): + image_id_to_annotations_index.setdefault(ann['image_id'], []).append(i) + + pbar.reset(total=len(split_coco["images"])) + + for row in split_coco["images"]: + page_hash = row["file_name"][:-4] + id_file = os.path.join(args.dataset_path, "DATA", split_name, "by-id", str(row["id"])) + row_filepath = os.path.join(id_file, "row.json") + pdf_filepath = os.path.join(args.dataset_path, "PDF", page_hash + ".pdf") + + # Skip if already processed + if os.path.exists(row_filepath): + pbar.update(1) + continue + + while tasks_in_pool >= max_tasks_in_pool: + time.sleep(0.1) + + tasks_in_pool += 1 + pool.apply_async(process, args=(pdf_filepath, args.image_scale, split_name, row), callback=callback, error_callback=my_error_callback) + # callback(process(orig_pdf_filepath, args.image_scale, split_name, row)) + + while tasks_in_pool > 0: + pbar.refresh() + print("Tasks in pool:", tasks_in_pool) + print("Waiting for following tasks:") + # print(pool._cache) + print(pool._taskqueue) + time.sleep(1) + + print("Finished processing DocLayNet dataset") + + +if __name__ == '__main__': + main() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..92e4054 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +pytorch +PyMuPDF +tesserocr +qoi