diff --git a/.condarc b/.condarc new file mode 100644 index 00000000..4883d181 --- /dev/null +++ b/.condarc @@ -0,0 +1,10 @@ +channels: + - defaults +show_channel_urls: true +default_channels: + - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main + - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/r + - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/msys2 +custom_channels: + conda-forge: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud + pytorch: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 00000000..7b7c5a91 --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1,12 @@ +# These are supported funding model platforms + +github: YaoFANGUK +patreon: # Replace with a single Patreon username +open_collective: # Replace with a single Open Collective username +ko_fi: # Replace with a single Ko-fi username +tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel +community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry +liberapay: # Replace with a single Liberapay username +issuehunt: # Replace with a single IssueHunt username +otechie: # Replace with a single Otechie username +custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] diff --git a/.gitignore b/.gitignore index 31cb5801..3b0d3dfe 100644 --- a/.gitignore +++ b/.gitignore @@ -347,5 +347,15 @@ $RECYCLE.BIN/ # Windows shortcuts *.lnk +*.srt # End of https://www.toptal.com/developers/gitignore/api/intellij+all,python,pycharm+all,macos,windows +/backend/models/V2/ch_det/inference.pdiparams +/backend/models/V4/ch_det/inference.pdiparams +/output/ +/backend/test.py +/dylib/ +/settings.ini +/test.py +/test2.py +/subtitle.ini diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..261eeb9e --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md index f49be413..66f8b452 100755 --- a/README.md +++ b/README.md @@ -1,60 +1,326 @@ -## 项目特色 +简体中文 | [English](README_en.md) -- 提取视频中的字幕,生成字幕文件,将水印(台标)等文本信息去除 -- 采用PaddleOCR,无需设置调用任何API,不需要接入百度、阿里等OCR服务即可本地完成文本识别 -- 无需用户手动设置字幕区域,该项目引入Paddle文本检测模型自动检测字幕区域 +## 项目简介 -## 使用说明 +![License](https://img.shields.io/badge/License-Apache%202-red.svg) +![python version](https://img.shields.io/badge/Python-3.12+-blue.svg) +![support os](https://img.shields.io/badge/OS-Windows/macOS/Linux-green.svg) -- 下载项目文件,将工作目录切换到项目文件所在目录 +Video-subtitle-extractor (VSE) 是一款将视频中的硬字幕提取为外挂字幕文件(srt格式)的软件。 +主要实现了以下功能: -```shell -cd video-subtitle-extractor +- 提取视频中的关键帧 +- 检测视频帧中文本的所在位置 +- 识别视频帧中文本的内容 +- 过滤非字幕区域的文本 +- [去除水印、台标文本、原视频硬字幕,可配合:video-subtitle-remover (VSR) ](https://github.com/YaoFANGUK/video-subtitle-remover/tree/main) +- 去除重复字幕行,生成srt字幕文件/txt文本文件 +> 若需要生成txt文本,可以在backend/config.py中设置```GENERATE_TXT=True``` +- 支持视频字幕**批量提取** +- 多语言:支持**简体中文(中英双语)**、**繁体中文**、**英文**、**日语**、**韩语**、**越南语**、**阿拉伯语**、**法语**、**德语**、**俄语**、**西班牙语**、**葡萄牙语**、**意大利语**等**87种**语言的字幕提取 +- 多模式: + - **快速**:(推荐)使用轻量模型,快速提取字幕,可能丢少量字幕、存在少量错别字 + - **自动**:(推荐)自动判断模型,CPU下使用轻量模型;GPU下使用精准模型,提取字幕速度较慢,可能丢少量字幕、几乎不存在错别字 + - **精准**:(不推荐)使用精准模型,GPU下逐帧检测,不丢字幕,几乎不存在错别字,但速度**非常慢** + +> 请优先使用快速/自动模式,如果前两种模式存在较多丢字幕轴情况时,再使用精准模式 + +

demo.png

+ +**项目特色**: + +- 采用本地进行OCR识别,无需设置调用任何API,不需要接入百度、阿里等在线OCR服务即可本地完成文本识别 +- 支持GPU加速,GPU加速后可以获得更高的准确率与更快的提取速度 + +**使用说明**: + +- 有使用问题请加群讨论,QQ群:210150985、816881808 + +- 点击【打开】后选择视频文件,调整字幕区域,点击【运行】 + - 单文件提取:打开文件的时候选择**单个**视频 + - **批量提取**:打开文件的时候选择**多个**视频,确保每个视频的分辨率、字幕区域保持一致 + +- 去除水印文本/替换特定文本: +> 如果视频中出现特定的文本需要删除,或者特定的文本需要替换,可以编辑 ``backend/configs/typoMap.json``文件,加入你要替换或去除的内容 + +```json +{ + "l'm": "I'm", + "l just": "I just", + "Let'sqo": "Let's go", + "Iife": "life", + "威筋": "威胁", + "性感荷官在线发牌": "" +} ``` +> 这样就可以把文本中出现的所有“威筋”替换为“威胁”,所有的“性感荷官在线发牌”文本删除 + +- 视频以及程序路径请**不要带中文和空格**,否则可能出现未知错误!!! + + > 如:以下存放视频和代码的路径都不行 + > + > D:\下载\vse\运行程序.exe(路径含中文) + > + > E:\study\kaoyan\sanshang youya.mp4 (路径含空格) + +- 直接下载压缩包解压运行,如果不能运行再按照下面的教程,尝试源码安装conda环境运行 + +**下载地址**: + +- Windows 绿色版本v2.0.0(CPU): vse_windows_cpu_v2.0.0.zip 提取码:**vse2** + +> **推荐使用,启动速度较快** + +- Windows 单文件版本v2.0.0(CPU): vse.exe 提取码:**rl02** + +> 双击直接运行,每次打开时会有一点慢,**若出现误报毒,使用绿色版** + +- Windows GPU版本v2.0.0(GPU): vse_windows_gpu_v2.0.0.7z 提取码:**vse2** + +> **仅供具有Nvidia显卡的用户使用(AMD的显卡不行),提取速度非常快** + +- MacOS 版本v0.1.0(CPU): vse_macOS_CPU.dmg 提取码:**7gbo** + +> PS: 若无法下载,请前往 Release 下载 + +> **有任何改进意见请在ISSUES和DISCUSSION中提出** -#### 1. 下载安装Anaconda -https://www.anaconda.com/products/individual#Downloads +## 演示 + +- GUI版: + +

demo.gif

+ +- 点击查看视频教程 👇 + +[![GPU版本安装教程](https://s1.ax1x.com/2022/04/15/L3KzLR.png)](https://www.bilibili.com/video/bv11L4y1Y7Tj "GUP版本安装教程") + + + +## 在线运行 + +- 使用**Google Colab Notebook**(免费GPU): Open In Colab + +> PS: Google Colab只能运行CLI版本 + + + +## 源码使用说明 + +#### 1. 下载安装Miniconda + +- Windows: Miniconda3-py312_24.7.1-0-Windows-x86_64.exe + + +- MacOS:Miniconda3-py312_24.7.1-0-MacOSX-x86_64.pkg + + +- Linux: Miniconda3-py312_24.7.1-0-Linux-x86_64.sh + +#### 2. 创建并激活虚机环境 + +(1)切换到源码所在目录: +```shell +cd <源码所在目录> +``` +> 例如:如果你的源代码放在D盘的tools文件下,并且源代码的文件夹名为video-subtitle-extractor,就输入 ```cd D:/tools/video-subtitle-extractor-main``` -#### 2. 使用conda创建项目虚拟环境并激活环境 (建议创建虚拟环境运行,也可以不用conda) +(2)创建激活conda环境 +```shell +conda create -n videoEnv python=3.12 +``` ```shell -conda create --name videoEnv python=3.7 -conda activate videoEnv +conda activate videoEnv ``` -#### 3. 使用pip安装依赖文件 +#### 3. 安装依赖文件 + +请确保你已经安装 python 3.12+,使用conda创建项目虚拟环境并激活环境 (建议创建虚拟环境运行,以免后续出现问题) + +- 安装依赖: + + ```shell + pip install -r requirements.txt + ``` + +- 安装CUDA和cuDNN + +> 请确保有拥有Nvidia的显卡,**30系列以上的显卡驱动可能不支持 cuda 11.2及以下版本的安装** +> +> 如果安装cuda 10.2,请对应安装7.6.5的cuDNN,并使用对应cuda版本的paddlepaddle,**请不要使用cuDNN v8.x 和 cuda 10.2的组合** +> +> 如果安装cuda 11.2,请对应安装8.1.1的cuDNN,并使用对应cuda版本的paddlepaddle +> +> 如果安装cuda 11.6,请对应安装8.4.0的cuDNN,并使用对应cuda版本的paddlepaddle +> +> 如果安装cuda 11.8,请对应安装8.6.0的cuDNN,并使用对应cuda版本的paddlepaddle +> +> 如果安装cuda 12.0,请对应安装8.9.1的cuDNN,并使用对应cuda版本的paddlepaddle + + + + -- mac用户, cpu用户: +
+ Linux用户 +
(1) 下载CUDA 11.7
+
wget https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda_11.7.0_515.43.04_linux.run
+
(2) 安装CUDA 11.7
+
sudo sh cuda_11.7.0_515.43.04_linux.run
+

1. 输入accept

+ +

2. 选中CUDA Toolkit 11.7(如果你没有安装nvidia驱动则选中Driver,如果你已经安装了nvidia驱动请不要选中driver),之后选中install,回车

+ +

3. 添加环境变量

+

在 ~/.bashrc 加入以下内容

+
# CUDA
+  export PATH=/usr/local/cuda-11.7/bin${PATH:+:${PATH}}
+  export LD_LIBRARY_PATH=/usr/local/cuda-11.7/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}
+

使其生效

+
source ~/.bashrc
+
(3) 下载cuDNN 8.4.1
+

国内:cudnn-linux-x86_64-8.4.1.50_cuda11.6-archive.tar.xz 提取码:57mg

+

国外:cudnn-linux-x86_64-8.4.1.50_cuda11.6-archive.tar.xz

+
(4) 安装cuDNN 8.4.1
+
 tar -xf cudnn-linux-x86_64-8.4.1.50_cuda11.6-archive.tar.xz
+   mv cudnn-linux-x86_64-8.4.1.50_cuda11.6-archive cuda
+   sudo cp ./cuda/include/* /usr/local/cuda-11.7/include/
+   sudo cp ./cuda/lib/* /usr/local/cuda-11.7/lib64/
+   sudo chmod a+r /usr/local/cuda-11.7/lib64/*
+   sudo chmod a+r /usr/local/cuda-11.7/include/*
+
+ +
+ Windows用户 +
(1) 下载CUDA 11.7
+ cuda_11.7.0_516.01_windows.exe +
(2) 安装CUDA 11.7
+
(3) 下载cuDNN 8.4.0
+

cudnn-windows-x86_64-8.4.0.27_cuda11.6-archive.zip

+
(4) 安装cuDNN 8.4.0
+

+ 将cuDNN解压后的cuda文件夹中的bin, include, lib目录下的文件复制到C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\对应目录下 +

+
+ + +- 安装paddlepaddle: + + - windows: + + ```shell + python -m pip install paddlepaddle-gpu==2.6.1.post117 -f https://www.paddlepaddle.org.cn/whl/windows/mkl/avx/stable.html + ``` + + - Linux: + + ```shell + python -m pip install paddlepaddle-gpu==2.6.1.post117 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html + ``` + +#### 4. 运行程序 + +- 运行图形化界面版本(GUI) ```shell -pip install -r requirements.txt +python gui.py ``` -- gpu用户:(使用cuda版本10.2) +- 运行命令行版本(CLI) ```shell -pip install -r requirements(gpu).txt +python ./backend/main.py ``` -#### 4. 运行程序 -> 对于gpu用户: -> 请将config.py中第48行 USE_GPU = False 改为: +## 常见问题与解决方案 + +#### 1. 运行不正常/没有结果/cuda及cudnn问题 + +解决方案:根据自己的显卡型号、显卡驱动版本,安装对应的cuda与cudnn -```python -USE_GPU = True + +#### 2. CondaHTTPError + +将项目中的.condarc放在用户目录下(C:\Users\\<你的用户名>),如果用户目录已经存在该文件则覆盖 + +解决方案:https://zhuanlan.zhihu.com/p/260034241 + +#### 3. Windows下出现geos_c.dll错误 + +```text + _lgeos = CDLL(os.path.join(sys.prefix, 'Library', 'bin', 'geos_c.dll')) + File "C:\Users\Flavi\anaconda3\envs\subEnv\lib\ctypes\__init__.py", line 364, in __init__ + self._handle = _dlopen(self._name, mode) +OSError: [WinError 126] 找不到指定的模块。 ``` -> 运行 +解决方案: + +(1) 卸载Shapely ```shell -python main.py +pip uninstall Shapely -y ``` -## 演示视频 +(2) 使用conda重新安装Shapely + +```shell +conda install Shapely +``` + +#### 4. 7z文件解压错误 + +解决方案:升级7-zip解压程序到最新版本 + +#### 5. Nuitka打包代码闪退 + +使用Nuitka版本```0.6.19```,将conda虚拟环境Lib文件夹下site-packages的所有文件复制到dependencies文件夹中,把paddle库dataset下image.py的有关subprocess代码全部注释了,使用以下打包命令: + +```shell + python -m nuitka --standalone --mingw64 --include-data-dir=D:\vse\backend=backend --include-data-dir=D:\vse\dependencies=dependencies --nofollow-imports --windows-icon-from-ico=D:\vse\design\vse.ico --plugin-enable=tk-inter,multiprocessing --output-dir=out .\gui.py +``` + +编译成单个文件(pip安装zstandard可以减小体积) +```shell +python -m nuitka --standalone --windows-disable-console --mingw64 --lto no --include-data-dir=C:\Users\Yao\Downloads\vse\backend=backend --include-data-dir=C:\Users\Yao\Downloads\vse\design=design --include-data-dir=C:\Users\Yao\Downloads\vse\dependencies=dependencies --nofollow-imports --windows-icon-from-ico=C:\Users\Yao\Downloads\vse\design\vse.ico --plugin-enable=tk-inter,multiprocessing --output-dir=C:\Users\Yao\Downloads\out --onefile .\gui.py +``` + +## 社区支持 + +#### Jetbrains 全家桶支持 +本项目开发所使用的IDE由Jetbrains支持。 +
+ JetBrains Logo (Main) logo. +
-[![Demo Video](https://s1.ax1x.com/2020/10/05/0JWVeJ.png)](https://www.bilibili.com/video/BV1t5411h78J "Demo Video") +## 赞助 + +| 捐赠者 | 累计捐赠金额 | 赞助席位 | +|----------------------------------------| --- | --- | +| **伟 | 300.00 RMB | 金牌赞助席位 | +| 周学彬 | 200.00 RMB | 金牌赞助席位 | +| 爱东 | 100.00 RMB | 金牌赞助席位 | +| **迪 | 100.00 RMB | 金牌赞助席位 | +| ysjm | 100.00 RMB | 金牌赞助席位 | +| [ischeung](https://github.com/ischeung) | 100.00 RMB | 金牌赞助席位 | +| 明 | 88.00 RMB | 金牌赞助席位 | +| [neoyxm](https://github.com/neoyxm) | 50.00 RMB | 银牌赞助席位 | +| 亦 | 50.00 RMB | 银牌赞助席位 | +| 周昊 | 50.00 RMB | 银牌赞助席位 | +| 玛卡巴卡 | 35.00 RMB | 银牌赞助席位 | +| 净心 | 30.00 RMB | 银牌赞助席位 | +| ysjm | 30.00 RMB | 银牌赞助席位 | +| 生活不止眼前的苟且 | 30.00 RMB | 银牌赞助席位 | +| 迷走神经病 | 30.00 RMB | 银牌赞助席位 | +| [AcelXiao](https://github.com/acelxiao) | 20.00 RMB | 银牌赞助席位 | +| 又是李啊 | 10.00 RMB | 铜牌赞助席位 | +| 匿名 | 8.80 RMB | 铜牌赞助席位 | +| 落墨 | 6.00 RMB | 铜牌赞助席位 | +| 未闻花名 | 5.00 RMB | 铜牌赞助席位 | +| sky | 5.00 RMB | 铜牌赞助席位 | diff --git a/README_en.md b/README_en.md new file mode 100644 index 00000000..8e081282 --- /dev/null +++ b/README_en.md @@ -0,0 +1,262 @@ +[简体中文](README.md) | English + +## Introduction + +![License](https://img.shields.io/badge/License-Apache%202-red.svg) +![python version](https://img.shields.io/badge/Python-3.12+-blue.svg) +![support os](https://img.shields.io/badge/OS-Windows/macOS/Linux-green.svg) + +**Video-subtitle-extractor** (VSE) is a free, open-source tool which can help you rip the hard-coded subtitles from videos and automatically generate corresponding **srt** files for each video. It includes the following implementations: + +- Detect and extract subtitle frames (using traditional graphic method) +- Detect subtitle areas (i.e., coordinates) (as well as scene text if you want) (using deep learning algorithms) +- Converting graphic text into plain-text (using deep learning algorithms) +- Filter non-subtitle text (e.g., Logo and watermark etc.) +- Remove watermark, logo text and original video hard subtitles, see: [video-subtitle-remover (VSR)](https://github.com/YaoFANGUK/video-subtitle-remover/tree/main). +- Remove duplicated subtitle line and **generate srt file** (by calculating text similarity) +- Batch extraction. You can select multiple video files at one time and this tool can generate subtitles for each video. +- Multiple language support. You can extract subtitles in 87 languages such as: **Simplified Chinese**, **English**, + **Japanese**, **Korean**, **Arabic**, **Traditional Chinese**, **French**, **German**, **Russian**, **Spanish**, + **Portuguese**, **Italian** +- Multi-mode: + - **fast**: (Recommended) Uses a lightweight model for quick subtitle extraction, though it might miss a small amount of subtitles and contains a few typos. + - **auto**: (Recommended) Automatically selects the model. It uses the lightweight model under the CPU, and the precise model under the GPU. While subtitle extraction speed is slower and might miss a minor amount of subtitles, there are almost no typos. + - **accurate**: (Not Recommended) Uses the precise model with frame-by-frame detection under the GPU, ensuring no missed subtitles and almost non-existent typos, but the speed is **very slow**. + +

demo.png

+ +**Features**: + +- You don't need to do any preprocessing (e.g., binarization) and don't need to consider all aspects like subtitle fonts and size etc.. +- This is an offline project. There is no online API call and you dont need to connect to the Internet service provider in order to get results. + +**Usage**: + +- After clicking "Open", select video file(s), adjust the subtitle area, and then click "Run". + - Single file extraction: When opening a file, choose a single video. + - Batch extraction: When opening files, choose multiple videos, ensure that every video's resolution and subtitle area remain consistent. + +- Remove watermark text/replace specific text: +> If specific text needs to be deleted from generated .srt file, or specific text needs to be replaced, you can edit the ``backend/configs/typoMap.json`` file and add the content you want to replace or remove. + +```json +{ + "l'm": "I'm", + "l just": "I just", + "Let'sqo": "Let's go", + "Iife": "life", + "威筋": "threat", + "性感荷官在线发牌": "" +} +``` + +> In this way, you can replace all occurrences of "威筋" in the text with "threat" and delete all instances of the text "性感荷官在线发牌". + + +- Directly download the compressed package, unzip it and run it. If it cannot run, follow the tutorial below and try to install the Conda environment and run it using the source code. + +**Download**: + +- Windows executable (might be slow when initial start): vse.exe + +- Windows GPU version:vse_windows_gpu_v2.0.0.7z + +- Windows CPU version:vse_windows_cpu_v2.0.0.zip + +- MacOS:vse_macOS_CPU.dmg + + +> **Provide your suggestions to improve this project in ISSUES & DISCUSSION** + + +## Demo + +- Graphic User Interface (GUI): + +

demo.gif

+ + +- Command Line Interface (CLI): + +[![Demo Video](https://s1.ax1x.com/2020/10/05/0JWVeJ.png)](https://www.bilibili.com/video/BV1t5411h78J "Demo Video") + + +## Running Online + +- **Google Colab Notebook with free GPU**: Open In Colab + +> PS: can only run CLI version on Google Colab + + +## Getting Started with Source Code + +#### 1. Download and Install Miniconda + + +- Windows: Miniconda3-py312_24.7.1-0-Windows-x86_64.exe + + +- MacOS:Miniconda3-py312_24.7.1-0-MacOSX-x86_64.pkg + + +- Linux: Miniconda3-py312_24.7.1-0-Linux-x86_64.sh + + +#### 2. Activate Vitrual Environment + +(1) Switch to working directory +```shell +cd +``` + +(2) create and activate conda environment +```shell +conda create -n videoEnv python=3.12 pip +``` + +```shell +conda activate videoEnv +``` + + +#### 3. Install Dependencies + +Before you install dependencies, make sure your python 3.8+ has installed as well as conda virtual environment has created and activated. + +- Install dependencies: + + ```shell + pip install -r requirements.txt + ``` + +- Install **CUDA** and **cuDNN** + > make sure that you have **NVIDIA** graphic card before doing this step + +
+ Linux +
(1) Download CUDA 11.7
+
wget https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda_11.7.0_515.43.04_linux.run
+
(2) Install CUDA 11.7
+
sudo sh cuda_11.7.0_515.43.04_linux.run
+

1. Input accept

+ +

2. make sure CUDA Toolkit 11.7 is chosen (If you have already installed driver, do not select Driver)

+ +

3. Add environment variables

+

add the following content in ~/.bashrc

+
# CUDA
+    export PATH=/usr/local/cuda-11.7/bin${PATH:+:${PATH}}
+    export LD_LIBRARY_PATH=/usr/local/cuda-11.7/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}
+

Make sure it works

+
source ~/.bashrc
+
(3) Download cuDNN 8.4.1
+

cudnn-linux-x86_64-8.4.1.50_cuda11.6-archive.tar.xz

+
(4) Install cuDNN 8.4.1
+
 tar -xf cudnn-linux-x86_64-8.4.1.50_cuda11.6-archive.tar.xz
+   mv cudnn-linux-x86_64-8.4.1.50_cuda11.6-archive cuda
+   sudo cp ./cuda/include/* /usr/local/cuda-11.7/include/
+   sudo cp ./cuda/lib/* /usr/local/cuda-11.7/lib64/
+   sudo chmod a+r /usr/local/cuda-11.7/lib64/*
+   sudo chmod a+r /usr/local/cuda-11.7/include/*
+
+ +
+ Windows +
(1) Download CUDA 11.7
+ cuda_11.7.0_516.01_windows.exe +
(2) Install CUDA 11.7
+
(3) Download cuDNN 8.4.0
+

cudnn-windows-x86_64-8.4.0.27_cuda11.6-archive.zip

+
(4) Install cuDNN 8.4.0
+

+ unzip "cudnn-windows-x86_64-8.4.0.27_cuda11.6-archive.zip", then move all files in "bin, include, lib" in cuda + directory to C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\ +

+
+ + + - Install paddlepaddle: + - windows: + + ```shell + python -m pip install paddlepaddle-gpu==2.6.1.post117 -f https://www.paddlepaddle.org.cn/whl/windows/mkl/avx/stable.html + ``` + + - Linux: + + ```shell + python -m pip install paddlepaddle-gpu==2.6.1.post117 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html + ``` + + > If you installed cuda 10.2,please install cuDNN 7.6.5 instead of cuDNN v8.x + + > If you installed cuda 11.2, please install cuDNN 8.1.1. However, RTX 30xx might be incompatible with cuda 11.2 + + +#### 3. Running the program + +- Run GUI version + +```shell +python gui.py +``` + +- Run CLI version + +```shell +python ./backend/main.py +``` + +## Q & A + +#### 1. Running Failure or Environment Problem + +Solution: If you are using a nvidia ampere architecture graphic card such as RTX 3050/3060/3070/3080, please use the latest PaddlePaddle version and CUDA 11.6 with cuDNN 8.2.1. Otherwise, check your which cuda and cudnn works with your GPU and then install them. + + +#### 2. For Windows users, if you encounter errors related to "geos_c.dll" + +```text + _lgeos = CDLL(os.path.join(sys.prefix, 'Library', 'bin', 'geos_c.dll')) + File "C:\Users\Flavi\anaconda3\envs\subEnv\lib\ctypes\__init__.py", line 364, in __init__ + self._handle = _dlopen(self._name, mode) +OSError: [WinError 126] The specified module could not be found。 +``` + +Solution: + +1) Uninstall Shapely + +```shell +pip uninstall Shapely -y +``` + +2) Reinstall Shapely via conda (make sure you have anaconda or miniconda installed) + +```shell +conda install Shapely +``` + + +#### 3. How to generate executables +Using Nuitka version 0.6.19, copy all the files of ```site-packages``` under the Lib folder of the conda virtual environment to the ```dependencies``` folder, and comment all codes relevant to ```subprocess``` of ```image.py``` under the ```paddle``` library dataset, and use the following packaging command: + +```shell + python -m nuitka --standalone --mingw64 --include-data-dir=D:\vse\backend=backend --include-data-dir=D:\vse\dependencies=dependencies --nofollow-imports --windows-icon-from-ico=D:\vse\design\vse.ico --plugin-enable=tk-inter,multiprocessing --output-dir=out .\gui.py +``` + +Make a single ```.exe``` file, (pip install zstandard can compress the file): + +```shell + python -m nuitka --standalone --windows-disable-console --mingw64 --lto no --include-data-dir=D:\vse\backend=backend --include-data-dir=D:\vse\dependencies=dependencies --nofollow-imports --windows-icon-from-ico=D:\vse\design\vse.ico --plugin-enable=tk-inter,multiprocessing --output-dir=out --onefile .\gui.py +``` + + +## Community Support + +#### Jetbrains All Products Pack +The IDE this project used is supported by Jetbrains +
+ JetBrains Logo (Main) logo. +
+ diff --git a/backend/__init__.py b/backend/__init__.py new file mode 100644 index 00000000..622b46f0 --- /dev/null +++ b/backend/__init__.py @@ -0,0 +1,3 @@ +import os +import sys +sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'dependencies')) diff --git a/backend/config.py b/backend/config.py new file mode 100644 index 00000000..5576f58e --- /dev/null +++ b/backend/config.py @@ -0,0 +1,256 @@ +# -*- coding: utf-8 -*- +""" +@Author : Fang Yao +@Time : 2021/3/24 9:36 上午 +@FileName: config.py +@desc: 项目配置文件,可以在这里调参,牺牲时间换取精确度,或者牺牲准确度换取时间 +""" +import warnings +warnings.filterwarnings("ignore", category=DeprecationWarning) +import configparser +import os +import re +import time +from pathlib import Path +from fsplit.filesplit import Filesplit +import paddle +from tools.constant import * + + +# 项目的base目录 +BASE_DIR = str(Path(os.path.abspath(__file__)).parent) + +# ×××××××××××××××××××× [不要改]读取配置文件 start ×××××××××××××××××××× +# 读取settings.ini配置 +settings_config = configparser.ConfigParser() +MODE_CONFIG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'settings.ini') +if not os.path.exists(os.path.join(os.path.dirname(os.path.dirname(__file__)), 'settings.ini')): + # 如果没有配置文件,默认使用中文 + with open(os.path.join(os.path.dirname(os.path.dirname(__file__)), 'settings.ini'), mode='w', encoding='utf-8') as f: + f.write('[DEFAULT]\n') + f.write('Interface = 简体中文\n') + f.write('Language = ch\n') + f.write('Mode = fast') +settings_config.read(MODE_CONFIG_PATH, encoding='utf-8') + +# 读取interface下的语言配置,e.g. ch.ini +interface_config = configparser.ConfigParser() +INTERFACE_KEY_NAME_MAP = { + '简体中文': 'ch', + '繁體中文': 'chinese_cht', + 'English': 'en', + '한국어': 'ko', + '日本語': 'japan', + 'Tiếng Việt': 'vi', + 'Español': 'es' +} +interface_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'interface', + f"{INTERFACE_KEY_NAME_MAP[settings_config['DEFAULT']['Interface']]}.ini") +interface_config.read(interface_file, encoding='utf-8') +# ×××××××××××××××××××× [不要改]读取配置文件 end ×××××××××××××××××××× + + +# ×××××××××××××××××××× [不要改]判断程序运行路径是否合法 start ×××××××××××××××××××× +# 程序运行路径如果包含中文或者空格,运行过程在程序可能会存在bug,因此需要检查路径合法性 +# 默认为合法路径 +IS_LEGAL_PATH = True +# 如果路径包含中文,设置路径为非法 +if re.search(r"[\u4e00-\u9fa5]+", BASE_DIR): + IS_LEGAL_PATH = False +# 如果路径包含空格,设置路径为非法 +if re.search(r"\s", BASE_DIR): + IS_LEGAL_PATH = False +# 如果为程序存放在非法路径则一直提示用户路径不合法 +while not IS_LEGAL_PATH: + print(interface_config['Main']['IllegalPathWarning']) + time.sleep(3) +# ×××××××××××××××××××× [不要改]判断程序运行路径是否合法 end ×××××××××××××××××××× + + +# ×××××××××××××××××××× [不要改]判断是否使用GPU start ×××××××××××××××××××× +# 是否使用GPU +USE_GPU = False +# 如果paddlepaddle编译了gpu的版本 +if paddle.is_compiled_with_cuda(): + # 查看是否有可用的gpu + if len(paddle.static.cuda_places()) > 0: + # 如果有GPU则使用GPU + USE_GPU = True +# ×××××××××××××××××××× [不要改]判断是否使用GPU start ×××××××××××××××××××× + + +# ×××××××××××××××××××× [不要改]读取语言、模型路径、字典路径 start ×××××××××××××××××××× +# 设置识别语言 +REC_CHAR_TYPE = settings_config['DEFAULT']['Language'] + +# 设置识别模式 +MODE_TYPE = settings_config['DEFAULT']['Mode'] +ACCURATE_MODE_ON = False +if MODE_TYPE == 'accurate': + ACCURATE_MODE_ON = True +if MODE_TYPE == 'fast': + ACCURATE_MODE_ON = False +if MODE_TYPE == 'auto': + if USE_GPU: + ACCURATE_MODE_ON = True + else: + ACCURATE_MODE_ON = False +# 模型文件目录 +# 默认模型版本 V4 +MODEL_VERSION = 'V4' +# 文本检测模型 +DET_MODEL_BASE = os.path.join(BASE_DIR, 'models') +# 设置文本识别模型 + 字典 +REC_MODEL_BASE = os.path.join(BASE_DIR, 'models') +# 默认字典路径为中文 +DICT_BASE = os.path.join(BASE_DIR, 'ppocr', 'utils', 'dict') +# V3, V4模型默认图形识别的shape为3, 48, 320 +REC_IMAGE_SHAPE = '3,48,320' +REC_MODEL_PATH = os.path.join(REC_MODEL_BASE, MODEL_VERSION, f'{REC_CHAR_TYPE}_rec') +DET_MODEL_PATH = os.path.join(DET_MODEL_BASE, MODEL_VERSION, f'{REC_CHAR_TYPE}_det') + +LATIN_LANG = [ + 'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga', 'hr', + 'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms', 'mt', 'nl', + 'no', 'oc', 'pi', 'pl', 'pt', 'ro', 'rs_latin', 'sk', 'sl', 'sq', 'sv', + 'sw', 'tl', 'tr', 'uz', 'vi', 'latin', 'german', 'french' +] +ARABIC_LANG = ['ar', 'fa', 'ug', 'ur'] +CYRILLIC_LANG = [ + 'ru', 'rs_cyrillic', 'be', 'bg', 'uk', 'mn', 'abq', 'ady', 'kbd', 'ava', + 'dar', 'inh', 'che', 'lbe', 'lez', 'tab', 'cyrillic' +] +DEVANAGARI_LANG = [ + 'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new', 'gom', + 'sa', 'bgc', 'devanagari' +] +OTHER_LANG = [ + 'ch', 'japan', 'korean', 'en', 'ta', 'kn', 'te', 'ka', + 'chinese_cht', +] +MULTI_LANG = LATIN_LANG + ARABIC_LANG + CYRILLIC_LANG + DEVANAGARI_LANG + \ + OTHER_LANG + +# 定义字典路径 +DICT_PATH = os.path.join(DICT_BASE, f'{REC_CHAR_TYPE}_dict.txt') +DET_MODEL_FAST_PATH = os.path.join(DET_MODEL_BASE, MODEL_VERSION, 'ch_det_fast') + + +# 如果设置了识别文本语言类型,则设置为对应的语言 +if REC_CHAR_TYPE in MULTI_LANG: + # 定义文本检测与识别模型 + # 使用快速模式时,调用轻量级模型 + if MODE_TYPE == 'fast': + DET_MODEL_PATH = os.path.join(DET_MODEL_BASE, MODEL_VERSION, 'ch_det_fast') + REC_MODEL_PATH = os.path.join(REC_MODEL_BASE, MODEL_VERSION, f'{REC_CHAR_TYPE}_rec_fast') + # 使用自动模式时,检测有没有使用GPU,根据GPU判断模型 + elif MODE_TYPE == 'auto': + # 如果使用GPU,则使用大模型 + if USE_GPU: + DET_MODEL_PATH = os.path.join(DET_MODEL_BASE, MODEL_VERSION, 'ch_det') + # 英文模式的ch模型识别效果好于fast + if REC_CHAR_TYPE == 'en': + REC_MODEL_PATH = os.path.join(REC_MODEL_BASE, MODEL_VERSION, f'ch_rec') + DICT_PATH = os.path.join(DICT_BASE, f'ch_dict.txt') + else: + REC_MODEL_PATH = os.path.join(REC_MODEL_BASE, MODEL_VERSION, f'{REC_CHAR_TYPE}_rec') + else: + DET_MODEL_PATH = os.path.join(DET_MODEL_BASE, MODEL_VERSION, 'ch_det_fast') + REC_MODEL_PATH = os.path.join(REC_MODEL_BASE, MODEL_VERSION, f'{REC_CHAR_TYPE}_rec_fast') + else: + DET_MODEL_PATH = os.path.join(DET_MODEL_BASE, MODEL_VERSION, 'ch_det') + REC_MODEL_PATH = os.path.join(REC_MODEL_BASE, MODEL_VERSION, f'{REC_CHAR_TYPE}_rec') + # 如果默认版本(V4)没有大模型,则切换为默认版本(V4)的fast模型 + if not os.path.exists(REC_MODEL_PATH): + REC_MODEL_PATH = os.path.join(REC_MODEL_BASE, MODEL_VERSION, f'{REC_CHAR_TYPE}_rec_fast') + # 如果默认版本(V4)既没有大模型,又没有fast模型,则使用V3版本的大模型 + if not os.path.exists(REC_MODEL_PATH): + MODEL_VERSION = 'V3' + REC_MODEL_PATH = os.path.join(REC_MODEL_BASE, MODEL_VERSION, f'{REC_CHAR_TYPE}_rec') + # 如果V3版本没有大模型,则使用V3版本的fast模型 + if not os.path.exists(REC_MODEL_PATH): + MODEL_VERSION = 'V3' + REC_MODEL_PATH = os.path.join(REC_MODEL_BASE, MODEL_VERSION, f'{REC_CHAR_TYPE}_rec_fast') + + if REC_CHAR_TYPE in LATIN_LANG: + REC_MODEL_PATH = os.path.join(REC_MODEL_BASE, MODEL_VERSION, f'latin_rec_fast') + DICT_PATH = os.path.join(DICT_BASE, f'latin_dict.txt') + elif REC_CHAR_TYPE in ARABIC_LANG: + REC_MODEL_PATH = os.path.join(REC_MODEL_BASE, MODEL_VERSION, f'arabic_rec_fast') + DICT_PATH = os.path.join(DICT_BASE, f'arabic_dict.txt') + elif REC_CHAR_TYPE in CYRILLIC_LANG: + REC_MODEL_PATH = os.path.join(REC_MODEL_BASE, MODEL_VERSION, f'cyrillic_rec_fast') + DICT_PATH = os.path.join(DICT_BASE, f'cyrillic_dict.txt') + elif REC_CHAR_TYPE in DEVANAGARI_LANG: + REC_MODEL_PATH = os.path.join(REC_MODEL_BASE, MODEL_VERSION, f'devanagari_rec_fast') + DICT_PATH = os.path.join(DICT_BASE, f'devanagari_dict.txt') + + # 定义图像识别shape + if MODEL_VERSION == 'V2': + REC_IMAGE_SHAPE = '3,32,320' + else: + REC_IMAGE_SHAPE = '3,48,320' + + # 查看该路径下是否有文本模型识别完整文件,没有的话合并小文件生成完整文件 + if 'inference.pdiparams' not in (os.listdir(REC_MODEL_PATH)): + fs = Filesplit() + fs.merge(input_dir=REC_MODEL_PATH) + # 查看该路径下是否有文本模型识别完整文件,没有的话合并小文件生成完整文件 + if 'inference.pdiparams' not in (os.listdir(DET_MODEL_PATH)): + fs = Filesplit() + fs.merge(input_dir=DET_MODEL_PATH) +# ×××××××××××××××××××× [不要改]读取语言、模型路径、字典路径 end ×××××××××××××××××××× + + +# --------------------- 请根据自己的实际情况改 start----------------- +# 是否生成TXT文本字幕 +GENERATE_TXT = True + +# 每张图中同时识别6个文本框中的文本,GPU显存越大,该数值可以设置越大 +REC_BATCH_NUM = 6 +# DB算法每个batch识别多少张,默认为10 +MAX_BATCH_SIZE = 10 + +# 默认字幕出现区域为下方 +DEFAULT_SUBTITLE_AREA = SubtitleArea.UNKNOWN + +# 每一秒抓取多少帧进行OCR识别 +EXTRACT_FREQUENCY = 3 + +# 容忍的像素点偏差 +PIXEL_TOLERANCE_Y = 50 # 允许检测框纵向偏差50个像素点 +PIXEL_TOLERANCE_X = 100 # 允许检测框横向偏差100个像素点 + +# 字幕区域偏移量 +SUBTITLE_AREA_DEVIATION_PIXEL = 50 + +# 最有可能出现的水印区域 +WATERMARK_AREA_NUM = 5 + +# 文本相似度阈值 +# 用于去重时判断两行字幕是不是同一行,这个值越高越严格。 e.g. 0.99表示100个字里面有99各个字一模一样才算相似 +# 采用动态算法实现相似度阈值判断: 对于短文本要求较低的阈值,对于长文本要求较高的阈值 +# 如:文本较短,人民、入民,0.5就算相似 +THRESHOLD_TEXT_SIMILARITY = 0.8 + +# 字幕提取中置信度低于0.75的不要 +DROP_SCORE = 0.75 + +# 字幕区域允许偏差, 0为不允许越界, 0.03表示可以越界3% +SUB_AREA_DEVIATION_RATE = 0 + +# 输出丢失的字幕帧, 仅简体中文,繁体中文,日文,韩语有效, 默认将调试信息输出到: 视频路径/loss +DEBUG_OCR_LOSS = False + +# 是否不删除缓存数据,以方便调试 +DEBUG_NO_DELETE_CACHE = False + +# 是否删除空时间轴 +DELETE_EMPTY_TIMESTAMP = True + +# 是否重新分词, 用于解决没有语句没有空格 +WORD_SEGMENTATION = True + +# --------------------- 请根据自己的实际情况改 end----------------------------- + +os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' diff --git a/backend/configs/typoMap.json b/backend/configs/typoMap.json new file mode 100644 index 00000000..9cb3b171 --- /dev/null +++ b/backend/configs/typoMap.json @@ -0,0 +1,7 @@ +{ + "l'm": "I'm", + "l just": "I just", + "Let'sqo": "Let's go", + "Iife": "life", + "威筋": "威胁" +} \ No newline at end of file diff --git a/backend/interface/ch.ini b/backend/interface/ch.ini new file mode 100644 index 00000000..7ad16484 --- /dev/null +++ b/backend/interface/ch.ini @@ -0,0 +1,139 @@ +[LanguageModeGUI] +Title = 字幕提取器 +InterfaceLanguage = 选择语言: +SubtitleLanguage = 选择视频字幕的语言: +Mode = 选择识别模式: +ModeAuto = 自动 +ModeFast = 快速 +ModeAccurate = 精准 +InterfaceDefault = 简体中文 +LanguageCH = 简体中文 +LanguageCHINESE_CHT = 繁体中文 +LanguageEN = 英文 +LanguageJAPAN = 日文 +LanguageKOREAN = 韩文 +LanguageAR = 阿拉伯文 +LanguageFRENCH = 法文 +LanguageGERMAN = 德文 +LanguageRU = 俄罗斯文 +LanguageES = 西班牙文 +LanguagePT = 葡萄牙文 +LanguageIT = 意大利文 +LanguageAF = 南非荷兰文 +LanguageAZ = 阿塞拜疆文 +LanguageBS = 波斯尼亚文 +LanguageCS = 捷克文 +LanguageCY = 威尔士文 +LanguageDA = 丹麦文 +LanguageDE = 德文 +LanguageET = 爱沙尼亚文 +LanguageFR = 法文 +LanguageGA = 爱尔兰文 +LanguageHR = 克罗地亚文 +LanguageHU = 匈牙利文 +LanguageID = 印尼文 +LanguageIS = 冰岛文 +LanguageKU = 库尔德文 +LanguageLA = 拉丁文 +LanguageLT = 立陶宛文 +LanguageLV = 拉脱维亚文 +LanguageMI = 毛利文 +LanguageMS = 马来文 +LanguageMT = 马耳他文 +LanguageNL = 荷兰文 +LanguageNO = 挪威文 +LanguageOC = 欧西坦文 +LanguagePI = 巴利文 +LanguagePL = 波兰文 +LanguageRO = 罗马尼亚文 +LanguageRS_LATIN = 塞尔维亚文(latin) +LanguageSK = 斯洛伐克文 +LanguageSL = 斯洛文尼亚文 +LanguageSQ = 阿尔巴尼亚文 +LanguageSV = 瑞典文 +LanguageSW = 西瓦希里文 +LanguageTL = 塔加洛文 +LanguageTR = 土耳其文 +LanguageUZ = 乌兹别克文 +LanguageVI = 越南文 +LanguageLATIN = 拉丁文 +LanguageFA = 波斯文 +LanguageUG = 维吾尔文 +LanguageUR = 乌尔都文 +LanguageRS_CYRILLIC = 塞尔维亚文(cyrillic) +LanguageBE = 白俄罗斯文 +LanguageBG = 保加利亚文 +LanguageUK = 乌克兰文 +LanguageMN = 蒙古文 +LanguageABQ = 阿巴扎文 +LanguageADY = 阿迪赫文 +LanguageKBD = 卡巴尔达文 +LanguageAVA = 阿瓦尔文 +LanguageDAR = 达尔瓦文 +LanguageINH = 因古什文 +LanguageCHE = 车臣文 +LanguageLBE = 拉克文 +LanguageLEZ = 莱兹甘文 +LanguageTAB = 塔巴萨兰文 +LanguageCYRILLIC = 西里尔文 +LanguageHI = 印地文 +LanguageMR = 马拉地文 +LanguageNE = 尼泊尔文 +LanguageBH = 比尔哈文 +LanguageMAI = 迈蒂利文 +LanguageANG = 昂加文 +LanguageBHO = 孟加拉文 +LanguageMAH = 摩揭陀文 +LanguageSCK = 那格浦尔文 +LanguageNEW = 尼瓦尔文 +LanguageGOM = 保加利亚文 +LanguageSA = 沙特阿拉伯文 +LanguageBGC = 哈里亚纳文 +LanguageDEVANAGARI = 德瓦那加里文 +LanguageTA = 泰米尔文 +LanguageKN = 卡纳达文 +LanguageTE = 泰卢固文 +LanguageKA = 卡纳达文 + +[SubtitleExtractorGUI] +Title = 字幕提取器 +Open = 打开 +AllFile = 所有文件 +Vertical = 垂直方向 +Horizontal = 水平方向 +Run = 运行 +Setting = 设置 +OpenVideoSuccess = 成功打开视频 +OpenVideoFirst = 请先打开视频 +SubtitleArea = 字幕区域 + +[Main] +RecSubLang = 识别字幕语言 +RecMode = 识别模式 +IllegalPathWarning = 【警告】程序运行中断!路径不合法!请不要将程序放入带有空格和中文的路径下!!!请修改程序路径名后重新运行程序 +GPUSpeedUp = 使用GPU进行加速 +FrameCount = 帧数 +FrameRate = 帧率 +StartProcessFrame = 【处理中】开启提取视频关键帧... +FinishProcessFrame = 【结束】提取视频关键帧完毕... +StartFindSub = 【处理中】开始提取字幕信息,此步骤可能花费较长时间,请耐心等待... +FinishFindSub = 【结束】完成字幕提取,生成原始字幕文件... +StartDetectWaterMark = 【处理中】开始检测并过滤水印区域内容 +checkWaterMark = 视频是否存在水印区域,存在的话输入y,不存在的话输入n: +FinishDetectWaterMark = 【结束】已经成功过滤水印区域内容 +StartDeleteNonSub = 【处理中】开始检测非字幕区域,并将非字幕区域的内容删除 +FinishDeleteNonSub = 【结束】已将非字幕区域的内容删除 +StartGenerateSub = 【处理中】开始生成字幕文件 +FinishGenerateSub = 【结束】字幕文件生成成功 +SubFrameNo = 字幕帧 +Elapse = 耗时 +ChooseSubArea = 请指定字幕区域 +WatchPicture = 请查看图片, 确定水印区域 +QuestionDelete = 区域中的字幕是否去除? 输入 "y" 或 "回车" 表示去除,输入"n"或其他表示不去除: +FinishDelete = 已经删除该区域字幕... +FinishWaterMarkFilter = 水印区域字幕过滤完毕... +CheckSubArea = 请查看图片, 确定字幕区域是否正确: +DeleteNoSubArea = 红色框区域外的字幕是否去除? 输入 "y" 或 "回车" 表示去除,输入"n"或其他表示不去除: +FinishDeleteNoSubArea = 去除完毕 +SubLocation = 字幕文件生成位置: +InputVideo = 请输入视频完整路径: diff --git a/backend/interface/chinese_cht.ini b/backend/interface/chinese_cht.ini new file mode 100644 index 00000000..cc64f9a0 --- /dev/null +++ b/backend/interface/chinese_cht.ini @@ -0,0 +1,139 @@ +[LanguageModeGUI] +Title = 字幕提取器 +InterfaceLanguage = 選擇語言: +SubtitleLanguage = 選擇視頻字幕的語言: +Mode = 選擇識別模式: +ModeAuto = 自動 +ModeFast = 快速 +ModeAccurate = 精準 +InterfaceDefault = 繁體中文 +LanguageCH = 簡體中文 +LanguageCHINESE_CHT = 繁體中文 +LanguageEN = 英文 +LanguageJAPAN = 日文 +LanguageKOREAN = 韓文 +LanguageAR = 阿拉伯文 +LanguageFRENCH = 法文 +LanguageGERMAN = 德文 +LanguageRU = 俄羅斯文 +LanguageES = 西班牙文 +LanguagePT = 葡萄牙文 +LanguageIT = 意大利文 +LanguageAF = 南非荷蘭文 +LanguageAZ = 阿塞拜疆文 +LanguageBS = 波斯尼亞文 +LanguageCS = 捷克文 +LanguageCY = 威爾士文 +LanguageDA = 丹麥文 +LanguageDE = 德文 +LanguageET = 愛沙尼亞文 +LanguageFR = 法文 +LanguageGA = 愛爾蘭文 +LanguageHR = 克羅地亞文 +LanguageHU = 匈牙利文 +LanguageID = 印尼文 +LanguageIS = 冰島文 +LanguageKU = 庫爾德文 +LanguageLA = 拉丁文 +LanguageLT = 立陶宛文 +LanguageLV = 拉脫維亞文 +LanguageMI = 毛利文 +LanguageMS = 馬來文 +LanguageMT = 馬耳他文 +LanguageNL = 荷蘭文 +LanguageNO = 挪威文 +LanguageOC = 歐西坦文 +LanguagePI = 巴利文 +LanguagePL = 波蘭文 +LanguageRO = 羅馬尼亞文 +LanguageRS_LATIN = 塞爾維亞文(latin) +LanguageSK = 斯洛伐克文 +LanguageSL = 斯洛文尼亞文 +LanguageSQ = 阿爾巴尼亞文 +LanguageSV = 瑞典文 +LanguageSW = 西瓦希裏文 +LanguageTL = 塔加洛文 +LanguageTR = 土耳其文 +LanguageUZ = 烏茲別克文 +LanguageVI = 越南文 +LanguageLATIN = 拉丁文 +LanguageFA = 波斯文 +LanguageUG = 維吾爾文 +LanguageUR = 烏爾都文 +LanguageRS_CYRILLIC = 塞爾維亞文(cyrillic) +LanguageBE = 白俄羅斯文 +LanguageBG = 保加利亞文 +LanguageUK = 烏克蘭文 +LanguageMN = 蒙古文 +LanguageABQ = 阿巴紮文 +LanguageADY = 阿迪赫文 +LanguageKBD = 卡巴爾達文 +LanguageAVA = 阿瓦爾文 +LanguageDAR = 達爾瓦文 +LanguageINH = 因古什文 +LanguageCHE = 車臣文 +LanguageLBE = 拉克文 +LanguageLEZ = 萊茲甘文 +LanguageTAB = 塔巴薩蘭文 +LanguageCYRILLIC = 西裏爾文 +LanguageHI = 印地文 +LanguageMR = 馬拉地文 +LanguageNE = 尼泊爾文 +LanguageBH = 比爾哈文 +LanguageMAI = 邁蒂利文 +LanguageANG = 昂加文 +LanguageBHO = 孟加拉文 +LanguageMAH = 摩揭陀文 +LanguageSCK = 那格浦爾文 +LanguageNEW = 尼瓦爾文 +LanguageGOM = 保加利亞文 +LanguageSA = 沙特阿拉伯文 +LanguageBGC = 哈裏亞納文 +LanguageDEVANAGARI = 德瓦那加裏文 +LanguageTA = 泰米爾文 +LanguageKN = 卡納達文 +LanguageTE = 泰盧固文 +LanguageKA = 卡納達文 + +[SubtitleExtractorGUI] +Title = 字幕提取器 +Open = 打開 +AllFile = 所有文件 +Vertical = 垂直方向 +Horizontal = 水平方向 +Run = 運行 +Setting = 設置 +OpenVideoSuccess = 成功打開視頻 +OpenVideoFirst = 請先打開視頻 +SubtitleArea = 字幕區域 + +[Main] +RecSubLang = 識別字幕語言 +RecMode = 識別模式 +IllegalPathWarning = 【警告】程序運行中斷!路徑不合法!請不要將程序放入帶有空格和中文的路徑下!!!請修改程序路徑名後重新運行程序 +GPUSpeedUp = 使用GPU進行加速 +FrameCount = 幀數 +FrameRate = 幀率 +StartProcessFrame = 【處理中】開啟提取視頻關鍵幀... +FinishProcessFrame = 【結束】提取視頻關鍵幀完畢... +StartFindSub = 【處理中】開始提取字幕信息,此步驟可能花費較長時間,請耐心等待... +FinishFindSub = 【結束】完成字幕提取,生成原始字幕文件... +StartDetectWaterMark = 【處理中】開始檢測並過濾水印區域內容 +checkWaterMark = 視頻是否存在水印區域,存在的話輸入y,不存在的話輸入n: +FinishDetectWaterMark = 【結束】已經成功過濾水印區域內容 +StartDeleteNonSub = 【處理中】開始檢測非字幕區域,並將非字幕區域的內容刪除 +FinishDeleteNonSub = 【結束】已將非字幕區域的內容刪除 +StartGenerateSub = 【處理中】開始生成字幕文件 +FinishGenerateSub = 【結束】字幕文件生成成功 +SubFrameNo = 字幕幀 +Elapse = 耗時 +ChooseSubArea = 請指定字幕區域 +WatchPicture = 請查看圖片, 確定水印區域 +QuestionDelete = 區域中的字幕是否去除? 輸入 "y" 或 "回車" 表示去除,輸入"n"或其他表示不去除: +FinishDelete = 已經刪除該區域字幕... +FinishWaterMarkFilter = 水印區域字幕過濾完畢... +CheckSubArea = 請查看圖片, 確定字幕區域是否正確: +DeleteNoSubArea = 紅色框區域外的字幕是否去除? 輸入 "y" 或 "回車" 表示去除,輸入"n"或其他表示不去除: +FinishDeleteNoSubArea = 去除完畢 +SubLocation = 字幕文件生成位置: +InputVideo = 請輸入視頻完整路徑: diff --git a/backend/interface/en.ini b/backend/interface/en.ini new file mode 100644 index 00000000..52fa70cb --- /dev/null +++ b/backend/interface/en.ini @@ -0,0 +1,139 @@ +[LanguageModeGUI] +Title = Subtitle Extractor +InterfaceLanguage = Choose Language: +SubtitleLanguage = Subtitle Language: +Mode = Choose Mode: +ModeAuto = auto +ModeFast = fast +ModeAccurate = accurate +InterfaceDefault = English +LanguageCH = Simplified Chinese +LanguageCHINESE_CHT = Traditional Chinese +LanguageEN = English +LanguageJAPAN = Japanese +LanguageKOREAN = Korean +LanguageAR = Arabic +LanguageFRENCH = French +LanguageGERMAN = German +LanguageRU = Russian +LanguageES = Spanish +LanguagePT = Portuguese +LanguageIT = Italian +LanguageAF = Afrikaans +LanguageAZ = Azerbaijani +LanguageBS = Bosnian +LanguageCS = Czech +LanguageCY = Welsh +LanguageDA = Danish +LanguageDE = German +LanguageET = Estonian +LanguageFR = French +LanguageGA = Irish +LanguageHR = Croatian +LanguageHU = Hungarian +LanguageID = Indonesian +LanguageIS = Icelandic +LanguageKU = Kurdish +LanguageLA = Latin +LanguageLT = Lithuanian +LanguageLV = Latvian +LanguageMI = Maori +LanguageMS = Malay +LanguageMT = Maltese +LanguageNL = Dutch +LanguageNO = Norwegian +LanguageOC = Occitan +LanguagePI = Pali +LanguagePL = Polish +LanguageRO = Romanian +LanguageRS_LATIN = Serbian(latin) +LanguageSK = Slovak +LanguageSL = Slovenian +LanguageSQ = Albanian +LanguageSV = Swedish +LanguageSW = Swahili +LanguageTL = Tagalog +LanguageTR = Turkish +LanguageUZ = Uzbek +LanguageVI = Vietnamese +LanguageLATIN = Latin +LanguageFA = Persian +LanguageUR = Urdu +LanguageRS_CYRILLIC = Serbian(cyrillic) +LanguageBE = Belarusian +LanguageBG = Bulgarian +LanguageUK = Ukranian +LanguageMN = Mongolian +LanguageABQ = Abaza +LanguageADY = Adyghe +LanguageKBD = Kabardian +LanguageAVA = Avar +LanguageDAR = Dargwa +LanguageINH = Ingush +LanguageCHE = Chechen +LanguageLBE = Lak +LanguageLEZ = Lezghian +LanguageTAB = Tabassaran +LanguageCYRILLIC = Cyrillic +LanguageHI = Hindi +LanguageMR = Marathi +LanguageNE = Nepali +LanguageBH = Bihari +LanguageMAI = Maithili +LanguageANG = Angika +LanguageBHO = Bhojpuri +LanguageMAH = Magahi +LanguageSCK = Nagpur +LanguageNEW = Newari +LanguageGOM = Goan Konkani +LanguageSA = Saudi Arabia +LanguageBGC = Haryanvi +LanguageDEVANAGARI = Devanagari +LanguageTA = Tamil +LanguageKN = Kannada +LanguageUG = Uyghur +LanguageTE = Telugu +LanguageKA = Kannada + +[SubtitleExtractorGUI] +Title = Subtitle Extractor +Open = Open +AllFile = All Files +Vertical = Vertical +Horizontal = Horizontal +Run = Run +Setting = Settings +OpenVideoSuccess = Successfully Open Video +OpenVideoFirst = Please Open Video First +SubtitleArea = Subtitle Area + +[Main] +RecSubLang = Subtitle Language +RecMode = Mode +IllegalPathWarning = [Warning] The program is interrupted! The path is illegal! Please do not put the program in a path with spaces and Chinese! ! ! Please modify the program path name and re-run the program +GPUSpeedUp = Use GPU for acceleration +FrameCount = Frame Count +FrameRate = Frame Rate +StartProcessFrame = [Processing] Start to extracting video keyframes... +FinishProcessFrame = [Finished] Finished extracting video key frames... +StartFindSub = [Processing] Start to extract subtitle information, this step may take a long time, please be patient... +FinishFindSub = [Finished] Finish subtitle extraction, generate original subtitle file... +StartDetectWaterMark = [Processing] Start to detect and filter watermark area +checkWaterMark = Whether there is a watermark area in the video, if it exists, enter "y", if it does not exist, enter "n": +FinishDetectWaterMark = [Finished] Watermark area has been successfully filtered +StartDeleteNonSub = [Processing] Start to detect the non-subtitle area and delete the content in the non-subtitle area +FinishDeleteNonSub = [Finished] Non-subtitle area has been deleted +StartGenerateSub = [Processing] Start generating subtitle files +FinishGenerateSub = [Finished] Subtitle file generated successfully +SubFrameNo = Subtitle frame +Elapse = elapse +ChooseSubArea = Please specify subtitle area +WatchPicture = Please check the picture to determine the watermark area +QuestionDelete = Whether to remove the subtitles in the area? Input "y" or "Enter" to remove, input "n" or other means not to remove: +FinishDelete = Subtitles in this area have been deleted... +FinishWaterMarkFilter = The subtitles in the watermark area are filtered... +CheckSubArea = Please check the picture to make sure the subtitle area is correct: +DeleteNoSubArea = Are the subtitles outside the red box area removed? Input "y" or "Enter" to remove, input "n" or other means not to remove: +FinishDeleteNoSubArea = Removed +SubLocation = Subtitle file generated at: +InputVideo = Please enter the full path of the video: diff --git a/backend/interface/es.ini b/backend/interface/es.ini new file mode 100644 index 00000000..ef846ea5 --- /dev/null +++ b/backend/interface/es.ini @@ -0,0 +1,139 @@ +[LanguageModeGUI] +Title = Extractor de subtítulos +InterfaceLanguage = Elija idioma: +SubtitleLanguage = Idioma de subtítulos: +Mode = Elija modo: +ModeAuto = automático +ModeFast = rápido +ModeAccurate = preciso +InterfaceDefault = Inglés +LanguageCH = Chino simplificado +LanguageCHINESE_CHT = Chino tradicional +LanguageEN = Inglés +LanguageJAPAN = Japonés +LanguageKOREAN = Coreano +LanguageAR = Árabe +LanguageFRENCH = Francés +LanguageGERMAN = Alemán +LanguageRU = Ruso +LanguageES = Español +LanguagePT = Portugués +LanguageIT = Italiano +LanguageAF = Afrikáans +LanguageAZ = Azerí +LanguageBS = Bosnio +LanguageCS = Checo +LanguageCY = Galés +LanguageDA = Danés +LanguageDE = Alemán +LanguageET = Estonio +LanguageFR = Francés +LanguageGA = Irlandés +LanguageHR = Croata +LanguageHU = Húngaro +LanguageID = Indonesio +LanguageIS = Islandés +LanguageKU = Kurdo +LanguageLA = Latín +LanguageLT = Lituano +LanguageLV = Letón +LanguageMI = Maorí +LanguageMS = Malay +LanguageMT = Maltés +LanguageNL = Neerlandés +LanguageNO = Noruego +LanguageOC = Occitano +LanguagePI = Pali +LanguagePL = Polaco +LanguageRO = Rumano +LanguageRS_LATIN = Serbio (latín) +LanguageSK = Eslovaco +LanguageSL = Esloveno +LanguageSQ = Albanés +LanguageSV = Sueco +LanguageSW = Swahili +LanguageTL = Tagalo +LanguageTR = Turco +LanguageUZ = Uzbeko +LanguageVI = Vietnamita +LanguageLATIN = Latín +LanguageFA = Persa +LanguageUR = Urdu +LanguageRS_CYRILLIC = Serbio (cirílico) +LanguageBE = Bielorruso +LanguageBG = Búlgaro +LanguageUK = Ucraniano +LanguageMN = Mongol +LanguageABQ = Abaza +LanguageADY = Adigueo +LanguageKBD = Kabardino +LanguageAVA = Avar +LanguageDAR = Dargwa +LanguageINH = Ingush +LanguageCHE = Checheno +LanguageLBE = Lak +LanguageLEZ = Lezgiano +LanguageTAB = Tabasarán +LanguageCYRILLIC = Cirílico +LanguageHI = Hindi +LanguageMR = Marathi +LanguageNE = Nepalí +LanguageBH = Bihari +LanguageMAI = Maithili +LanguageANG = Angika +LanguageBHO = Bhojpuri +LanguageMAH = Magahi +LanguageSCK = Nagpur +LanguageNEW = Newari +LanguageGOM = Concaní +LanguageSA = Arabia Saudita +LanguageBGC = Haryanvi +LanguageDEVANAGARI = Devanagari +LanguageTA = Tamil +LanguageKN = Kannada +LanguageUG = Uigur +LanguageTE = Telugu +LanguageKA = Kannada + +[SubtitleExtractorGUI] +Title = Extractor de subtítulos +Open = Abrir +AllFile = Todos los archivos +Vertical = Vertical +Horizontal = Horizontal +Run = Ejecutar +Setting = Configuraciones +OpenVideoSuccess = Se abrió el video exitosamente +OpenVideoFirst = Por favor, abra el video primero +SubtitleArea = Área de subtítulos + +[Main] +RecSubLang = Idioma de subtítulos +RecMode = Modo +IllegalPathWarning = [Advertencia] ¡El programa se interrumpió! ¡La ruta es ilegal! ¡No coloque el programa en una ruta con espacios y caracteres chinos! Por favor, modifique el nombre de ruta del programa y vuelva a ejecutarlo. +GPUSpeedUp = Use GPU para aceleración +FrameCount = Conteo de fotogramas +FrameRate = Velocidad de fotogramas +StartProcessFrame = [Procesamiento] Empieza a extraer fotogramas clave del video... +FinishProcessFrame = [Terminado] Se termino de extraer fotogramas clave del video... +StartFindSub = [Procesamiento] Empieza la extracción de información de subtítulos, este paso puede llevar bastante tiempo, por favor tenga paciencia... +FinishFindSub = [Terminado] Se termino la extracción de subtítulos, se generó el archivo de subtítulos original... +StartDetectWaterMark = [Procesamiento] Empieza la detección y filtración del área de marca de agua +checkWaterMark = Si hay un área de marca de agua en el video, si existe, escriba "s", si no, escriba "n": +FinishDetectWaterMark = [Terminado] El área de marca de agua ha sido filtrada exitosamente +StartDeleteNonSub = [Procesamiento] Empieza la detección del área sin subtítulos y se elimina el contenido de la misma +FinishDeleteNonSub = [Terminado] El área sin subtítulos ha sido eliminada +StartGenerateSub = [Procesamiento] Empieza la generación de archivos de subtítulos +FinishGenerateSub = [Terminado] Se generó el archivo de subtítulos exitosamente +SubFrameNo = Fotograma de subtítulo +Elapse = transcurrir +ChooseSubArea = Por favor, especifique el área de subtítulo +WatchPicture = Por favor, revise la imagen para determinar la area de marca de agua +QuestionDelete = ¿Desea quitar los subtítulos en esta area? Pulsar "s" o "Enter" para eliminar, pulsar "n" o cualquier otro to te significa no eliminar: +FinishDelete = Se han eliminado los subtítulos de esta área... +FinishWaterMarkFilter = Se han filtrado los subtítulos en la área de la marca de agua... +CheckSubArea = Por favor, revise la imagen para asegurarse que el área de subtítulo es correcta: +DeleteNoSubArea = ¿Se remueven los subtítulos fuera del área del cuadro rojo? Pulsar "s" o "Enter" para eliminar, pulsar "n" o cualquier otro to te significa no eliminar: +FinishDeleteNoSubArea = Eliminado +SubLocation = Archivo de subtítulos generado en: +InputVideo = Por favor, ingrese la ruta completa del video: diff --git a/backend/interface/japan.ini b/backend/interface/japan.ini new file mode 100644 index 00000000..8dde5c1e --- /dev/null +++ b/backend/interface/japan.ini @@ -0,0 +1,139 @@ +[LanguageModeGUI] +Title = サブタイトル抽出器 +InterfaceLanguage = 言語の選択: +SubtitleLanguage = サブタイトルの言語: +Mode = モードを選択: +ModeAuto = 自動 +ModeFast = 高速 +ModeAccurate = 正確 +InterfaceDefault = 英語 +LanguageCH = 簡体字中国語 +LanguageCHINESE_CHT = 繁体字中国語 +LanguageEN = 英語 +LanguageJAPAN = 日本語 +LanguageKOREAN = 韓国語 +LanguageAR = アラビア語 +LanguageFRENCH = フランス語 +LanguageGERMAN = ドイツ語 +LanguageRU = ロシア語 +LanguageES = スペイン語 +LanguagePT = ポルトガル語 +LanguageIT = イタリア語 +LanguageAF = アフリカーンス語 +LanguageAZ = アゼルバイジャン語 +LanguageBS = ボスニア語 +LanguageCS = チェコ語 +LanguageCY = ウェールズ語 +LanguageDA = デンマーク語 +LanguageDE = ドイツ語 +LanguageET = エストニア語 +LanguageFR = フランス語 +LanguageGA = アイルランド語 +LanguageHR = クロアチア語 +LanguageHU = ハンガリー語 +LanguageID = インドネシア語 +LanguageIS = アイスランド語 +LanguageKU = クルド語 +LanguageLA = ラテン語 +LanguageLT = リトアニア語 +LanguageLV = ラトビア語 +LanguageMI = マオリ語 +LanguageMS = マレー語 +LanguageMT = マルタ語 +LanguageNL = オランダ語 +LanguageNO = ノルウェー語 +LanguageOC = オック語 +LanguagePI = パーリ語 +LanguagePL = ポーランド語 +LanguageRO = ルーマニア語 +LanguageRS_LATIN = セルビア語(ラテン文字) +LanguageSK = スロバキア語 +LanguageSL = スロベニア語 +LanguageSQ = アルバニア語 +LanguageSV = スウェーデン語 +LanguageSW = スワヒリ語 +LanguageTL = タガログ語 +LanguageTR = トルコ語 +LanguageUZ = ウズベク語 +LanguageVI = ベトナム語 +LanguageLATIN = ラテン語 +LanguageFA = ペルシア語 +LanguageUR = ウルドゥ語 +LanguageRS_CYRILLIC = セルビア語(キリル文字) +LanguageBE = ベラルーシ語 +LanguageBG = ブルガリア語 +LanguageUK = ウクライナ語 +LanguageMN = モンゴル語 +LanguageABQ = アバザ語 +LanguageADY = アディゲ語 +LanguageKBD = カバルディン語 +LanguageAVA = アヴァール語 +LanguageDAR = ダルガン語 +LanguageINH = イングーシ語 +LanguageCHE = チェチェン語 +LanguageLBE = ラク語 +LanguageLEZ = レズギ語 +LanguageTAB = タバサラン語 +LanguageCYRILLIC = キリル文字 +LanguageHI = ヒンディー語 +LanguageMR = マラーティー語 +LanguageNE = ネパール語 +LanguageBH = ビハリ語 +LanguageMAI = マイティリー語 +LanguageANG = アンギカ語 +LanguageBHO = ボージプリー語 +LanguageMAH = マガヒー語 +LanguageSCK = ナグプール語 +LanguageNEW = ネワール語 +LanguageGOM = ゴアのコンカニ語 +LanguageSA = サウジアラビア +LanguageBGC = ハリヤンビ語 +LanguageDEVANAGARI = デーヴァナーガリー文字 +LanguageTA = タミル語 +LanguageKN = カンナダ語 +LanguageUG = ウイグル語 +LanguageTE = テルグ語 +LanguageKA = カンナダ語 + +[SubtitleExtractorGUI] +Title = サブタイトル抽出器 +Open = 開く +AllFile = 全てのファイル +Vertical = 垂直 +Horizontal = 水平 +Run = 実行 +Setting = 設定 +OpenVideoSuccess = ビデオが正常に開きました +OpenVideoFirst = 最初にビデオを開いてください +SubtitleArea = サブタイトル領域 + +[Main] +RecSubLang = サブタイトル言語 +RecMode = モード +IllegalPathWarning = 【注意】プログラムは中断されました! パスが不正です! プログラムをスペースや中国語が含まれるパスに置かないでください!!! プログラムのパス名を修正してプログラムを再実行してください +GPUSpeedUp = GPUを使用して加速します +FrameCount = フレーム数 +FrameRate = フレームレート +StartProcessFrame = 【処理中】ビデオのキーフレームの抽出を開始します… +FinishProcessFrame = 【完了】ビデオのキーフレームの抽出が終了しました… +StartFindSub = 【処理中】サブタイトル情報の抽出を開始します。このステップでは時間がかかる場合がありますので、お待ちください… +FinishFindSub = 【終了】サブタイトルの抽出が完了し、元のサブタイトルファイルを生成します… +StartDetectWaterMark = 【処理中】透かし領域を検出し、フィルタリングします +checkWaterMark = ビデオに透かし領域が存在するかどうか、存在する場合は"y"、存在しない場合は"n"を入力してください: +FinishDetectWaterMark = 【完了】透かし領域が正常にフィルタリングされました +StartDeleteNonSub = 【処理中】非字幕領域を検出し、非字幕領域の内容を削除します +FinishDeleteNonSub = 【完了】非字幕領域が削除されました +StartGenerateSub = 【処理中】字幕ファイルの生成を開始します +FinishGenerateSub = 【完了】字幕ファイルが正常に生成されました +SubFrameNo = 字幕フレーム +Elapse = 経過 +ChooseSubArea = サブタイトル領域を指定してください +WatchPicture = 写真をチェックして透かし領域を確認してください +QuestionDelete = この領域の字幕を削除しますか? "y"または"Enter"を入力して削除、"n"または他の手段を入力して削除しない: +FinishDelete = この領域の字幕は削除されました… +FinishWaterMarkFilter = 透かし領域の字幕がフィルタリングされました… +CheckSubArea = 写真を確認して字幕エリアが正しいことを確認してください: +DeleteNoSubArea = 赤い枠以外の領域の字幕を削除しますか? "y"または"Enter"を入力して削除、"n"または他の手段を入力して削除しない: +FinishDeleteNoSubArea = 削除済み +SubLocation = 生成された字幕ファイルの場所: +InputVideo = ビデオのフルパスを入力してください: diff --git a/backend/interface/ko.ini b/backend/interface/ko.ini new file mode 100644 index 00000000..2d0f9171 --- /dev/null +++ b/backend/interface/ko.ini @@ -0,0 +1,139 @@ +[LanguageModeGUI] +Title = Subtitle Extractor +InterfaceLanguage = 언어 선택: +SubtitleLanguage = 자막 언어: +Mode = 모드 선택: +ModeAuto = 자동적 인 +ModeFast = 빠름 +ModeAccurate = 정확함 +InterfaceDefault = 한국어 +LanguageCH = 중국어(간체) +LanguageCHINESE_CHT = 중국어(번체) +LanguageEN = 영어 +LanguageJAPAN = 일본어 +LanguageKOREAN = 한국어 +LanguageAR = 아랍어 +LanguageFRENCH = 프랑스어 +LanguageGERMAN = 독일어 +LanguageRU = 러시아어 +LanguageES = 스페인어 +LanguagePT = 포르투갈어 +LanguageIT = 이탈리아어 +LanguageAF = 아프리칸스어 +LanguageAZ = 아제르바이잔어 +LanguageBS = 보스니아어 +LanguageCS = 체코어 +LanguageCY = 웨일스어 +LanguageDA = 덴마크어 +LanguageDE = 독일어 +LanguageET = 에스토니아어 +LanguageFR = 프랑스어 +LanguageGA = 아일랜드어 +LanguageHR = 크로아티아어 +LanguageHU = 헝가리어 +LanguageID = 인도네시아어 +LanguageIS = 아이슬란드어 +LanguageKU = 쿠르드어 +LanguageLA = 라틴어 +LanguageLT = 리투아니아어 +LanguageLV = 라트비아어 +LanguageMI = 마오리어 +LanguageMS = 말레이어 +LanguageMT = 몰타어 +LanguageNL = 네덜란드어 +LanguageNO = 노르웨이어 +LanguageOC = 옥시타니아어 +LanguagePI = 팔리어 +LanguagePL = 폴란드어 +LanguageRO = 루마니아어 +LanguageRS_LATIN = 세르비아어(라틴어) +LanguageSK = 슬로바키아어 +LanguageSL = 슬로베니아어 +LanguageSQ = 알바니아어 +LanguageSV = 스웨덴어 +LanguageSW = 스와힐리어 +LanguageTL = 타갈로그어 +LanguageTR = 터키어 +LanguageUZ = 우즈베크어 +LanguageVI = 베트남어 +LanguageLATIN = 라틴어 +LanguageFA = 페르시아어 +LanguageUR = 우르두어 +LanguageRS_CYRILLIC = 세르비아어(키릴 문자) +LanguageBE = 벨라루스어 +LanguageBG = 불가리아어 +LanguageUK = 우크라이나어 +LanguageMN = 몽골어 +LanguageABQ = 아바자어 +LanguageADY = 아디게어 +LanguageKBD = 카바르다어 +LanguageAVA = 아바르어 +LanguageDAR = 다르기어 +LanguageINH = 인구시어 +LanguageCHE = 체첸어 +LanguageLBE = 라크어 +LanguageLEZ = 레즈기안어 +LanguageTAB = 타바사란어 +LanguageCYRILLIC = 키릴 문자 +LanguageHI = 힌디어 +LanguageMR = 마라티어 +LanguageNE = 네팔어 +LanguageBH = 비하르어 +LanguageMAI = 마이틸리어 +LanguageANG = 앙기카어 +LanguageBHO = 보즈푸리어 +LanguageMAH = 마가히어 +LanguageSCK = 나그푸르어 +LanguageNEW = 네와르어 +LanguageGOM = 불가리아어 +LanguageSA = 사우디아라비아어 +LanguageBGC = 하리아나어 +LanguageDEVANAGARI = 데바나가리 문자 +LanguageTA = 타밀어 +LanguageKN = 칸나다어 +LanguageUG = 위구르어 +LanguageTE = 텔루구어 +LanguageKA = 칸나다어 + +[SubtitleExtractorGUI] +Title = Subtitle Extractor +Open = 열기 +AllFile = 모든 파일 +Vertical = 세로 +Horizontal = 가로 +Run = 실행 +Setting = 설정 +OpenVideoSuccess = 비디오 열기 성공 +OpenVideoFirst = 비디오을 먼저 열어주세요 +SubtitleArea = 자막 영역 + +[Main] +RecSubLang = 자막 언어 인식 +RecMode = 인식 모드 +IllegalPathWarning = [경고] 프로그램이 중단되었습니다! 경로가 올바르지 않습니다! 공백과 한국어가 포함된 경로에 프로그램을 넣지 마세요! 경로 이름을 변경하고 프로그램을 다시 실행해주세요 +GPUSpeedUp = 가속을 위해 GPU 사용 +FrameCount = 프레임 수 +FrameRate = 프레임 속도 +StartProcessFrame = [처리 중] 비디오 키프레임 추출 시작... +FinishProcessFrame = [완료] 비디오 키 프레임 추출 완료... +StartFindSub = [처리 중] 자막 정보 추출 시작, 이 단계는 시간이 오래 걸릴 수 있으니 조금만 기다려주세요... +FinishFindSub = [완료] 자막 추출 완료, 원본 자막 파일 생성... +StartDetectWaterMark = [처리 중] 워터마크 영역 내용 검출 및 필터링 시작 +checkWaterMark = 워터마크 영역이 영상에 존재하는지 여부, 존재하면 y, 존재하지 않으면 n을 입력합니다: +FinishDetectWaterMark = [완료] 워터마크 영역의 내용이 성공적으로 필터링되었습니다 +StartDeleteNonSub = [처리 중] 비자막 영역 검출을 시작하고, 비자막 영역의 내용을 삭제합니다 +FinishDeleteNonSub = [완료] 비자막 영역이 삭제되었습니다 +StartGenerateSub = [처리 중] 자막 파일 생성 시작 +FinishGenerateSub = [완료] 자막 파일 생성 성공 +SubFrameNo = 자막 프레임 +Elapse = 경과 시간 +ChooseSubArea = 자막 영역을 지정해주세요. +WatchPicture = 이미지를 확인하고 워터마크 영역을 지정해주세요 +QuestionDelete = 해당 영역의 자막을 제거할까요? 제거하려면 "y" 또는 "Enter"를 입력하고, 제거하지 않으려면 "n" 또는 아무거나 입력합니다: +FinishDelete = 이 영역의 자막이 삭제되었습니다... +FinishWaterMarkFilter = 워터마크 영역이 필터링되었습니다... +CheckSubArea = 자막 영역이 올바른지 이미지를 확인해주세요. +DeleteNoSubArea = 빨간색 상자 영역 밖의 자막이 제거되었나요? 제거하려면 "y" 또는 "Enter"를 입력하고, 제거하지 않으려면 "n" 또는 아무거나 입력합니다: +FinishDeleteNoSubArea = 제거 완료 +SubLocation = 자막 파일이 생성된 위치: +InputVideo = 비디오의 전체 경로를 입력하세요: diff --git a/backend/interface/vi.ini b/backend/interface/vi.ini new file mode 100644 index 00000000..a34a059f --- /dev/null +++ b/backend/interface/vi.ini @@ -0,0 +1,139 @@ +[LanguageModeGUI] +Title = Trích xuất phụ đề +InterfaceLanguage = Chọn Ngôn ngữ: +SubtitleLanguage = Ngôn ngữ phụ đề: +Mode = Chọn chế độ: +ModeAuto = tự động +ModeFast = nhanh +ModeAccurate = chính xác +InterfaceDefault = Tiếng Anh +LanguageCH = Tiếng Trung giản thể +LanguageCHINESE_CHT = Tiếng Trung phồn thể +LanguageEN = Tiếng Anh +LanguageJAPAN = Tiếng Nhật +LanguageKOREAN = Tiếng Hàn +LanguageAR = Tiếng Ả Rập +LanguageFRENCH = Tiếng Pháp +LanguageGERMAN = Tiếng Đức +LanguageRU = Tiếng Nga +LanguageES = Tiếng Tây Ban Nha +LanguagePT = Tiếng Bồ Đào Nha +LanguageIT = Tiếng Ý +LanguageAF = Afrikaans +LanguageAZ = Azerbaijani +LanguageBS = Bosnian +LanguageCS = Tiếng Séc +LanguageCY = Welsh +LanguageDA = Danish +LanguageDE = Tiếng Đức +LanguageET = Estonian +LanguageFR = Tiếng Pháp +LanguageGA = Irish +LanguageHR = Croatian +LanguageHU = Hungarian +LanguageID = Tiếng Indonesia +LanguageIS = Icelandic +LanguageKU = Kurdish +LanguageLA = Tiếng La Mã +LanguageLT = Lithuanian +LanguageLV = Latvian +LanguageMI = Mã Lai +LanguageMS = Mã Lai +LanguageMT = Maltese +LanguageNL = Dutch +LanguageNO = Norwegian +LanguageOC = Occitan +LanguagePI = Pali +LanguagePL = Tiếng Ba Lan +LanguageRO = Tiếng Romania +LanguageRS_LATIN = Serbian (Latinh) +LanguageSK = Slovak +LanguageSL = Slovenian +LanguageSQ = Albanian +LanguageSV = Swedish +LanguageSW = Swahili +LanguageTL = Tagalog +LanguageTR = Tiếng Thổ Nhĩ Kỳ +LanguageUZ = Uzbek +LanguageVI = Tiếng Việt +LanguageLATIN = Tiếng La Mã +LanguageFA = Tiếng Ba Tư +LanguageUR = Tiếng Urdu +LanguageRS_CYRILLIC = Serbian (Cyrillic) +LanguageBE = Belarussian +LanguageBG = Bulgarian +LanguageUK = Ukraina +LanguageMN = Mongol +LanguageABQ = Abaza +LanguageADY = Adyghe +LanguageKBD = Kabardian +LanguageAVA = Avar +LanguageDAR = Dargwa +LanguageINH = Ingush +LanguageCHE = Chechen +LanguageLBE = Lak +LanguageLEZ = Lezghian +LanguageTAB = Tabassaran +LanguageCYRILLIC = Cyrillic +LanguageHI = Hindi +LanguageMR = Marathi +LanguageNE = Nepali +LanguageBH = Bihari +LanguageMAI = Maithili +LanguageANG = Angika +LanguageBHO = Bhojpuri +LanguageMAH = Magahi +LanguageSCK = Nagpur +LanguageNEW = Newari +LanguageGOM = Goan Konkani +LanguageSA = Ả Rập Xê Út +LanguageBGC = Haryanvi +LanguageDEVANAGARI = Devanagari +LanguageTA = Tamil +LanguageKN = Kannada +LanguageUG = Uyghur +LanguageTE = Telugu +LanguageKA = Kannada + +[SubtitleExtractorGUI] +Title = Trích xuất phụ đề +Open = Mở +AllFile = Tất cả tệp +Vertical = Thẳng đứng +Horizontal = Ngang +Run = Chạy +Setting = Cài đặt +OpenVideoSuccess = Mở video thành công +OpenVideoFirst = Xin mở video trước đã +SubtitleArea = Khu vực phụ đề + +[Main] +RecSubLang = Ngôn ngữ phụ đề +RecMode = Chế độ +IllegalPathWarning = [Cảnh báo] Chương trình bị gián đoạn! Đường dẫn không hợp lệ! Xin đừng để chương trình trong đường dẫn có dấu cách và tiếng Trung! ! ! Xin sửa tên đường dẫn chương trình và chạy lại chương trình +GPUSpeedUp = Sử dụng GPU để tăng tốc +FrameCount = Số khung hình +FrameRate = Tốc độ khung hình +StartProcessFrame = [Đang xử lý] Bắt đầu trích xuất keyframe video... +FinishProcessFrame = [Đã hoàn thành] Đã hoàn thành việc trích xuất key frame video... +StartFindSub = [Đang xử lý] Bắt đầu tìm thông tin phụ đề, bước này có thể mất nhiều thời gian, xin kiên nhẫn... +FinishFindSub = [Đã hoàn thành] Hoàn thành việc tìm kiếm phụ đề, sinh ra tệp nguyên bản của phụ đề... +StartDetectWaterMark = [Đang xử lý] Bắt đầu phát hiện và lọc khu vực watermark +checkWaterMark = Có watermark trong video không? Nếu có, nhập "y", nếu không có, nhập "n": +FinishDetectWaterMark = [Đã hoàn thành] Khu vực watermark đã được lọc thành công +StartDeleteNonSub = [Đang xử lý] Bắt đầu tìm khu vực không có phụ đề và xóa nội dung trong khu vực không có phụ đề +FinishDeleteNonSub = [Đã hoàn thành] Khu vực không có phụ đề đã được xóa +StartGenerateSub = [Đang xử lý] Bắt đầu tạo tệp phụ đề +FinishGenerateSub = [Đã hoàn thành] Tạo tệp phụ đề thành công +SubFrameNo = Khung hình phụ đề +Elapse = thời gian trôi qua +ChooseSubArea = Xin chỉ định khu vực phụ đề +WatchPicture = Xin kiểm tra hình ảnh để xác định khu vực watermark +QuestionDelete = Có xóa phụ đề trong khu vực này? Nhập "y" hoặc "Enter" để xóa, nhập "n" hoặc khác có nghĩa là không xóa: +FinishDelete = Phụ đề trong khu vực này đã được xóa... +FinishWaterMarkFilter = Các phụ đề trong khu vực watermark đã được lọc... +CheckSubArea = Xin kiểm tra hình ảnh để chắc chắn rằng khu vực phụ đề là đúng: +DeleteNoSubArea = Có xóa phụ đề ngoài khu vực hộp đỏ? Nhập "y" hoặc "Enter" để xóa, nhập "n" hoặc khác có nghĩa là không xóa: +FinishDeleteNoSubArea = Đã xóa +SubLocation = Tệp phụ đề được tạo tại: +InputVideo = Xin nhập đường dẫn đầy đủ của video: diff --git a/backend/main.py b/backend/main.py new file mode 100644 index 00000000..71ac4bff --- /dev/null +++ b/backend/main.py @@ -0,0 +1,1021 @@ +# -*- coding: utf-8 -*- +""" +@Author : Fang Yao +@Time : 2021/3/24 9:28 上午 +@FileName: main.py +@desc: 主程序入口文件 +""" +import os +import random +import shutil +from collections import Counter, namedtuple +import unicodedata +from threading import Thread +from pathlib import Path +import cv2 +from Levenshtein import ratio +from PIL import Image +from numpy import average, dot, linalg +from tqdm import tqdm +import sys + +sys.path.insert(0, os.path.dirname(__file__)) +import importlib +import config +from tools import reformat +from tools.infer import utility +from tools.infer.predict_det import TextDetector +from tools.ocr import OcrRecogniser, get_coordinates +from tools import subtitle_ocr +import threading +import platform +import multiprocessing +import time +import pysrt + + +class SubtitleDetect: + """ + 文本框检测类,用于检测视频帧中是否存在文本框 + """ + + def __init__(self): + # 获取参数对象 + importlib.reload(config) + args = utility.parse_args() + args.det_algorithm = 'DB' + args.det_model_dir = config.DET_MODEL_PATH + self.text_detector = TextDetector(args) + + def detect_subtitle(self, img): + dt_boxes, elapse = self.text_detector(img) + return dt_boxes, elapse + + +class SubtitleExtractor: + """ + 视频字幕提取类 + """ + + def __init__(self, vd_path, sub_area=None): + importlib.reload(config) + # 线程锁 + self.lock = threading.RLock() + # 用户指定的字幕区域位置 + self.sub_area = sub_area + # 创建字幕检测对象 + self.sub_detector = SubtitleDetect() + # 视频路径 + self.video_path = vd_path + self.video_cap = cv2.VideoCapture(vd_path) + # 通过视频路径获取视频名称 + self.vd_name = Path(self.video_path).stem + # 临时存储文件夹 + self.temp_output_dir = os.path.join(os.path.dirname(config.BASE_DIR), 'output', str(self.vd_name)) + # 视频帧总数 + self.frame_count = self.video_cap.get(cv2.CAP_PROP_FRAME_COUNT) + # 视频帧率 + self.fps = self.video_cap.get(cv2.CAP_PROP_FPS) + # 视频尺寸 + self.frame_height = int(self.video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + self.frame_width = int(self.video_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + # 用户未指定字幕区域时,默认字幕出现的区域 + self.default_subtitle_area = config.DEFAULT_SUBTITLE_AREA + # 提取的视频帧储存目录 + self.frame_output_dir = os.path.join(self.temp_output_dir, 'frames') + # 提取的字幕文件存储目录 + self.subtitle_output_dir = os.path.join(self.temp_output_dir, 'subtitle') + # 若目录不存在,则创建文件夹 + if not os.path.exists(self.frame_output_dir): + os.makedirs(self.frame_output_dir) + if not os.path.exists(self.subtitle_output_dir): + os.makedirs(self.subtitle_output_dir) + # 定义是否使用vsf提取字幕帧 + self.use_vsf = False + # 定义vsf的字幕输出路径 + self.vsf_subtitle = os.path.join(self.subtitle_output_dir, 'raw_vsf.srt') + # 提取的原始字幕文本存储路径 + self.raw_subtitle_path = os.path.join(self.subtitle_output_dir, 'raw.txt') + # 自定义ocr对象 + self.ocr = None + # 打印识别语言与识别模式 + print(f"{config.interface_config['Main']['RecSubLang']}:{config.REC_CHAR_TYPE}") + print(f"{config.interface_config['Main']['RecMode']}:{config.MODE_TYPE}") + # 如果使用GPU加速,则打印GPU加速提示 + if config.USE_GPU: + print(config.interface_config['Main']['GPUSpeedUp']) + # 总处理进度 + self.progress_total = 0 + # 视频帧提取进度 + self.progress_frame_extract = 0 + # OCR识别进度 + self.progress_ocr = 0 + # 是否完成 + self.isFinished = False + # 字幕OCR任务队列 + self.subtitle_ocr_task_queue = None + # 字幕OCR进度队列 + self.subtitle_ocr_progress_queue = None + # vsf运行状态 + self.vsf_running = False + + def run(self): + """ + 运行整个提取视频的步骤 + """ + # 记录开始运行的时间 + start_time = time.time() + self.lock.acquire() + # 重置进度条 + self.update_progress(ocr=0, frame_extract=0) + # 打印视频帧数与帧率 + print(f"{config.interface_config['Main']['FrameCount']}:{self.frame_count}" + f",{config.interface_config['Main']['FrameRate']}:{self.fps}") + # 打印加载模型信息 + print(f'{os.path.basename(os.path.dirname(config.DET_MODEL_PATH))}-{os.path.basename(config.DET_MODEL_PATH)}') + print(f'{os.path.basename(os.path.dirname(config.REC_MODEL_PATH))}-{os.path.basename(config.REC_MODEL_PATH)}') + # 打印视频帧提取开始提示 + print(config.interface_config['Main']['StartProcessFrame']) + # 创建一个字幕OCR识别进程 + subtitle_ocr_process = self.start_subtitle_ocr_async() + if self.sub_area is not None: + if platform.system() in ['Windows', 'Linux']: + # 使用GPU且使用accurate模式时才开放此方法: + if config.USE_GPU and config.MODE_TYPE == 'accurate': + self.extract_frame_by_det() + else: + self.extract_frame_by_vsf() + else: + self.extract_frame_by_fps() + else: + self.extract_frame_by_fps() + + # 往字幕OCR任务队列中,添加OCR识别任务结束标志 + # 任务格式为:(total_frame_count总帧数, current_frame_no当前帧, dt_box检测框, rec_res识别结果, 当前帧时间, subtitle_area字幕区域) + self.subtitle_ocr_task_queue.put((self.frame_count, -1, None, None, None, None)) + # 等待子线程完成 + subtitle_ocr_process.join() + # 打印完成提示 + print(config.interface_config['Main']['FinishProcessFrame']) + print(config.interface_config['Main']['FinishFindSub']) + + if self.sub_area is None: + print(config.interface_config['Main']['StartDetectWaterMark']) + # 询问用户视频是否有水印区域 + user_input = input(config.interface_config['Main']['checkWaterMark']).strip() + if user_input == 'y': + self.filter_watermark() + print(config.interface_config['Main']['FinishDetectWaterMark']) + else: + print('-----------------------------') + + if self.sub_area is None: + print(config.interface_config['Main']['StartDeleteNonSub']) + self.filter_scene_text() + print(config.interface_config['Main']['FinishDeleteNonSub']) + + # 打印开始字幕生成提示 + print(config.interface_config['Main']['StartGenerateSub']) + # 判断是否使用了vsf提取字幕 + if self.use_vsf: + # 如果使用了vsf提取字幕,则使用vsf的字幕生成方法 + self.generate_subtitle_file_vsf() + else: + # 如果未使用vsf提取字幕,则使用常规字幕生成方法 + self.generate_subtitle_file() + if config.WORD_SEGMENTATION: + reformat.execute(os.path.join(os.path.splitext(self.video_path)[0] + '.srt'), config.REC_CHAR_TYPE) + print(config.interface_config['Main']['FinishGenerateSub'], f"{round(time.time() - start_time, 2)}s") + self.update_progress(ocr=100, frame_extract=100) + self.isFinished = True + # 删除缓存文件 + self.empty_cache() + self.lock.release() + if config.GENERATE_TXT: + self.srt2txt(os.path.join(os.path.splitext(self.video_path)[0] + '.srt')) + + def extract_frame_by_fps(self): + """ + 根据帧率,定时提取视频帧,容易丢字幕,但速度快,将提取到的视频帧加入ocr识别任务队列 + """ + # 删除缓存 + self.__delete_frame_cache() + # 当前视频帧的帧号 + current_frame_no = 0 + while self.video_cap.isOpened(): + ret, frame = self.video_cap.read() + # 如果读取视频帧失败(视频读到最后一帧) + if not ret: + break + # 读取视频帧成功 + else: + current_frame_no += 1 + # subtitle_ocr_task_queue: (total_frame_count总帧数, current_frame_no当前帧, dt_box检测框, rec_res识别结果, 当前帧时间,subtitle_area字幕区域) + task = (self.frame_count, current_frame_no, None, None, None, self.default_subtitle_area) + self.subtitle_ocr_task_queue.put(task) + # 跳过剩下的帧 + for i in range(int(self.fps // config.EXTRACT_FREQUENCY) - 1): + ret, _ = self.video_cap.read() + if ret: + current_frame_no += 1 + # 更新进度条 + self.update_progress(frame_extract=(current_frame_no / self.frame_count) * 100) + + self.video_cap.release() + + def extract_frame_by_det(self): + """ + 通过检测字幕区域位置提取字幕帧 + """ + # 删除缓存 + self.__delete_frame_cache() + + # 当前视频帧的帧号 + current_frame_no = 0 + frame_lru_list = [] + frame_lru_list_max_size = 2 + ocr_args_list = [] + compare_ocr_result_cache = {} + tbar = tqdm(total=int(self.frame_count), unit='f', position=0, file=sys.__stdout__) + first_flag = True + is_finding_start_frame_no = False + is_finding_end_frame_no = False + start_frame_no = 0 + start_end_frame_no = [] + start_frame = None + if self.ocr is None: + self.ocr = OcrRecogniser() + while self.video_cap.isOpened(): + ret, frame = self.video_cap.read() + # 如果读取视频帧失败(视频读到最后一帧) + if not ret: + break + # 读取视频帧成功 + current_frame_no += 1 + tbar.update(1) + dt_boxes, elapse = self.sub_detector.detect_subtitle(frame) + has_subtitle = False + if self.sub_area is not None: + s_ymin, s_ymax, s_xmin, s_xmax = self.sub_area + coordinate_list = get_coordinates(dt_boxes.tolist()) + if coordinate_list: + for coordinate in coordinate_list: + xmin, xmax, ymin, ymax = coordinate + if (s_xmin <= xmin and xmax <= s_xmax + and s_ymin <= ymin + and ymax <= s_ymax): + has_subtitle = True + # 检测到字幕时,如果列表为空,则为字幕头 + if first_flag: + is_finding_start_frame_no = True + first_flag = False + break + else: + has_subtitle = len(dt_boxes) > 0 + # 检测到包含字幕帧的起始帧号与结束帧号 + if has_subtitle: + # 判断是字幕头还是尾 + if is_finding_start_frame_no: + start_frame_no = current_frame_no + dt_box, rec_res = self.ocr.predict(frame) + area_text1 = "".join(self.__get_area_text((dt_box, rec_res))) + if start_frame_no not in compare_ocr_result_cache.keys(): + compare_ocr_result_cache[current_frame_no] = {'text': area_text1, 'dt_box': dt_box, 'rec_res': rec_res} + frame_lru_list.append((frame, current_frame_no)) + ocr_args_list.append((self.frame_count, current_frame_no)) + # 缓存头帧 + start_frame = frame + # 开始找尾 + is_finding_start_frame_no = False + is_finding_end_frame_no = True + # 判断是否为最后一帧 + if is_finding_end_frame_no and current_frame_no == self.frame_count: + is_finding_end_frame_no = False + is_finding_start_frame_no = False + end_frame_no = current_frame_no + frame_lru_list.append((frame, current_frame_no)) + ocr_args_list.append((self.frame_count, current_frame_no)) + start_end_frame_no.append((start_frame_no, end_frame_no)) + # 如果在找结束帧的时候 + if is_finding_end_frame_no: + # 判断该帧与头帧ocr内容是否一致,若不一致则找到尾,尾巴为前一帧 + if not self._compare_ocr_result(compare_ocr_result_cache, None, start_frame_no, frame, current_frame_no): + is_finding_end_frame_no = False + is_finding_start_frame_no = True + end_frame_no = current_frame_no - 1 + frame_lru_list.append((start_frame, end_frame_no)) + ocr_args_list.append((self.frame_count, end_frame_no)) + start_end_frame_no.append((start_frame_no, end_frame_no)) + + else: + # 如果检测到字幕头后有没有字幕,则找到结尾,尾巴为前一帧 + if is_finding_end_frame_no: + end_frame_no = current_frame_no - 1 + is_finding_end_frame_no = False + is_finding_start_frame_no = True + frame_lru_list.append((start_frame, end_frame_no)) + ocr_args_list.append((self.frame_count, end_frame_no)) + start_end_frame_no.append((start_frame_no, end_frame_no)) + + while len(frame_lru_list) > frame_lru_list_max_size: + frame_lru_list.pop(0) + + # if len(start_end_frame_no) > 0: + # print(start_end_frame_no) + + while len(ocr_args_list) > 1: + total_frame_count, ocr_info_frame_no = ocr_args_list.pop(0) + if current_frame_no in compare_ocr_result_cache: + predict_result = compare_ocr_result_cache[current_frame_no] + dt_box, rec_res = predict_result['dt_box'], predict_result['rec_res'] + else: + dt_box, rec_res = None, None + # subtitle_ocr_task_queue: (total_frame_count总帧数, current_frame_no当前帧, dt_box检测框, rec_res识别结果, 当前帧时间, subtitle_area字幕区域) + task = (total_frame_count, ocr_info_frame_no, dt_box, rec_res, None, self.default_subtitle_area) + # 添加任务 + self.subtitle_ocr_task_queue.put(task) + self.update_progress(frame_extract=(current_frame_no / self.frame_count) * 100) + + while len(ocr_args_list) > 0: + total_frame_count, ocr_info_frame_no = ocr_args_list.pop(0) + if current_frame_no in compare_ocr_result_cache: + predict_result = compare_ocr_result_cache[current_frame_no] + dt_box, rec_res = predict_result['dt_box'], predict_result['rec_res'] + else: + dt_box, rec_res = None, None + task = (total_frame_count, ocr_info_frame_no, dt_box, rec_res, None, self.default_subtitle_area) + # 添加任务 + self.subtitle_ocr_task_queue.put(task) + self.video_cap.release() + + def extract_frame_by_vsf(self): + """ + 通过调用videoSubFinder获取字幕帧 + """ + self.use_vsf = True + + def count_process(): + duration_ms = (self.frame_count / self.fps) * 1000 + last_total_ms = 0 + processed_image = set() + rgb_images_path = os.path.join(self.temp_output_dir, 'RGBImages') + while self.vsf_running and not self.isFinished: + # 如果还没有rgb_images_path说明vsf还没处理完 + if not os.path.exists(rgb_images_path): + # 继续等待 + continue + try: + # 将列表按文件名排序 + rgb_images = sorted(os.listdir(rgb_images_path)) + for rgb_image in rgb_images: + # 如果当前图片已被处理,则跳过 + if rgb_image in processed_image: + continue + processed_image.add(rgb_image) + # 根据vsf生成的文件名读取时间 + h, m, s, ms = rgb_image.split('__')[0].split('_') + total_ms = int(ms) + int(s) * 1000 + int(m) * 60 * 1000 + int(h) * 60 * 60 * 1000 + if total_ms > last_total_ms: + frame_no = int(total_ms / self.fps) + task = (self.frame_count, frame_no, None, None, total_ms, self.default_subtitle_area) + self.subtitle_ocr_task_queue.put(task) + last_total_ms = total_ms + if total_ms / duration_ms >= 1: + self.update_progress(frame_extract=100) + return + else: + self.update_progress(frame_extract=(total_ms / duration_ms) * 100) + # 文件被清理了 + except FileNotFoundError: + return + + def vsf_output(out, ): + duration_ms = (self.frame_count / self.fps) * 1000 + last_total_ms = 0 + for line in iter(out.readline, b''): + line = line.decode("utf-8") + # print('line', line, type(line), line.startswith('Frame: ')) + if line.startswith('Frame: '): + line = line.replace("\n", "") + line = line.replace("Frame: ", "") + h, m, s, ms = line.split('__')[0].split('_') + total_ms = int(ms) + int(s) * 1000 + int(m) * 60 * 1000 + int(h) * 60 * 60 * 1000 + if total_ms > last_total_ms: + frame_no = int(total_ms / self.fps) + task = (self.frame_count, frame_no, None, None, total_ms, self.default_subtitle_area) + self.subtitle_ocr_task_queue.put(task) + last_total_ms = total_ms + if total_ms / duration_ms >= 1: + self.update_progress(frame_extract=100) + return + else: + self.update_progress(frame_extract=(total_ms / duration_ms) * 100) + else: + print(line.strip()) + out.close() + + # 删除缓存 + self.__delete_frame_cache() + # 定义videoSubFinder所在路径 + if platform.system() == 'Windows': + path_vsf = os.path.join(config.BASE_DIR, 'subfinder', 'windows', 'VideoSubFinderWXW.exe') + else: + path_vsf = os.path.join(config.BASE_DIR, 'subfinder', 'linux', 'VideoSubFinderCli.run') + os.chmod(path_vsf, 0o775) + # :图像上半部分所占百分比,取值【0-1】 + top_end = 1 - self.sub_area[0] / self.frame_height + # bottom_end:图像下半部分所占百分比,取值【0-1】 + bottom_end = 1 - self.sub_area[1] / self.frame_height + # left_end:图像左半部分所占百分比,取值【0-1】 + left_end = self.sub_area[2] / self.frame_width + # re:图像右半部分所占百分比,取值【0-1】 + right_end = self.sub_area[3] / self.frame_width + cpu_count = max(int(multiprocessing.cpu_count() * 2 / 3), 1) + if cpu_count < 4: + cpu_count = max(multiprocessing.cpu_count() - 1, 1) + if platform.system() == 'Windows': + # 定义执行命令 + cmd = f"{path_vsf} --use_cuda -c -r -i \"{self.video_path}\" -o \"{self.temp_output_dir}\" -ces \"{self.vsf_subtitle}\" " + cmd += f"-te {top_end} -be {bottom_end} -le {left_end} -re {right_end} -nthr {cpu_count} -nocrthr {cpu_count}" + self.vsf_running = True + # 计算进度 + Thread(target=count_process, daemon=True).start() + import subprocess + subprocess.run(cmd, shell=True, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + self.vsf_running = False + else: + # 定义执行命令 + cmd = f"{path_vsf} -c -r -i \"{self.video_path}\" -o \"{self.temp_output_dir}\" -ces \"{self.vsf_subtitle}\" " + if config.USE_GPU: + cmd += "--use_cuda " + cmd += f"-te {top_end} -be {bottom_end} -le {left_end} -re {right_end} -nthr {cpu_count} -dsi" + self.vsf_running = True + import subprocess + p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=1, + close_fds='posix' in sys.builtin_module_names, shell=True) + Thread(target=vsf_output, daemon=True, args=(p.stderr,)).start() + p.wait() + self.vsf_running = False + + def filter_watermark(self): + """ + 去除原始字幕文本中的水印区域的文本 + """ + # 获取潜在水印区域 + watermark_areas = self._detect_watermark_area() + + # 随机选择一帧, 将所水印区域标记出来,用户看图判断是否是水印区域 + cap = cv2.VideoCapture(self.video_path) + ret, sample_frame = False, None + for i in range(10): + frame_no = random.randint(int(self.frame_count * 0.1), int(self.frame_count * 0.9)) + cap.set(cv2.CAP_PROP_POS_FRAMES, frame_no) + ret, sample_frame = cap.read() + if ret: + break + cap.release() + + if not ret: + print("Error in filter_watermark: reading frame from video") + return + + # 给潜在的水印区域编号 + area_num = ['E', 'D', 'C', 'B', 'A'] + + for watermark_area in watermark_areas: + ymin = min(watermark_area[0][2], watermark_area[0][3]) + ymax = max(watermark_area[0][3], watermark_area[0][2]) + xmin = min(watermark_area[0][0], watermark_area[0][1]) + xmax = max(watermark_area[0][1], watermark_area[0][0]) + cover = sample_frame[ymin:ymax, xmin:xmax] + cover = cv2.blur(cover, (10, 10)) + cv2.rectangle(cover, pt1=(0, cover.shape[0]), pt2=(cover.shape[1], 0), color=(0, 0, 255), thickness=3) + sample_frame[ymin:ymax, xmin:xmax] = cover + position = ((xmin + xmax) // 2, ymax) + + cv2.putText(sample_frame, text=area_num.pop(), org=position, fontFace=cv2.FONT_HERSHEY_SIMPLEX, + fontScale=1, color=(255, 0, 0), thickness=2, lineType=cv2.LINE_AA) + + sample_frame_file_path = os.path.join(os.path.dirname(self.frame_output_dir), 'watermark_area.jpg') + cv2.imwrite(sample_frame_file_path, sample_frame) + print(f"{config.interface_config['Main']['WatchPicture']}: {sample_frame_file_path}") + + area_num = ['E', 'D', 'C', 'B', 'A'] + for watermark_area in watermark_areas: + user_input = input(f"{area_num.pop()}{str(watermark_area)} " + f"{config.interface_config['Main']['QuestionDelete']}").strip() + if user_input == 'y' or user_input == '\n': + with open(self.raw_subtitle_path, mode='r+', encoding='utf-8') as f: + content = f.readlines() + f.seek(0) + for i in content: + if i.find(str(watermark_area[0])) == -1: + f.write(i) + f.truncate() + print(config.interface_config['Main']['FinishDelete']) + print(config.interface_config['Main']['FinishWaterMarkFilter']) + # 删除缓存 + if os.path.exists(sample_frame_file_path): + os.remove(sample_frame_file_path) + + def filter_scene_text(self): + """ + 将场景里提取的文字过滤,仅保留字幕区域 + """ + # 获取潜在字幕区域 + subtitle_area = self._detect_subtitle_area()[0][0] + + # 随机选择一帧,将所水印区域标记出来,用户看图判断是否是水印区域 + cap = cv2.VideoCapture(self.video_path) + ret, sample_frame = False, None + for i in range(10): + frame_no = random.randint(int(self.frame_count * 0.1), int(self.frame_count * 0.9)) + cap.set(cv2.CAP_PROP_POS_FRAMES, frame_no) + ret, sample_frame = cap.read() + if ret: + break + cap.release() + + if not ret: + print("Error in filter_scene_text: reading frame from video") + return + + # 为了防止有双行字幕,根据容忍度,将字幕区域y范围加高 + ymin = abs(subtitle_area[0] - config.SUBTITLE_AREA_DEVIATION_PIXEL) + ymax = subtitle_area[1] + config.SUBTITLE_AREA_DEVIATION_PIXEL + # 画出字幕框的区域 + cv2.rectangle(sample_frame, pt1=(0, ymin), pt2=(sample_frame.shape[1], ymax), color=(0, 0, 255), thickness=3) + sample_frame_file_path = os.path.join(os.path.dirname(self.frame_output_dir), 'subtitle_area.jpg') + cv2.imwrite(sample_frame_file_path, sample_frame) + print(f"{config.interface_config['Main']['CheckSubArea']} {sample_frame_file_path}") + + user_input = input(f"{(ymin, ymax)} {config.interface_config['Main']['DeleteNoSubArea']}").strip() + if user_input == 'y' or user_input == '\n': + with open(self.raw_subtitle_path, mode='r+', encoding='utf-8') as f: + content = f.readlines() + f.seek(0) + for i in content: + i_ymin = int(i.split('\t')[1].split('(')[1].split(')')[0].split(', ')[2]) + i_ymax = int(i.split('\t')[1].split('(')[1].split(')')[0].split(', ')[3]) + if ymin <= i_ymin and i_ymax <= ymax: + f.write(i) + f.truncate() + print(config.interface_config['Main']['FinishDeleteNoSubArea']) + # 删除缓存 + if os.path.exists(sample_frame_file_path): + os.remove(sample_frame_file_path) + + def generate_subtitle_file(self): + """ + 生成srt格式的字幕文件 + """ + if not self.use_vsf: + subtitle_content = self._remove_duplicate_subtitle() + srt_filename = os.path.join(os.path.splitext(self.video_path)[0] + '.srt') + # 保存持续时间不足1秒的字幕行,用于后续处理 + post_process_subtitle = [] + with open(srt_filename, mode='w', encoding='utf-8') as f: + for index, content in enumerate(subtitle_content): + line_code = index + 1 + frame_start = self._frame_to_timecode(int(content[0])) + # 比较起始帧号与结束帧号, 如果字幕持续时间不足1秒,则将显示时间设为1s + if abs(int(content[1]) - int(content[0])) < self.fps: + frame_end = self._frame_to_timecode(int(int(content[0]) + self.fps)) + post_process_subtitle.append(line_code) + else: + frame_end = self._frame_to_timecode(int(content[1])) + frame_content = content[2] + subtitle_line = f'{line_code}\n{frame_start} --> {frame_end}\n{frame_content}\n' + f.write(subtitle_line) + print(f"[NO-VSF]{config.interface_config['Main']['SubLocation']} {srt_filename}") + # 返回持续时间低于1s的字幕行 + return post_process_subtitle + + def generate_subtitle_file_vsf(self): + if not self.use_vsf: + return + subs = pysrt.open(self.vsf_subtitle) + sub_no_map = {} + for sub in subs: + sub.start.no = self._timestamp_to_frameno(sub.start.ordinal) + sub_no_map[sub.start.no] = sub + + subtitle_content = self._remove_duplicate_subtitle() + subtitle_content_start_map = {int(a[0]): a for a in subtitle_content} + final_subtitles = [] + for sub in subs: + found = sub.start.no in subtitle_content_start_map + if found: + subtitle_content_line = subtitle_content_start_map[sub.start.no] + sub.text = subtitle_content_line[2] + end_no = int(subtitle_content_line[1]) + sub.end = sub_no_map[end_no].end if end_no in sub_no_map else sub.end + sub.index = len(final_subtitles) + 1 + final_subtitles.append(sub) + + if not found and not config.DELETE_EMPTY_TIMESTAMP: + # 保留时间轴 + sub.text = "" + sub.index = len(final_subtitles) + 1 + final_subtitles.append(sub) + continue + + srt_filename = os.path.join(os.path.splitext(self.video_path)[0] + '.srt') + pysrt.SubRipFile(final_subtitles).save(srt_filename, encoding='utf-8') + print(f"[VSF]{config.interface_config['Main']['SubLocation']} {srt_filename}") + + def _detect_watermark_area(self): + """ + 根据识别出来的raw txt文件中的坐标点信息,查找水印区域 + 假定:水印区域(台标)的坐标在水平和垂直方向都是固定的,也就是具有(xmin, xmax, ymin, ymax)相对固定 + 根据坐标点信息,进行统计,将一直具有固定坐标的文本区域选出 + :return 返回最有可能的水印区域 + """ + f = open(self.raw_subtitle_path, mode='r', encoding='utf-8') # 打开txt文件,以‘utf-8’编码读取 + line = f.readline() # 以行的形式进行读取文件 + # 坐标点列表 + coordinates_list = [] + # 帧列表 + frame_no_list = [] + # 内容列表 + content_list = [] + while line: + frame_no = line.split('\t')[0] + text_position = line.split('\t')[1].split('(')[1].split(')')[0].split(', ') + content = line.split('\t')[2] + frame_no_list.append(frame_no) + coordinates_list.append((int(text_position[0]), + int(text_position[1]), + int(text_position[2]), + int(text_position[3]))) + content_list.append(content) + line = f.readline() + f.close() + # 将坐标列表的相似值统一 + coordinates_list = self._unite_coordinates(coordinates_list) + + # 将原txt文件的坐标更新为归一后的坐标 + with open(self.raw_subtitle_path, mode='w', encoding='utf-8') as f: + for frame_no, coordinate, content in zip(frame_no_list, coordinates_list, content_list): + f.write(f'{frame_no}\t{coordinate}\t{content}') + + if len(Counter(coordinates_list).most_common()) > config.WATERMARK_AREA_NUM: + # 读取配置文件,返回可能为水印区域的坐标列表 + return Counter(coordinates_list).most_common(config.WATERMARK_AREA_NUM) + else: + # 不够则有几个返回几个 + return Counter(coordinates_list).most_common() + + def _detect_subtitle_area(self): + """ + 读取过滤水印区域后的raw txt文件,根据坐标信息,查找字幕区域 + 假定:字幕区域在y轴上有一个相对固定的坐标范围,相对于场景文本,这个范围出现频率更高 + :return 返回字幕的区域位置 + """ + # 打开去水印区域处理过的raw txt + f = open(self.raw_subtitle_path, mode='r', encoding='utf-8') # 打开txt文件,以‘utf-8’编码读取 + line = f.readline() # 以行的形式进行读取文件 + # y坐标点列表 + y_coordinates_list = [] + while line: + text_position = line.split('\t')[1].split('(')[1].split(')')[0].split(', ') + y_coordinates_list.append((int(text_position[2]), int(text_position[3]))) + line = f.readline() + f.close() + return Counter(y_coordinates_list).most_common(1) + + def _frame_to_timecode(self, frame_no): + """ + 将视频帧转换成时间 + :param frame_no: 视频的帧号,i.e. 第几帧视频帧 + :returns: SMPTE格式时间戳 as string, 如'01:02:12:032' 或者 '01:02:12;032' + """ + # 设置当前帧号 + cap = cv2.VideoCapture(self.video_path) + cap.set(cv2.CAP_PROP_POS_FRAMES, frame_no) + ret, _ = cap.read() + # 获取当前帧号对应的时间戳 + if ret: + milliseconds = cap.get(cv2.CAP_PROP_POS_MSEC) + if milliseconds <= 0: + return '{0:02d}:{1:02d}:{2:02d},{3:03d}'.format(int(frame_no / (3600 * self.fps)), + int(frame_no / (60 * self.fps) % 60), + int(frame_no / self.fps % 60), + int(frame_no % self.fps)) + seconds = milliseconds // 1000 + milliseconds = int(milliseconds % 1000) + minutes = 0 + hours = 0 + if seconds >= 60: + minutes = int(seconds // 60) + seconds = int(seconds % 60) + if minutes >= 60: + hours = int(minutes // 60) + minutes = int(minutes % 60) + smpte_token = ',' + cap.release() + return "%02d:%02d:%02d%s%03d" % (hours, minutes, seconds, smpte_token, milliseconds) + else: + return '{0:02d}:{1:02d}:{2:02d},{3:03d}'.format(int(frame_no / (3600 * self.fps)), + int(frame_no / (60 * self.fps) % 60), + int(frame_no / self.fps % 60), + int(frame_no % self.fps)) + + def _timestamp_to_frameno(self, time_ms): + return int(time_ms / self.fps) + + def _frameno_to_milliseconds(self, frame_no): + return float(int(frame_no / self.fps * 1000)) + + def _remove_duplicate_subtitle(self): + """ + 读取原始的raw txt,去除重复行,返回去除了重复后的字幕列表 + """ + self._concat_content_with_same_frameno() + with open(self.raw_subtitle_path, mode='r', encoding='utf-8') as r: + lines = r.readlines() + RawInfo = namedtuple('RawInfo', 'no content') + content_list = [] + for line in lines: + frame_no = line.split('\t')[0] + content = line.split('\t')[2] + content_list.append(RawInfo(frame_no, content)) + # 去重后的字幕列表 + unique_subtitle_list = [] + idx_i = 0 + content_list_len = len(content_list) + # 循环遍历每行字幕,记录开始时间与结束时间 + while idx_i < content_list_len: + i = content_list[idx_i] + start_frame = i.no + idx_j = idx_i + while idx_j < content_list_len: + # 计算当前行与下一行的Levenshtein距离 + # 判决idx_j的下一帧是否与idx_i不同,若不同(或者是最后一帧)则找到结束帧 + if idx_j + 1 == content_list_len or ratio(i.content.replace(' ', ''), content_list[idx_j + 1].content.replace(' ', '')) < config.THRESHOLD_TEXT_SIMILARITY: + # 若找到终点帧,定义字幕结束帧帧号 + end_frame = content_list[idx_j].no + if not self.use_vsf: + if end_frame == start_frame and idx_j + 1 < content_list_len: + # 针对只有一帧的情况,以下一帧的开始时间为准(除非是最后一帧) + end_frame = content_list[idx_j + 1][0] + # 寻找最长字幕 + similar_list = content_list[idx_i:idx_j + 1] + similar_content_strip_list = [item.content.replace(' ', '') for item in similar_list] + index, _ = max(enumerate(similar_content_strip_list), key=lambda x: len(x[1])) + + # 添加进列表 + unique_subtitle_list.append((start_frame, end_frame, similar_list[index].content)) + idx_i = idx_j + 1 + break + else: + idx_j += 1 + continue + return unique_subtitle_list + + def _concat_content_with_same_frameno(self): + """ + 将raw txt文本中具有相同帧号的字幕行合并 + """ + with open(self.raw_subtitle_path, mode='r', encoding='utf-8') as r: + lines = r.readlines() + content_list = [] + frame_no_list = [] + for line in lines: + frame_no = line.split('\t')[0] + frame_no_list.append(frame_no) + coordinate = line.split('\t')[1] + content = line.split('\t')[2] + content_list.append([frame_no, coordinate, content]) + + # 找出那些不止一行的帧号 + frame_no_list = [i[0] for i in Counter(frame_no_list).most_common() if i[1] > 1] + + # 找出这些帧号出现的位置 + concatenation_list = [] + for frame_no in frame_no_list: + position = [i for i, x in enumerate(content_list) if x[0] == frame_no] + concatenation_list.append((frame_no, position)) + + for i in concatenation_list: + content = [] + for j in i[1]: + content.append(content_list[j][2]) + content = ' '.join(content).replace('\n', ' ') + '\n' + for k in i[1]: + content_list[k][2] = content + + # 将多余的字幕行删除 + to_delete = [] + for i in concatenation_list: + for j in i[1][1:]: + to_delete.append(content_list[j]) + for i in to_delete: + if i in content_list: + content_list.remove(i) + + with open(self.raw_subtitle_path, mode='w', encoding='utf-8') as f: + for frame_no, coordinate, content in content_list: + content = unicodedata.normalize('NFKC', content) + f.write(f'{frame_no}\t{coordinate}\t{content}') + + def _unite_coordinates(self, coordinates_list): + """ + 给定一个坐标列表,将这个列表中相似的坐标统一为一个值 + e.g. 由于检测框检测的结果不是一致的,相同位置文字的坐标可能一次检测为(255,123,456,789),另一次检测为(253,122,456,799) + 因此要对相似的坐标进行值的统一 + :param coordinates_list 包含坐标点的列表 + :return: 返回一个统一值后的坐标列表 + """ + # 将相似的坐标统一为一个 + index = 0 + for coordinate in coordinates_list: # TODO:时间复杂度n^2,待优化 + for i in coordinates_list: + if self.__is_coordinate_similar(coordinate, i): + coordinates_list[index] = i + index += 1 + return coordinates_list + + def _compute_image_similarity(self, image1, image2): + """ + 计算两张图片的余弦相似度 + """ + image1 = self.__get_thum(image1) + image2 = self.__get_thum(image2) + images = [image1, image2] + vectors = [] + norms = [] + for image in images: + vector = [] + for pixel_tuple in image.getdata(): + vector.append(average(pixel_tuple)) + vectors.append(vector) + # linalg=linear(线性)+algebra(代数),norm则表示范数 + # 求图片的范数 + norms.append(linalg.norm(vector, 2)) + a, b = vectors + a_norm, b_norm = norms + # dot返回的是点积,对二维数组(矩阵)进行计算 + res = dot(a / a_norm, b / b_norm) + return res + + def __get_area_text(self, ocr_result): + """ + 获取字幕区域内的文本内容 + """ + box, text = ocr_result + coordinates = get_coordinates(box) + area_text = [] + for content, coordinate in zip(text, coordinates): + if self.sub_area is not None: + s_ymin = self.sub_area[0] + s_ymax = self.sub_area[1] + s_xmin = self.sub_area[2] + s_xmax = self.sub_area[3] + xmin = coordinate[0] + xmax = coordinate[1] + ymin = coordinate[2] + ymax = coordinate[3] + if s_xmin <= xmin and xmax <= s_xmax and s_ymin <= ymin and ymax <= s_ymax: + area_text.append(content[0]) + return area_text + + def _compare_ocr_result(self, result_cache, img1, img1_no, img2, img2_no): + """ + 比较两张图片预测出的字幕区域文本是否相同 + """ + if self.ocr is None: + self.ocr = OcrRecogniser() + if img1_no in result_cache: + area_text1 = result_cache[img1_no]['text'] + else: + dt_box, rec_res = self.ocr.predict(img1) + area_text1 = "".join(self.__get_area_text((dt_box, rec_res))) + result_cache[img1_no] = {'text': area_text1, 'dt_box': dt_box, 'rec_res': rec_res} + + if img2_no in result_cache: + area_text2 = result_cache[img2_no]['text'] + else: + dt_box, rec_res = self.ocr.predict(img2) + area_text2 = "".join(self.__get_area_text((dt_box, rec_res))) + result_cache[img2_no] = {'text': area_text2, 'dt_box': dt_box, 'rec_res': rec_res} + delete_no_list = [] + for no in result_cache: + if no < min(img1_no, img2_no) - 10: + delete_no_list.append(no) + for no in delete_no_list: + del result_cache[no] + if ratio(area_text1, area_text2) > config.THRESHOLD_TEXT_SIMILARITY: + return True + else: + return False + + @staticmethod + def __is_coordinate_similar(coordinate1, coordinate2): + """ + 计算两个坐标是否相似,如果两个坐标点的xmin,xmax,ymin,ymax的差值都在像素点容忍度内 + 则认为这两个坐标点相似 + """ + return abs(coordinate1[0] - coordinate2[0]) < config.PIXEL_TOLERANCE_X and \ + abs(coordinate1[1] - coordinate2[1]) < config.PIXEL_TOLERANCE_X and \ + abs(coordinate1[2] - coordinate2[2]) < config.PIXEL_TOLERANCE_Y and \ + abs(coordinate1[3] - coordinate2[3]) < config.PIXEL_TOLERANCE_Y + + @staticmethod + def __get_thum(image, size=(64, 64), greyscale=False): + """ + 对图片进行统一化处理 + """ + # 利用image对图像大小重新设置, Image.ANTIALIAS为高质量的 + image = image.resize(size, Image.ANTIALIAS) + if greyscale: + # 将图片转换为L模式,其为灰度图,其每个像素用8个bit表示 + image = image.convert('L') + return image + + def __delete_frame_cache(self): + if not config.DEBUG_NO_DELETE_CACHE: + if len(os.listdir(self.frame_output_dir)) > 0: + for i in os.listdir(self.frame_output_dir): + os.remove(os.path.join(self.frame_output_dir, i)) + + def empty_cache(self): + """ + 删除字幕提取过程中所有生产的缓存文件 + """ + if not config.DEBUG_NO_DELETE_CACHE: + if os.path.exists(self.temp_output_dir): + shutil.rmtree(self.temp_output_dir, True) + + def update_progress(self, ocr=None, frame_extract=None): + """ + 更新进度条 + :param ocr ocr进度 + :param frame_extract 视频帧提取进度 + """ + if ocr is not None: + self.progress_ocr = ocr + if frame_extract is not None: + self.progress_frame_extract = frame_extract + self.progress_total = (self.progress_frame_extract + self.progress_ocr) / 2 + + def start_subtitle_ocr_async(self): + def get_ocr_progress(): + """ + 获取ocr识别进度 + """ + # 获取视频总帧数 + total_frame_count = self.frame_count + # 是否打印提示开始查找字幕的信息 + notify = True + while True: + current_frame_no = self.subtitle_ocr_progress_queue.get(block=True) + if notify: + print(config.interface_config['Main']['StartFindSub']) + notify = False + self.update_progress( + ocr=100 if current_frame_no == -1 else (current_frame_no / total_frame_count * 100)) + # print(f'recv total_ms:{total_ms}') + if current_frame_no == -1: + return + + process, task_queue, progress_queue = subtitle_ocr.async_start(self.video_path, + self.raw_subtitle_path, + self.sub_area, + options={'REC_CHAR_TYPE': config.REC_CHAR_TYPE, + 'DROP_SCORE': config.DROP_SCORE, + 'SUB_AREA_DEVIATION_RATE': config.SUB_AREA_DEVIATION_RATE, + 'DEBUG_OCR_LOSS': config.DEBUG_OCR_LOSS, + } + ) + self.subtitle_ocr_task_queue = task_queue + self.subtitle_ocr_progress_queue = progress_queue + # 开启线程负责更新OCR进度 + Thread(target=get_ocr_progress, daemon=True).start() + return process + + @staticmethod + def srt2txt(srt_file): + subs = pysrt.open(srt_file, encoding='utf-8') + output_path = os.path.join(os.path.dirname(srt_file), Path(srt_file).stem + '.txt') + print(output_path) + with open(output_path, 'w') as f: + for sub in subs: + f.write(f'{sub.text}\n') + + +if __name__ == '__main__': + multiprocessing.set_start_method("spawn") + # 提示用户输入视频路径 + video_path = input(f"{config.interface_config['Main']['InputVideo']}").strip() + # 提示用户输入字幕区域 + try: + y_min, y_max, x_min, x_max = map(int, input( + f"{config.interface_config['Main']['ChooseSubArea']} (ymin ymax xmin xmax):").split()) + subtitle_area = (y_min, y_max, x_min, x_max) + except ValueError as e: + subtitle_area = None + # 新建字幕提取对象 + se = SubtitleExtractor(video_path, subtitle_area) + # 开始提取字幕 + se.run() diff --git a/backend/models/V2/ch_det/inference.pdiparams b/backend/models/V2/ch_det/inference.pdiparams new file mode 100644 index 00000000..cb82f0a3 Binary files /dev/null and b/backend/models/V2/ch_det/inference.pdiparams differ diff --git a/backend/models/V2/ch_det/inference.pdiparams.info b/backend/models/V2/ch_det/inference.pdiparams.info new file mode 100644 index 00000000..04521aac Binary files /dev/null and b/backend/models/V2/ch_det/inference.pdiparams.info differ diff --git a/backend/models/V2/ch_det/inference.pdmodel b/backend/models/V2/ch_det/inference.pdmodel new file mode 100644 index 00000000..73794b44 Binary files /dev/null and b/backend/models/V2/ch_det/inference.pdmodel differ diff --git a/backend/models/V2/ch_rec/fs_manifest.csv b/backend/models/V2/ch_rec/fs_manifest.csv new file mode 100644 index 00000000..20f5f22d --- /dev/null +++ b/backend/models/V2/ch_rec/fs_manifest.csv @@ -0,0 +1,4 @@ +filename,filesize,encoding,header +inference_1.pdiparams,50000000,, +inference_2.pdiparams,50000000,, +inference_3.pdiparams,11517375,, diff --git a/backend/models/V2/ch_rec/inference.pdiparams.info b/backend/models/V2/ch_rec/inference.pdiparams.info new file mode 100644 index 00000000..bd7148fb Binary files /dev/null and b/backend/models/V2/ch_rec/inference.pdiparams.info differ diff --git a/backend/models/V2/ch_rec/inference.pdmodel b/backend/models/V2/ch_rec/inference.pdmodel new file mode 100644 index 00000000..2eb52972 Binary files /dev/null and b/backend/models/V2/ch_rec/inference.pdmodel differ diff --git a/backend/models/V2/ch_rec/inference_1.pdiparams b/backend/models/V2/ch_rec/inference_1.pdiparams new file mode 100644 index 00000000..02ef6c03 Binary files /dev/null and b/backend/models/V2/ch_rec/inference_1.pdiparams differ diff --git a/backend/models/V2/ch_rec/inference_2.pdiparams b/backend/models/V2/ch_rec/inference_2.pdiparams new file mode 100644 index 00000000..03ba1a0d Binary files /dev/null and b/backend/models/V2/ch_rec/inference_2.pdiparams differ diff --git a/backend/models/V2/ch_rec/inference_3.pdiparams b/backend/models/V2/ch_rec/inference_3.pdiparams new file mode 100644 index 00000000..66ddde2f Binary files /dev/null and b/backend/models/V2/ch_rec/inference_3.pdiparams differ diff --git a/backend/models/V3/ar_rec_fast/inference.pdiparams b/backend/models/V3/ar_rec_fast/inference.pdiparams new file mode 100644 index 00000000..50b9f1fa Binary files /dev/null and b/backend/models/V3/ar_rec_fast/inference.pdiparams differ diff --git a/backend/models/V3/ar_rec_fast/inference.pdiparams.info b/backend/models/V3/ar_rec_fast/inference.pdiparams.info new file mode 100644 index 00000000..1cdccfce Binary files /dev/null and b/backend/models/V3/ar_rec_fast/inference.pdiparams.info differ diff --git a/backend/models/V3/ar_rec_fast/inference.pdmodel b/backend/models/V3/ar_rec_fast/inference.pdmodel new file mode 100644 index 00000000..8426633e Binary files /dev/null and b/backend/models/V3/ar_rec_fast/inference.pdmodel differ diff --git a/backend/models/V3/arabic_rec_fast/inference.pdiparams b/backend/models/V3/arabic_rec_fast/inference.pdiparams new file mode 100644 index 00000000..5c251e31 Binary files /dev/null and b/backend/models/V3/arabic_rec_fast/inference.pdiparams differ diff --git a/backend/models/V3/arabic_rec_fast/inference.pdiparams.info b/backend/models/V3/arabic_rec_fast/inference.pdiparams.info new file mode 100644 index 00000000..1cdccfce Binary files /dev/null and b/backend/models/V3/arabic_rec_fast/inference.pdiparams.info differ diff --git a/backend/models/V3/arabic_rec_fast/inference.pdmodel b/backend/models/V3/arabic_rec_fast/inference.pdmodel new file mode 100644 index 00000000..44109545 Binary files /dev/null and b/backend/models/V3/arabic_rec_fast/inference.pdmodel differ diff --git a/backend/models/V3/ch_det_fast/inference.pdiparams b/backend/models/V3/ch_det_fast/inference.pdiparams new file mode 100644 index 00000000..b33c3b0c Binary files /dev/null and b/backend/models/V3/ch_det_fast/inference.pdiparams differ diff --git a/backend/models/V3/ch_det_fast/inference.pdiparams.info b/backend/models/V3/ch_det_fast/inference.pdiparams.info new file mode 100644 index 00000000..622d87ba Binary files /dev/null and b/backend/models/V3/ch_det_fast/inference.pdiparams.info differ diff --git a/backend/models/V3/ch_det_fast/inference.pdmodel b/backend/models/V3/ch_det_fast/inference.pdmodel new file mode 100644 index 00000000..d09eb790 Binary files /dev/null and b/backend/models/V3/ch_det_fast/inference.pdmodel differ diff --git a/backend/models/V3/ch_rec_fast/inference.pdiparams b/backend/models/V3/ch_rec_fast/inference.pdiparams new file mode 100644 index 00000000..5eff5394 Binary files /dev/null and b/backend/models/V3/ch_rec_fast/inference.pdiparams differ diff --git a/backend/models/V3/ch_rec_fast/inference.pdiparams.info b/backend/models/V3/ch_rec_fast/inference.pdiparams.info new file mode 100644 index 00000000..9e133bfd Binary files /dev/null and b/backend/models/V3/ch_rec_fast/inference.pdiparams.info differ diff --git a/backend/models/V3/ch_rec_fast/inference.pdmodel b/backend/models/V3/ch_rec_fast/inference.pdmodel new file mode 100644 index 00000000..585fcdc6 Binary files /dev/null and b/backend/models/V3/ch_rec_fast/inference.pdmodel differ diff --git a/backend/models/V3/chinese_cht_rec_fast/inference.pdiparams b/backend/models/V3/chinese_cht_rec_fast/inference.pdiparams new file mode 100644 index 00000000..1b1ea1a9 Binary files /dev/null and b/backend/models/V3/chinese_cht_rec_fast/inference.pdiparams differ diff --git a/backend/models/V3/chinese_cht_rec_fast/inference.pdiparams.info b/backend/models/V3/chinese_cht_rec_fast/inference.pdiparams.info new file mode 100644 index 00000000..1cdccfce Binary files /dev/null and b/backend/models/V3/chinese_cht_rec_fast/inference.pdiparams.info differ diff --git a/backend/models/V3/chinese_cht_rec_fast/inference.pdmodel b/backend/models/V3/chinese_cht_rec_fast/inference.pdmodel new file mode 100644 index 00000000..95acc925 Binary files /dev/null and b/backend/models/V3/chinese_cht_rec_fast/inference.pdmodel differ diff --git a/backend/models/V3/cyrillic_rec_fast/inference.pdiparams b/backend/models/V3/cyrillic_rec_fast/inference.pdiparams new file mode 100644 index 00000000..5cde6bed Binary files /dev/null and b/backend/models/V3/cyrillic_rec_fast/inference.pdiparams differ diff --git a/backend/models/V3/cyrillic_rec_fast/inference.pdiparams.info b/backend/models/V3/cyrillic_rec_fast/inference.pdiparams.info new file mode 100644 index 00000000..1cdccfce Binary files /dev/null and b/backend/models/V3/cyrillic_rec_fast/inference.pdiparams.info differ diff --git a/backend/models/V3/cyrillic_rec_fast/inference.pdmodel b/backend/models/V3/cyrillic_rec_fast/inference.pdmodel new file mode 100644 index 00000000..dfa47a12 Binary files /dev/null and b/backend/models/V3/cyrillic_rec_fast/inference.pdmodel differ diff --git a/backend/models/V3/devanagari_rec_fast/inference.pdiparams b/backend/models/V3/devanagari_rec_fast/inference.pdiparams new file mode 100644 index 00000000..b5dddba4 Binary files /dev/null and b/backend/models/V3/devanagari_rec_fast/inference.pdiparams differ diff --git a/backend/models/V3/devanagari_rec_fast/inference.pdiparams.info b/backend/models/V3/devanagari_rec_fast/inference.pdiparams.info new file mode 100644 index 00000000..1cdccfce Binary files /dev/null and b/backend/models/V3/devanagari_rec_fast/inference.pdiparams.info differ diff --git a/backend/models/V3/devanagari_rec_fast/inference.pdmodel b/backend/models/V3/devanagari_rec_fast/inference.pdmodel new file mode 100644 index 00000000..f0ddf6a5 Binary files /dev/null and b/backend/models/V3/devanagari_rec_fast/inference.pdmodel differ diff --git a/backend/models/V3/en_rec_fast/inference.pdiparams b/backend/models/V3/en_rec_fast/inference.pdiparams new file mode 100644 index 00000000..26ba0c91 Binary files /dev/null and b/backend/models/V3/en_rec_fast/inference.pdiparams differ diff --git a/backend/models/V3/en_rec_fast/inference.pdiparams.info b/backend/models/V3/en_rec_fast/inference.pdiparams.info new file mode 100644 index 00000000..1cdccfce Binary files /dev/null and b/backend/models/V3/en_rec_fast/inference.pdiparams.info differ diff --git a/backend/models/V3/en_rec_fast/inference.pdmodel b/backend/models/V3/en_rec_fast/inference.pdmodel new file mode 100644 index 00000000..5dfe4cfa Binary files /dev/null and b/backend/models/V3/en_rec_fast/inference.pdmodel differ diff --git a/backend/models/V3/japan_rec_fast/inference.pdiparams b/backend/models/V3/japan_rec_fast/inference.pdiparams new file mode 100644 index 00000000..aa36dbd8 Binary files /dev/null and b/backend/models/V3/japan_rec_fast/inference.pdiparams differ diff --git a/backend/models/V3/japan_rec_fast/inference.pdiparams.info b/backend/models/V3/japan_rec_fast/inference.pdiparams.info new file mode 100644 index 00000000..1cdccfce Binary files /dev/null and b/backend/models/V3/japan_rec_fast/inference.pdiparams.info differ diff --git a/backend/models/V3/japan_rec_fast/inference.pdmodel b/backend/models/V3/japan_rec_fast/inference.pdmodel new file mode 100644 index 00000000..41b96742 Binary files /dev/null and b/backend/models/V3/japan_rec_fast/inference.pdmodel differ diff --git a/backend/models/V3/ka_rec_fast/inference.pdiparams b/backend/models/V3/ka_rec_fast/inference.pdiparams new file mode 100644 index 00000000..90c1a30c Binary files /dev/null and b/backend/models/V3/ka_rec_fast/inference.pdiparams differ diff --git a/backend/models/V3/ka_rec_fast/inference.pdiparams.info b/backend/models/V3/ka_rec_fast/inference.pdiparams.info new file mode 100644 index 00000000..1cdccfce Binary files /dev/null and b/backend/models/V3/ka_rec_fast/inference.pdiparams.info differ diff --git a/backend/models/V3/ka_rec_fast/inference.pdmodel b/backend/models/V3/ka_rec_fast/inference.pdmodel new file mode 100644 index 00000000..38d70902 Binary files /dev/null and b/backend/models/V3/ka_rec_fast/inference.pdmodel differ diff --git a/backend/models/V3/korean_rec_fast/inference.pdiparams b/backend/models/V3/korean_rec_fast/inference.pdiparams new file mode 100644 index 00000000..1419d768 Binary files /dev/null and b/backend/models/V3/korean_rec_fast/inference.pdiparams differ diff --git a/backend/models/V3/korean_rec_fast/inference.pdiparams.info b/backend/models/V3/korean_rec_fast/inference.pdiparams.info new file mode 100644 index 00000000..1cdccfce Binary files /dev/null and b/backend/models/V3/korean_rec_fast/inference.pdiparams.info differ diff --git a/backend/models/V3/korean_rec_fast/inference.pdmodel b/backend/models/V3/korean_rec_fast/inference.pdmodel new file mode 100644 index 00000000..fa2d8882 Binary files /dev/null and b/backend/models/V3/korean_rec_fast/inference.pdmodel differ diff --git a/backend/models/V3/latin_rec_fast/inference.pdiparams b/backend/models/V3/latin_rec_fast/inference.pdiparams new file mode 100644 index 00000000..6a7305b7 Binary files /dev/null and b/backend/models/V3/latin_rec_fast/inference.pdiparams differ diff --git a/backend/models/V3/latin_rec_fast/inference.pdiparams.info b/backend/models/V3/latin_rec_fast/inference.pdiparams.info new file mode 100644 index 00000000..1cdccfce Binary files /dev/null and b/backend/models/V3/latin_rec_fast/inference.pdiparams.info differ diff --git a/backend/models/V3/latin_rec_fast/inference.pdmodel b/backend/models/V3/latin_rec_fast/inference.pdmodel new file mode 100644 index 00000000..144978c1 Binary files /dev/null and b/backend/models/V3/latin_rec_fast/inference.pdmodel differ diff --git a/backend/models/V3/ta_rec_fast/inference.pdiparams b/backend/models/V3/ta_rec_fast/inference.pdiparams new file mode 100644 index 00000000..632b81a7 Binary files /dev/null and b/backend/models/V3/ta_rec_fast/inference.pdiparams differ diff --git a/backend/models/V3/ta_rec_fast/inference.pdiparams.info b/backend/models/V3/ta_rec_fast/inference.pdiparams.info new file mode 100644 index 00000000..1cdccfce Binary files /dev/null and b/backend/models/V3/ta_rec_fast/inference.pdiparams.info differ diff --git a/backend/models/V3/ta_rec_fast/inference.pdmodel b/backend/models/V3/ta_rec_fast/inference.pdmodel new file mode 100644 index 00000000..57a7b8c2 Binary files /dev/null and b/backend/models/V3/ta_rec_fast/inference.pdmodel differ diff --git a/backend/models/V3/te_rec_fast/inference.pdiparams b/backend/models/V3/te_rec_fast/inference.pdiparams new file mode 100644 index 00000000..7d4a1922 Binary files /dev/null and b/backend/models/V3/te_rec_fast/inference.pdiparams differ diff --git a/backend/models/V3/te_rec_fast/inference.pdiparams.info b/backend/models/V3/te_rec_fast/inference.pdiparams.info new file mode 100644 index 00000000..1cdccfce Binary files /dev/null and b/backend/models/V3/te_rec_fast/inference.pdiparams.info differ diff --git a/backend/models/V3/te_rec_fast/inference.pdmodel b/backend/models/V3/te_rec_fast/inference.pdmodel new file mode 100644 index 00000000..685b27a1 Binary files /dev/null and b/backend/models/V3/te_rec_fast/inference.pdmodel differ diff --git a/backend/models/V4/ch_det/fs_manifest.csv b/backend/models/V4/ch_det/fs_manifest.csv new file mode 100644 index 00000000..9cc67865 --- /dev/null +++ b/backend/models/V4/ch_det/fs_manifest.csv @@ -0,0 +1,4 @@ +filename,filesize,encoding,header +inference_1.pdiparams,50000000,, +inference_2.pdiparams,50000000,, +inference_3.pdiparams,13295054,, diff --git a/backend/models/V4/ch_det/inference.pdiparams.info b/backend/models/V4/ch_det/inference.pdiparams.info new file mode 100644 index 00000000..272488fb Binary files /dev/null and b/backend/models/V4/ch_det/inference.pdiparams.info differ diff --git a/backend/models/V4/ch_det/inference.pdmodel b/backend/models/V4/ch_det/inference.pdmodel new file mode 100644 index 00000000..15797214 Binary files /dev/null and b/backend/models/V4/ch_det/inference.pdmodel differ diff --git a/backend/models/V4/ch_det/inference_1.pdiparams b/backend/models/V4/ch_det/inference_1.pdiparams new file mode 100644 index 00000000..322c93dd Binary files /dev/null and b/backend/models/V4/ch_det/inference_1.pdiparams differ diff --git a/backend/models/V4/ch_det/inference_2.pdiparams b/backend/models/V4/ch_det/inference_2.pdiparams new file mode 100644 index 00000000..a3aa0605 Binary files /dev/null and b/backend/models/V4/ch_det/inference_2.pdiparams differ diff --git a/backend/models/V4/ch_det/inference_3.pdiparams b/backend/models/V4/ch_det/inference_3.pdiparams new file mode 100644 index 00000000..030d4888 Binary files /dev/null and b/backend/models/V4/ch_det/inference_3.pdiparams differ diff --git a/backend/models/V4/ch_det_fast/inference.pdiparams b/backend/models/V4/ch_det_fast/inference.pdiparams new file mode 100644 index 00000000..089594ae Binary files /dev/null and b/backend/models/V4/ch_det_fast/inference.pdiparams differ diff --git a/backend/models/V4/ch_det_fast/inference.pdiparams.info b/backend/models/V4/ch_det_fast/inference.pdiparams.info new file mode 100644 index 00000000..082c148e Binary files /dev/null and b/backend/models/V4/ch_det_fast/inference.pdiparams.info differ diff --git a/backend/models/V4/ch_det_fast/inference.pdmodel b/backend/models/V4/ch_det_fast/inference.pdmodel new file mode 100644 index 00000000..223b8614 Binary files /dev/null and b/backend/models/V4/ch_det_fast/inference.pdmodel differ diff --git a/backend/models/V4/ch_rec/inference.pdiparams b/backend/models/V4/ch_rec/inference.pdiparams new file mode 100644 index 00000000..45fc0939 Binary files /dev/null and b/backend/models/V4/ch_rec/inference.pdiparams differ diff --git a/backend/models/V4/ch_rec/inference.pdiparams.info b/backend/models/V4/ch_rec/inference.pdiparams.info new file mode 100644 index 00000000..abe688ab Binary files /dev/null and b/backend/models/V4/ch_rec/inference.pdiparams.info differ diff --git a/backend/models/V4/ch_rec/inference.pdmodel b/backend/models/V4/ch_rec/inference.pdmodel new file mode 100644 index 00000000..e4befed5 Binary files /dev/null and b/backend/models/V4/ch_rec/inference.pdmodel differ diff --git a/backend/models/V4/ch_rec_fast/inference.pdiparams b/backend/models/V4/ch_rec_fast/inference.pdiparams new file mode 100644 index 00000000..4c3d9e9c Binary files /dev/null and b/backend/models/V4/ch_rec_fast/inference.pdiparams differ diff --git a/backend/models/V4/ch_rec_fast/inference.pdiparams.info b/backend/models/V4/ch_rec_fast/inference.pdiparams.info new file mode 100644 index 00000000..923329f5 Binary files /dev/null and b/backend/models/V4/ch_rec_fast/inference.pdiparams.info differ diff --git a/backend/models/V4/ch_rec_fast/inference.pdmodel b/backend/models/V4/ch_rec_fast/inference.pdmodel new file mode 100644 index 00000000..dccddcc7 Binary files /dev/null and b/backend/models/V4/ch_rec_fast/inference.pdmodel differ diff --git a/backend/models/V4/en_rec_fast/inference.pdiparams b/backend/models/V4/en_rec_fast/inference.pdiparams new file mode 100644 index 00000000..49dac847 Binary files /dev/null and b/backend/models/V4/en_rec_fast/inference.pdiparams differ diff --git a/backend/models/V4/en_rec_fast/inference.pdiparams.info b/backend/models/V4/en_rec_fast/inference.pdiparams.info new file mode 100644 index 00000000..b1c33cb4 Binary files /dev/null and b/backend/models/V4/en_rec_fast/inference.pdiparams.info differ diff --git a/backend/models/V4/en_rec_fast/inference.pdmodel b/backend/models/V4/en_rec_fast/inference.pdmodel new file mode 100644 index 00000000..d0a593c3 Binary files /dev/null and b/backend/models/V4/en_rec_fast/inference.pdmodel differ diff --git a/backend/ppocr/__init__.py b/backend/ppocr/__init__.py index d0c32e26..e438e531 100755 --- a/backend/ppocr/__init__.py +++ b/backend/ppocr/__init__.py @@ -11,3 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings +warnings.filterwarnings("ignore", category=Warning) +warnings.filterwarnings("ignore", category=DeprecationWarning) diff --git a/backend/ppocr/data/__init__.py b/backend/ppocr/data/__init__.py index 7cb50d7a..78c32796 100644 --- a/backend/ppocr/data/__init__.py +++ b/backend/ppocr/data/__init__.py @@ -20,6 +20,7 @@ import os import sys import numpy as np +import skimage import paddle import signal import random @@ -34,6 +35,8 @@ from ppocr.data.imaug import transform, create_operators from ppocr.data.simple_dataset import SimpleDataSet from ppocr.data.lmdb_dataset import LMDBDataSet +from ppocr.data.pgnet_dataset import PGDataSet +from ppocr.data.pubtab_dataset import PubTabDataSet __all__ = ['build_dataloader', 'transform', 'create_operators'] @@ -47,14 +50,12 @@ def term_mp(sig_num, frame): os.killpg(pgid, signal.SIGKILL) -signal.signal(signal.SIGINT, term_mp) -signal.signal(signal.SIGTERM, term_mp) - - def build_dataloader(config, mode, device, logger, seed=None): config = copy.deepcopy(config) - support_dict = ['SimpleDataSet', 'LMDBDataSet'] + support_dict = [ + 'SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet' + ] module_name = config[mode]['dataset']['name'] assert module_name in support_dict, Exception( 'DataSet only support {}'.format(support_dict)) @@ -71,27 +72,38 @@ def build_dataloader(config, mode, device, logger, seed=None): use_shared_memory = loader_config['use_shared_memory'] else: use_shared_memory = True + if mode == "Train": - #Distribute data to multiple cards + # Distribute data to multiple cards batch_sampler = DistributedBatchSampler( dataset=dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last) else: - #Distribute data to single card + # Distribute data to single card batch_sampler = BatchSampler( dataset=dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last) + if 'collate_fn' in loader_config: + from . import collate_fn + collate_fn = getattr(collate_fn, loader_config['collate_fn'])() + else: + collate_fn = None data_loader = DataLoader( dataset=dataset, batch_sampler=batch_sampler, places=device, num_workers=num_workers, return_list=True, - use_shared_memory=use_shared_memory) + use_shared_memory=use_shared_memory, + collate_fn=collate_fn) + + # support exit using ctrl+c + signal.signal(signal.SIGINT, term_mp) + signal.signal(signal.SIGTERM, term_mp) return data_loader diff --git a/backend/ppocr/data/collate_fn.py b/backend/ppocr/data/collate_fn.py new file mode 100644 index 00000000..0da6060f --- /dev/null +++ b/backend/ppocr/data/collate_fn.py @@ -0,0 +1,72 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import numbers +import numpy as np +from collections import defaultdict + + +class DictCollator(object): + """ + data batch + """ + + def __call__(self, batch): + # todo:support batch operators + data_dict = defaultdict(list) + to_tensor_keys = [] + for sample in batch: + for k, v in sample.items(): + if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)): + if k not in to_tensor_keys: + to_tensor_keys.append(k) + data_dict[k].append(v) + for k in to_tensor_keys: + data_dict[k] = paddle.to_tensor(data_dict[k]) + return data_dict + + +class ListCollator(object): + """ + data batch + """ + + def __call__(self, batch): + # todo:support batch operators + data_dict = defaultdict(list) + to_tensor_idxs = [] + for sample in batch: + for idx, v in enumerate(sample): + if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)): + if idx not in to_tensor_idxs: + to_tensor_idxs.append(idx) + data_dict[idx].append(v) + for idx in to_tensor_idxs: + data_dict[idx] = paddle.to_tensor(data_dict[idx]) + return list(data_dict.values()) + + +class SSLRotateCollate(object): + """ + bach: [ + [(4*3xH*W), (4,)] + [(4*3xH*W), (4,)] + ... + ] + """ + + def __call__(self, batch): + output = [np.concatenate(d, axis=0) for d in zip(*batch)] + return output diff --git a/backend/ppocr/data/imaug/ColorJitter.py b/backend/ppocr/data/imaug/ColorJitter.py new file mode 100644 index 00000000..4b542abc --- /dev/null +++ b/backend/ppocr/data/imaug/ColorJitter.py @@ -0,0 +1,26 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from paddle.vision.transforms import ColorJitter as pp_ColorJitter + +__all__ = ['ColorJitter'] + +class ColorJitter(object): + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0,**kwargs): + self.aug = pp_ColorJitter(brightness, contrast, saturation, hue) + + def __call__(self, data): + image = data['image'] + image = self.aug(image) + data['image'] = image + return data diff --git a/backend/ppocr/data/imaug/__init__.py b/backend/ppocr/data/imaug/__init__.py index 250ac75e..548832fb 100644 --- a/backend/ppocr/data/imaug/__init__.py +++ b/backend/ppocr/data/imaug/__init__.py @@ -19,15 +19,27 @@ from .iaa_augment import IaaAugment from .make_border_map import MakeBorderMap from .make_shrink_map import MakeShrinkMap -from .random_crop_data import EastRandomCropData, PSERandomCrop +from .random_crop_data import EastRandomCropData, RandomCropImgMask +from .make_pse_gt import MakePseGt -from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg +from .rec_img_aug import RecAug, RecConAug, RecResizeImg, ClsResizeImg, \ + SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg +from .ssl_img_aug import SSLRotateResize from .randaugment import RandAugment +from .copy_paste import CopyPaste +from .ColorJitter import ColorJitter from .operators import * from .label_ops import * from .east_process import * from .sast_process import * +from .pg_process import * +from .gen_table_mask import * + +from .vqa import * + +from .fce_aug import * +from .fce_targets import FCENetTargets def transform(data, ops=None): diff --git a/backend/ppocr/data/imaug/copy_paste.py b/backend/ppocr/data/imaug/copy_paste.py new file mode 100644 index 00000000..0b3386c8 --- /dev/null +++ b/backend/ppocr/data/imaug/copy_paste.py @@ -0,0 +1,170 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import cv2 +import random +import numpy as np +from PIL import Image +from shapely.geometry import Polygon + +from ppocr.data.imaug.iaa_augment import IaaAugment +from ppocr.data.imaug.random_crop_data import is_poly_outside_rect +from tools.infer.utility import get_rotate_crop_image + + +class CopyPaste(object): + def __init__(self, objects_paste_ratio=0.2, limit_paste=True, **kwargs): + self.ext_data_num = 1 + self.objects_paste_ratio = objects_paste_ratio + self.limit_paste = limit_paste + augmenter_args = [{'type': 'Resize', 'args': {'size': [0.5, 3]}}] + self.aug = IaaAugment(augmenter_args) + + def __call__(self, data): + point_num = data['polys'].shape[1] + src_img = data['image'] + src_polys = data['polys'].tolist() + src_ignores = data['ignore_tags'].tolist() + ext_data = data['ext_data'][0] + ext_image = ext_data['image'] + ext_polys = ext_data['polys'] + ext_ignores = ext_data['ignore_tags'] + + indexs = [i for i in range(len(ext_ignores)) if not ext_ignores[i]] + select_num = max( + 1, min(int(self.objects_paste_ratio * len(ext_polys)), 30)) + + random.shuffle(indexs) + select_idxs = indexs[:select_num] + select_polys = ext_polys[select_idxs] + select_ignores = ext_ignores[select_idxs] + + src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB) + ext_image = cv2.cvtColor(ext_image, cv2.COLOR_BGR2RGB) + src_img = Image.fromarray(src_img).convert('RGBA') + for poly, tag in zip(select_polys, select_ignores): + box_img = get_rotate_crop_image(ext_image, poly) + + src_img, box = self.paste_img(src_img, box_img, src_polys) + if box is not None: + box = box.tolist() + for _ in range(len(box), point_num): + box.append(box[-1]) + src_polys.append(box) + src_ignores.append(tag) + src_img = cv2.cvtColor(np.array(src_img), cv2.COLOR_RGB2BGR) + h, w = src_img.shape[:2] + src_polys = np.array(src_polys) + src_polys[:, :, 0] = np.clip(src_polys[:, :, 0], 0, w) + src_polys[:, :, 1] = np.clip(src_polys[:, :, 1], 0, h) + data['image'] = src_img + data['polys'] = src_polys + data['ignore_tags'] = np.array(src_ignores) + return data + + def paste_img(self, src_img, box_img, src_polys): + box_img_pil = Image.fromarray(box_img).convert('RGBA') + src_w, src_h = src_img.size + box_w, box_h = box_img_pil.size + + angle = np.random.randint(0, 360) + box = np.array([[[0, 0], [box_w, 0], [box_w, box_h], [0, box_h]]]) + box = rotate_bbox(box_img, box, angle)[0] + box_img_pil = box_img_pil.rotate(angle, expand=1) + box_w, box_h = box_img_pil.width, box_img_pil.height + if src_w - box_w < 0 or src_h - box_h < 0: + return src_img, None + + paste_x, paste_y = self.select_coord(src_polys, box, src_w - box_w, + src_h - box_h) + if paste_x is None: + return src_img, None + box[:, 0] += paste_x + box[:, 1] += paste_y + r, g, b, A = box_img_pil.split() + src_img.paste(box_img_pil, (paste_x, paste_y), mask=A) + + return src_img, box + + def select_coord(self, src_polys, box, endx, endy): + if self.limit_paste: + xmin, ymin, xmax, ymax = box[:, 0].min(), box[:, 1].min( + ), box[:, 0].max(), box[:, 1].max() + for _ in range(50): + paste_x = random.randint(0, endx) + paste_y = random.randint(0, endy) + xmin1 = xmin + paste_x + xmax1 = xmax + paste_x + ymin1 = ymin + paste_y + ymax1 = ymax + paste_y + + num_poly_in_rect = 0 + for poly in src_polys: + if not is_poly_outside_rect(poly, xmin1, ymin1, + xmax1 - xmin1, ymax1 - ymin1): + num_poly_in_rect += 1 + break + if num_poly_in_rect == 0: + return paste_x, paste_y + return None, None + else: + paste_x = random.randint(0, endx) + paste_y = random.randint(0, endy) + return paste_x, paste_y + + +def get_union(pD, pG): + return Polygon(pD).union(Polygon(pG)).area + + +def get_intersection_over_union(pD, pG): + return get_intersection(pD, pG) / get_union(pD, pG) + + +def get_intersection(pD, pG): + return Polygon(pD).intersection(Polygon(pG)).area + + +def rotate_bbox(img, text_polys, angle, scale=1): + """ + from https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/augment.py + Args: + img: np.ndarray + text_polys: np.ndarray N*4*2 + angle: int + scale: int + + Returns: + + """ + w = img.shape[1] + h = img.shape[0] + + rangle = np.deg2rad(angle) + nw = (abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w)) + nh = (abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w)) + rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, scale) + rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0])) + rot_mat[0, 2] += rot_move[0] + rot_mat[1, 2] += rot_move[1] + + # ---------------------- rotate box ---------------------- + rot_text_polys = list() + for bbox in text_polys: + point1 = np.dot(rot_mat, np.array([bbox[0, 0], bbox[0, 1], 1])) + point2 = np.dot(rot_mat, np.array([bbox[1, 0], bbox[1, 1], 1])) + point3 = np.dot(rot_mat, np.array([bbox[2, 0], bbox[2, 1], 1])) + point4 = np.dot(rot_mat, np.array([bbox[3, 0], bbox[3, 1], 1])) + rot_text_polys.append([point1, point2, point3, point4]) + return np.array(rot_text_polys, dtype=np.float32) diff --git a/backend/ppocr/data/imaug/east_process.py b/backend/ppocr/data/imaug/east_process.py index b1d7a5e5..df08adfa 100644 --- a/backend/ppocr/data/imaug/east_process.py +++ b/backend/ppocr/data/imaug/east_process.py @@ -11,7 +11,10 @@ #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #See the License for the specific language governing permissions and #limitations under the License. - +""" +This code is refered from: +https://github.com/songdejia/EAST/blob/master/data_utils.py +""" import math import cv2 import numpy as np @@ -24,10 +27,10 @@ class EASTProcessTrain(object): def __init__(self, - image_shape = [512, 512], - background_ratio = 0.125, - min_crop_side_ratio = 0.1, - min_text_size = 10, + image_shape=[512, 512], + background_ratio=0.125, + min_crop_side_ratio=0.1, + min_text_size=10, **kwargs): self.input_size = image_shape[1] self.random_scale = np.array([0.5, 1, 2.0, 3.0]) @@ -282,12 +285,7 @@ def generate_quad(self, im_size, polys, tags): 1.0 / max(min(poly_h, poly_w), 1.0) return score_map, geo_map, training_mask - def crop_area(self, - im, - polys, - tags, - crop_background=False, - max_tries=50): + def crop_area(self, im, polys, tags, crop_background=False, max_tries=50): """ make random crop from the input image :param im: @@ -435,5 +433,4 @@ def __call__(self, data): data['score_map'] = score_map data['geo_map'] = geo_map data['training_mask'] = training_mask - # print(im.shape, score_map.shape, geo_map.shape, training_mask.shape) - return data \ No newline at end of file + return data diff --git a/backend/ppocr/data/imaug/fce_aug.py b/backend/ppocr/data/imaug/fce_aug.py new file mode 100644 index 00000000..66bafef1 --- /dev/null +++ b/backend/ppocr/data/imaug/fce_aug.py @@ -0,0 +1,564 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/open-mmlab/mmocr/blob/main/mmocr/datasets/pipelines/transforms.py +""" +import numpy as np +from PIL import Image, ImageDraw +import cv2 +from shapely.geometry import Polygon +import math +from ppocr.utils.poly_nms import poly_intersection + + +class RandomScaling: + def __init__(self, size=800, scale=(3. / 4, 5. / 2), **kwargs): + """Random scale the image while keeping aspect. + + Args: + size (int) : Base size before scaling. + scale (tuple(float)) : The range of scaling. + """ + assert isinstance(size, int) + assert isinstance(scale, float) or isinstance(scale, tuple) + self.size = size + self.scale = scale if isinstance(scale, tuple) \ + else (1 - scale, 1 + scale) + + def __call__(self, data): + image = data['image'] + text_polys = data['polys'] + h, w, _ = image.shape + + aspect_ratio = np.random.uniform(min(self.scale), max(self.scale)) + scales = self.size * 1.0 / max(h, w) * aspect_ratio + scales = np.array([scales, scales]) + out_size = (int(h * scales[1]), int(w * scales[0])) + image = cv2.resize(image, out_size[::-1]) + + data['image'] = image + text_polys[:, :, 0::2] = text_polys[:, :, 0::2] * scales[1] + text_polys[:, :, 1::2] = text_polys[:, :, 1::2] * scales[0] + data['polys'] = text_polys + + return data + + +class RandomCropFlip: + def __init__(self, + pad_ratio=0.1, + crop_ratio=0.5, + iter_num=1, + min_area_ratio=0.2, + **kwargs): + """Random crop and flip a patch of the image. + + Args: + crop_ratio (float): The ratio of cropping. + iter_num (int): Number of operations. + min_area_ratio (float): Minimal area ratio between cropped patch + and original image. + """ + assert isinstance(crop_ratio, float) + assert isinstance(iter_num, int) + assert isinstance(min_area_ratio, float) + + self.pad_ratio = pad_ratio + self.epsilon = 1e-2 + self.crop_ratio = crop_ratio + self.iter_num = iter_num + self.min_area_ratio = min_area_ratio + + def __call__(self, results): + for i in range(self.iter_num): + results = self.random_crop_flip(results) + + return results + + def random_crop_flip(self, results): + image = results['image'] + polygons = results['polys'] + ignore_tags = results['ignore_tags'] + if len(polygons) == 0: + return results + + if np.random.random() >= self.crop_ratio: + return results + + h, w, _ = image.shape + area = h * w + pad_h = int(h * self.pad_ratio) + pad_w = int(w * self.pad_ratio) + h_axis, w_axis = self.generate_crop_target(image, polygons, pad_h, + pad_w) + if len(h_axis) == 0 or len(w_axis) == 0: + return results + + attempt = 0 + while attempt < 50: + attempt += 1 + polys_keep = [] + polys_new = [] + ignore_tags_keep = [] + ignore_tags_new = [] + xx = np.random.choice(w_axis, size=2) + xmin = np.min(xx) - pad_w + xmax = np.max(xx) - pad_w + xmin = np.clip(xmin, 0, w - 1) + xmax = np.clip(xmax, 0, w - 1) + yy = np.random.choice(h_axis, size=2) + ymin = np.min(yy) - pad_h + ymax = np.max(yy) - pad_h + ymin = np.clip(ymin, 0, h - 1) + ymax = np.clip(ymax, 0, h - 1) + if (xmax - xmin) * (ymax - ymin) < area * self.min_area_ratio: + # area too small + continue + + pts = np.stack([[xmin, xmax, xmax, xmin], + [ymin, ymin, ymax, ymax]]).T.astype(np.int32) + pp = Polygon(pts) + fail_flag = False + for polygon, ignore_tag in zip(polygons, ignore_tags): + ppi = Polygon(polygon.reshape(-1, 2)) + ppiou, _ = poly_intersection(ppi, pp, buffer=0) + if np.abs(ppiou - float(ppi.area)) > self.epsilon and \ + np.abs(ppiou) > self.epsilon: + fail_flag = True + break + elif np.abs(ppiou - float(ppi.area)) < self.epsilon: + polys_new.append(polygon) + ignore_tags_new.append(ignore_tag) + else: + polys_keep.append(polygon) + ignore_tags_keep.append(ignore_tag) + + if fail_flag: + continue + else: + break + + cropped = image[ymin:ymax, xmin:xmax, :] + select_type = np.random.randint(3) + if select_type == 0: + img = np.ascontiguousarray(cropped[:, ::-1]) + elif select_type == 1: + img = np.ascontiguousarray(cropped[::-1, :]) + else: + img = np.ascontiguousarray(cropped[::-1, ::-1]) + image[ymin:ymax, xmin:xmax, :] = img + results['img'] = image + + if len(polys_new) != 0: + height, width, _ = cropped.shape + if select_type == 0: + for idx, polygon in enumerate(polys_new): + poly = polygon.reshape(-1, 2) + poly[:, 0] = width - poly[:, 0] + 2 * xmin + polys_new[idx] = poly + elif select_type == 1: + for idx, polygon in enumerate(polys_new): + poly = polygon.reshape(-1, 2) + poly[:, 1] = height - poly[:, 1] + 2 * ymin + polys_new[idx] = poly + else: + for idx, polygon in enumerate(polys_new): + poly = polygon.reshape(-1, 2) + poly[:, 0] = width - poly[:, 0] + 2 * xmin + poly[:, 1] = height - poly[:, 1] + 2 * ymin + polys_new[idx] = poly + polygons = polys_keep + polys_new + ignore_tags = ignore_tags_keep + ignore_tags_new + results['polys'] = np.array(polygons) + results['ignore_tags'] = ignore_tags + + return results + + def generate_crop_target(self, image, all_polys, pad_h, pad_w): + """Generate crop target and make sure not to crop the polygon + instances. + + Args: + image (ndarray): The image waited to be crop. + all_polys (list[list[ndarray]]): All polygons including ground + truth polygons and ground truth ignored polygons. + pad_h (int): Padding length of height. + pad_w (int): Padding length of width. + Returns: + h_axis (ndarray): Vertical cropping range. + w_axis (ndarray): Horizontal cropping range. + """ + h, w, _ = image.shape + h_array = np.zeros((h + pad_h * 2), dtype=np.int32) + w_array = np.zeros((w + pad_w * 2), dtype=np.int32) + + text_polys = [] + for polygon in all_polys: + rect = cv2.minAreaRect(polygon.astype(np.int32).reshape(-1, 2)) + box = cv2.boxPoints(rect) + box = np.int0(box) + text_polys.append([box[0], box[1], box[2], box[3]]) + + polys = np.array(text_polys, dtype=np.int32) + for poly in polys: + poly = np.round(poly, decimals=0).astype(np.int32) + minx = np.min(poly[:, 0]) + maxx = np.max(poly[:, 0]) + w_array[minx + pad_w:maxx + pad_w] = 1 + miny = np.min(poly[:, 1]) + maxy = np.max(poly[:, 1]) + h_array[miny + pad_h:maxy + pad_h] = 1 + + h_axis = np.where(h_array == 0)[0] + w_axis = np.where(w_array == 0)[0] + return h_axis, w_axis + + +class RandomCropPolyInstances: + """Randomly crop images and make sure to contain at least one intact + instance.""" + + def __init__(self, crop_ratio=5.0 / 8.0, min_side_ratio=0.4, **kwargs): + super().__init__() + self.crop_ratio = crop_ratio + self.min_side_ratio = min_side_ratio + + def sample_valid_start_end(self, valid_array, min_len, max_start, min_end): + + assert isinstance(min_len, int) + assert len(valid_array) > min_len + + start_array = valid_array.copy() + max_start = min(len(start_array) - min_len, max_start) + start_array[max_start:] = 0 + start_array[0] = 1 + diff_array = np.hstack([0, start_array]) - np.hstack([start_array, 0]) + region_starts = np.where(diff_array < 0)[0] + region_ends = np.where(diff_array > 0)[0] + region_ind = np.random.randint(0, len(region_starts)) + start = np.random.randint(region_starts[region_ind], + region_ends[region_ind]) + + end_array = valid_array.copy() + min_end = max(start + min_len, min_end) + end_array[:min_end] = 0 + end_array[-1] = 1 + diff_array = np.hstack([0, end_array]) - np.hstack([end_array, 0]) + region_starts = np.where(diff_array < 0)[0] + region_ends = np.where(diff_array > 0)[0] + region_ind = np.random.randint(0, len(region_starts)) + end = np.random.randint(region_starts[region_ind], + region_ends[region_ind]) + return start, end + + def sample_crop_box(self, img_size, results): + """Generate crop box and make sure not to crop the polygon instances. + + Args: + img_size (tuple(int)): The image size (h, w). + results (dict): The results dict. + """ + + assert isinstance(img_size, tuple) + h, w = img_size[:2] + + key_masks = results['polys'] + + x_valid_array = np.ones(w, dtype=np.int32) + y_valid_array = np.ones(h, dtype=np.int32) + + selected_mask = key_masks[np.random.randint(0, len(key_masks))] + selected_mask = selected_mask.reshape((-1, 2)).astype(np.int32) + max_x_start = max(np.min(selected_mask[:, 0]) - 2, 0) + min_x_end = min(np.max(selected_mask[:, 0]) + 3, w - 1) + max_y_start = max(np.min(selected_mask[:, 1]) - 2, 0) + min_y_end = min(np.max(selected_mask[:, 1]) + 3, h - 1) + + for mask in key_masks: + mask = mask.reshape((-1, 2)).astype(np.int32) + clip_x = np.clip(mask[:, 0], 0, w - 1) + clip_y = np.clip(mask[:, 1], 0, h - 1) + min_x, max_x = np.min(clip_x), np.max(clip_x) + min_y, max_y = np.min(clip_y), np.max(clip_y) + + x_valid_array[min_x - 2:max_x + 3] = 0 + y_valid_array[min_y - 2:max_y + 3] = 0 + + min_w = int(w * self.min_side_ratio) + min_h = int(h * self.min_side_ratio) + + x1, x2 = self.sample_valid_start_end(x_valid_array, min_w, max_x_start, + min_x_end) + y1, y2 = self.sample_valid_start_end(y_valid_array, min_h, max_y_start, + min_y_end) + + return np.array([x1, y1, x2, y2]) + + def crop_img(self, img, bbox): + assert img.ndim == 3 + h, w, _ = img.shape + assert 0 <= bbox[1] < bbox[3] <= h + assert 0 <= bbox[0] < bbox[2] <= w + return img[bbox[1]:bbox[3], bbox[0]:bbox[2]] + + def __call__(self, results): + image = results['image'] + polygons = results['polys'] + ignore_tags = results['ignore_tags'] + if len(polygons) < 1: + return results + + if np.random.random_sample() < self.crop_ratio: + + crop_box = self.sample_crop_box(image.shape, results) + img = self.crop_img(image, crop_box) + results['image'] = img + # crop and filter masks + x1, y1, x2, y2 = crop_box + w = max(x2 - x1, 1) + h = max(y2 - y1, 1) + polygons[:, :, 0::2] = polygons[:, :, 0::2] - x1 + polygons[:, :, 1::2] = polygons[:, :, 1::2] - y1 + + valid_masks_list = [] + valid_tags_list = [] + for ind, polygon in enumerate(polygons): + if (polygon[:, ::2] > -4).all() and ( + polygon[:, ::2] < w + 4).all() and ( + polygon[:, 1::2] > -4).all() and ( + polygon[:, 1::2] < h + 4).all(): + polygon[:, ::2] = np.clip(polygon[:, ::2], 0, w) + polygon[:, 1::2] = np.clip(polygon[:, 1::2], 0, h) + valid_masks_list.append(polygon) + valid_tags_list.append(ignore_tags[ind]) + + results['polys'] = np.array(valid_masks_list) + results['ignore_tags'] = valid_tags_list + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + return repr_str + + +class RandomRotatePolyInstances: + def __init__(self, + rotate_ratio=0.5, + max_angle=10, + pad_with_fixed_color=False, + pad_value=(0, 0, 0), + **kwargs): + """Randomly rotate images and polygon masks. + + Args: + rotate_ratio (float): The ratio of samples to operate rotation. + max_angle (int): The maximum rotation angle. + pad_with_fixed_color (bool): The flag for whether to pad rotated + image with fixed value. If set to False, the rotated image will + be padded onto cropped image. + pad_value (tuple(int)): The color value for padding rotated image. + """ + self.rotate_ratio = rotate_ratio + self.max_angle = max_angle + self.pad_with_fixed_color = pad_with_fixed_color + self.pad_value = pad_value + + def rotate(self, center, points, theta, center_shift=(0, 0)): + # rotate points. + (center_x, center_y) = center + center_y = -center_y + x, y = points[:, ::2], points[:, 1::2] + y = -y + + theta = theta / 180 * math.pi + cos = math.cos(theta) + sin = math.sin(theta) + + x = (x - center_x) + y = (y - center_y) + + _x = center_x + x * cos - y * sin + center_shift[0] + _y = -(center_y + x * sin + y * cos) + center_shift[1] + + points[:, ::2], points[:, 1::2] = _x, _y + return points + + def cal_canvas_size(self, ori_size, degree): + assert isinstance(ori_size, tuple) + angle = degree * math.pi / 180.0 + h, w = ori_size[:2] + + cos = math.cos(angle) + sin = math.sin(angle) + canvas_h = int(w * math.fabs(sin) + h * math.fabs(cos)) + canvas_w = int(w * math.fabs(cos) + h * math.fabs(sin)) + + canvas_size = (canvas_h, canvas_w) + return canvas_size + + def sample_angle(self, max_angle): + angle = np.random.random_sample() * 2 * max_angle - max_angle + return angle + + def rotate_img(self, img, angle, canvas_size): + h, w = img.shape[:2] + rotation_matrix = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1) + rotation_matrix[0, 2] += int((canvas_size[1] - w) / 2) + rotation_matrix[1, 2] += int((canvas_size[0] - h) / 2) + + if self.pad_with_fixed_color: + target_img = cv2.warpAffine( + img, + rotation_matrix, (canvas_size[1], canvas_size[0]), + flags=cv2.INTER_NEAREST, + borderValue=self.pad_value) + else: + mask = np.zeros_like(img) + (h_ind, w_ind) = (np.random.randint(0, h * 7 // 8), + np.random.randint(0, w * 7 // 8)) + img_cut = img[h_ind:(h_ind + h // 9), w_ind:(w_ind + w // 9)] + img_cut = cv2.resize(img_cut, (canvas_size[1], canvas_size[0])) + + mask = cv2.warpAffine( + mask, + rotation_matrix, (canvas_size[1], canvas_size[0]), + borderValue=[1, 1, 1]) + target_img = cv2.warpAffine( + img, + rotation_matrix, (canvas_size[1], canvas_size[0]), + borderValue=[0, 0, 0]) + target_img = target_img + img_cut * mask + + return target_img + + def __call__(self, results): + if np.random.random_sample() < self.rotate_ratio: + image = results['image'] + polygons = results['polys'] + h, w = image.shape[:2] + + angle = self.sample_angle(self.max_angle) + canvas_size = self.cal_canvas_size((h, w), angle) + center_shift = (int((canvas_size[1] - w) / 2), int( + (canvas_size[0] - h) / 2)) + image = self.rotate_img(image, angle, canvas_size) + results['image'] = image + # rotate polygons + rotated_masks = [] + for mask in polygons: + rotated_mask = self.rotate((w / 2, h / 2), mask, angle, + center_shift) + rotated_masks.append(rotated_mask) + results['polys'] = np.array(rotated_masks) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + return repr_str + + +class SquareResizePad: + def __init__(self, + target_size, + pad_ratio=0.6, + pad_with_fixed_color=False, + pad_value=(0, 0, 0), + **kwargs): + """Resize or pad images to be square shape. + + Args: + target_size (int): The target size of square shaped image. + pad_with_fixed_color (bool): The flag for whether to pad rotated + image with fixed value. If set to False, the rescales image will + be padded onto cropped image. + pad_value (tuple(int)): The color value for padding rotated image. + """ + assert isinstance(target_size, int) + assert isinstance(pad_ratio, float) + assert isinstance(pad_with_fixed_color, bool) + assert isinstance(pad_value, tuple) + + self.target_size = target_size + self.pad_ratio = pad_ratio + self.pad_with_fixed_color = pad_with_fixed_color + self.pad_value = pad_value + + def resize_img(self, img, keep_ratio=True): + h, w, _ = img.shape + if keep_ratio: + t_h = self.target_size if h >= w else int(h * self.target_size / w) + t_w = self.target_size if h <= w else int(w * self.target_size / h) + else: + t_h = t_w = self.target_size + img = cv2.resize(img, (t_w, t_h)) + return img, (t_h, t_w) + + def square_pad(self, img): + h, w = img.shape[:2] + if h == w: + return img, (0, 0) + pad_size = max(h, w) + if self.pad_with_fixed_color: + expand_img = np.ones((pad_size, pad_size, 3), dtype=np.uint8) + expand_img[:] = self.pad_value + else: + (h_ind, w_ind) = (np.random.randint(0, h * 7 // 8), + np.random.randint(0, w * 7 // 8)) + img_cut = img[h_ind:(h_ind + h // 9), w_ind:(w_ind + w // 9)] + expand_img = cv2.resize(img_cut, (pad_size, pad_size)) + if h > w: + y0, x0 = 0, (h - w) // 2 + else: + y0, x0 = (w - h) // 2, 0 + expand_img[y0:y0 + h, x0:x0 + w] = img + offset = (x0, y0) + + return expand_img, offset + + def square_pad_mask(self, points, offset): + x0, y0 = offset + pad_points = points.copy() + pad_points[::2] = pad_points[::2] + x0 + pad_points[1::2] = pad_points[1::2] + y0 + return pad_points + + def __call__(self, results): + image = results['image'] + polygons = results['polys'] + h, w = image.shape[:2] + + if np.random.random_sample() < self.pad_ratio: + image, out_size = self.resize_img(image, keep_ratio=True) + image, offset = self.square_pad(image) + else: + image, out_size = self.resize_img(image, keep_ratio=False) + offset = (0, 0) + results['image'] = image + try: + polygons[:, :, 0::2] = polygons[:, :, 0::2] * out_size[ + 1] / w + offset[0] + polygons[:, :, 1::2] = polygons[:, :, 1::2] * out_size[ + 0] / h + offset[1] + except: + pass + results['polys'] = polygons + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + return repr_str diff --git a/backend/ppocr/data/imaug/fce_targets.py b/backend/ppocr/data/imaug/fce_targets.py new file mode 100644 index 00000000..18184808 --- /dev/null +++ b/backend/ppocr/data/imaug/fce_targets.py @@ -0,0 +1,658 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/open-mmlab/mmocr/blob/main/mmocr/datasets/pipelines/textdet_targets/fcenet_targets.py +""" + +import cv2 +import numpy as np +from numpy.fft import fft +from numpy.linalg import norm +import sys + + +class FCENetTargets: + """Generate the ground truth targets of FCENet: Fourier Contour Embedding + for Arbitrary-Shaped Text Detection. + + [https://arxiv.org/abs/2104.10442] + + Args: + fourier_degree (int): The maximum Fourier transform degree k. + resample_step (float): The step size for resampling the text center + line (TCL). It's better not to exceed half of the minimum width. + center_region_shrink_ratio (float): The shrink ratio of text center + region. + level_size_divisors (tuple(int)): The downsample ratio on each level. + level_proportion_range (tuple(tuple(int))): The range of text sizes + assigned to each level. + """ + + def __init__(self, + fourier_degree=5, + resample_step=4.0, + center_region_shrink_ratio=0.3, + level_size_divisors=(8, 16, 32), + level_proportion_range=((0, 0.25), (0.2, 0.65), (0.55, 1.0)), + orientation_thr=2.0, + **kwargs): + + super().__init__() + assert isinstance(level_size_divisors, tuple) + assert isinstance(level_proportion_range, tuple) + assert len(level_size_divisors) == len(level_proportion_range) + self.fourier_degree = fourier_degree + self.resample_step = resample_step + self.center_region_shrink_ratio = center_region_shrink_ratio + self.level_size_divisors = level_size_divisors + self.level_proportion_range = level_proportion_range + + self.orientation_thr = orientation_thr + + def vector_angle(self, vec1, vec2): + if vec1.ndim > 1: + unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8).reshape((-1, 1)) + else: + unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8) + if vec2.ndim > 1: + unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8).reshape((-1, 1)) + else: + unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8) + return np.arccos( + np.clip( + np.sum(unit_vec1 * unit_vec2, axis=-1), -1.0, 1.0)) + + def resample_line(self, line, n): + """Resample n points on a line. + + Args: + line (ndarray): The points composing a line. + n (int): The resampled points number. + + Returns: + resampled_line (ndarray): The points composing the resampled line. + """ + + assert line.ndim == 2 + assert line.shape[0] >= 2 + assert line.shape[1] == 2 + assert isinstance(n, int) + assert n > 0 + + length_list = [ + norm(line[i + 1] - line[i]) for i in range(len(line) - 1) + ] + total_length = sum(length_list) + length_cumsum = np.cumsum([0.0] + length_list) + delta_length = total_length / (float(n) + 1e-8) + + current_edge_ind = 0 + resampled_line = [line[0]] + + for i in range(1, n): + current_line_len = i * delta_length + + while current_line_len >= length_cumsum[current_edge_ind + 1]: + current_edge_ind += 1 + current_edge_end_shift = current_line_len - length_cumsum[ + current_edge_ind] + end_shift_ratio = current_edge_end_shift / length_list[ + current_edge_ind] + current_point = line[current_edge_ind] + (line[current_edge_ind + 1] + - line[current_edge_ind] + ) * end_shift_ratio + resampled_line.append(current_point) + + resampled_line.append(line[-1]) + resampled_line = np.array(resampled_line) + + return resampled_line + + def reorder_poly_edge(self, points): + """Get the respective points composing head edge, tail edge, top + sideline and bottom sideline. + + Args: + points (ndarray): The points composing a text polygon. + + Returns: + head_edge (ndarray): The two points composing the head edge of text + polygon. + tail_edge (ndarray): The two points composing the tail edge of text + polygon. + top_sideline (ndarray): The points composing top curved sideline of + text polygon. + bot_sideline (ndarray): The points composing bottom curved sideline + of text polygon. + """ + + assert points.ndim == 2 + assert points.shape[0] >= 4 + assert points.shape[1] == 2 + + head_inds, tail_inds = self.find_head_tail(points, self.orientation_thr) + head_edge, tail_edge = points[head_inds], points[tail_inds] + + pad_points = np.vstack([points, points]) + if tail_inds[1] < 1: + tail_inds[1] = len(points) + sideline1 = pad_points[head_inds[1]:tail_inds[1]] + sideline2 = pad_points[tail_inds[1]:(head_inds[1] + len(points))] + sideline_mean_shift = np.mean( + sideline1, axis=0) - np.mean( + sideline2, axis=0) + + if sideline_mean_shift[1] > 0: + top_sideline, bot_sideline = sideline2, sideline1 + else: + top_sideline, bot_sideline = sideline1, sideline2 + + return head_edge, tail_edge, top_sideline, bot_sideline + + def find_head_tail(self, points, orientation_thr): + """Find the head edge and tail edge of a text polygon. + + Args: + points (ndarray): The points composing a text polygon. + orientation_thr (float): The threshold for distinguishing between + head edge and tail edge among the horizontal and vertical edges + of a quadrangle. + + Returns: + head_inds (list): The indexes of two points composing head edge. + tail_inds (list): The indexes of two points composing tail edge. + """ + + assert points.ndim == 2 + assert points.shape[0] >= 4 + assert points.shape[1] == 2 + assert isinstance(orientation_thr, float) + + if len(points) > 4: + pad_points = np.vstack([points, points[0]]) + edge_vec = pad_points[1:] - pad_points[:-1] + + theta_sum = [] + adjacent_vec_theta = [] + for i, edge_vec1 in enumerate(edge_vec): + adjacent_ind = [x % len(edge_vec) for x in [i - 1, i + 1]] + adjacent_edge_vec = edge_vec[adjacent_ind] + temp_theta_sum = np.sum( + self.vector_angle(edge_vec1, adjacent_edge_vec)) + temp_adjacent_theta = self.vector_angle(adjacent_edge_vec[0], + adjacent_edge_vec[1]) + theta_sum.append(temp_theta_sum) + adjacent_vec_theta.append(temp_adjacent_theta) + theta_sum_score = np.array(theta_sum) / np.pi + adjacent_theta_score = np.array(adjacent_vec_theta) / np.pi + poly_center = np.mean(points, axis=0) + edge_dist = np.maximum( + norm( + pad_points[1:] - poly_center, axis=-1), + norm( + pad_points[:-1] - poly_center, axis=-1)) + dist_score = edge_dist / np.max(edge_dist) + position_score = np.zeros(len(edge_vec)) + score = 0.5 * theta_sum_score + 0.15 * adjacent_theta_score + score += 0.35 * dist_score + if len(points) % 2 == 0: + position_score[(len(score) // 2 - 1)] += 1 + position_score[-1] += 1 + score += 0.1 * position_score + pad_score = np.concatenate([score, score]) + score_matrix = np.zeros((len(score), len(score) - 3)) + x = np.arange(len(score) - 3) / float(len(score) - 4) + gaussian = 1. / (np.sqrt(2. * np.pi) * 0.5) * np.exp(-np.power( + (x - 0.5) / 0.5, 2.) / 2) + gaussian = gaussian / np.max(gaussian) + for i in range(len(score)): + score_matrix[i, :] = score[i] + pad_score[(i + 2):(i + len( + score) - 1)] * gaussian * 0.3 + + head_start, tail_increment = np.unravel_index(score_matrix.argmax(), + score_matrix.shape) + tail_start = (head_start + tail_increment + 2) % len(points) + head_end = (head_start + 1) % len(points) + tail_end = (tail_start + 1) % len(points) + + if head_end > tail_end: + head_start, tail_start = tail_start, head_start + head_end, tail_end = tail_end, head_end + head_inds = [head_start, head_end] + tail_inds = [tail_start, tail_end] + else: + if self.vector_slope(points[1] - points[0]) + self.vector_slope( + points[3] - points[2]) < self.vector_slope(points[ + 2] - points[1]) + self.vector_slope(points[0] - points[ + 3]): + horizontal_edge_inds = [[0, 1], [2, 3]] + vertical_edge_inds = [[3, 0], [1, 2]] + else: + horizontal_edge_inds = [[3, 0], [1, 2]] + vertical_edge_inds = [[0, 1], [2, 3]] + + vertical_len_sum = norm(points[vertical_edge_inds[0][0]] - points[ + vertical_edge_inds[0][1]]) + norm(points[vertical_edge_inds[1][ + 0]] - points[vertical_edge_inds[1][1]]) + horizontal_len_sum = norm(points[horizontal_edge_inds[0][ + 0]] - points[horizontal_edge_inds[0][1]]) + norm(points[ + horizontal_edge_inds[1][0]] - points[horizontal_edge_inds[1] + [1]]) + + if vertical_len_sum > horizontal_len_sum * orientation_thr: + head_inds = horizontal_edge_inds[0] + tail_inds = horizontal_edge_inds[1] + else: + head_inds = vertical_edge_inds[0] + tail_inds = vertical_edge_inds[1] + + return head_inds, tail_inds + + def resample_sidelines(self, sideline1, sideline2, resample_step): + """Resample two sidelines to be of the same points number according to + step size. + + Args: + sideline1 (ndarray): The points composing a sideline of a text + polygon. + sideline2 (ndarray): The points composing another sideline of a + text polygon. + resample_step (float): The resampled step size. + + Returns: + resampled_line1 (ndarray): The resampled line 1. + resampled_line2 (ndarray): The resampled line 2. + """ + + assert sideline1.ndim == sideline2.ndim == 2 + assert sideline1.shape[1] == sideline2.shape[1] == 2 + assert sideline1.shape[0] >= 2 + assert sideline2.shape[0] >= 2 + assert isinstance(resample_step, float) + + length1 = sum([ + norm(sideline1[i + 1] - sideline1[i]) + for i in range(len(sideline1) - 1) + ]) + length2 = sum([ + norm(sideline2[i + 1] - sideline2[i]) + for i in range(len(sideline2) - 1) + ]) + + total_length = (length1 + length2) / 2 + resample_point_num = max(int(float(total_length) / resample_step), 1) + + resampled_line1 = self.resample_line(sideline1, resample_point_num) + resampled_line2 = self.resample_line(sideline2, resample_point_num) + + return resampled_line1, resampled_line2 + + def generate_center_region_mask(self, img_size, text_polys): + """Generate text center region mask. + + Args: + img_size (tuple): The image size of (height, width). + text_polys (list[list[ndarray]]): The list of text polygons. + + Returns: + center_region_mask (ndarray): The text center region mask. + """ + + assert isinstance(img_size, tuple) + # assert check_argument.is_2dlist(text_polys) + + h, w = img_size + + center_region_mask = np.zeros((h, w), np.uint8) + + center_region_boxes = [] + for poly in text_polys: + # assert len(poly) == 1 + polygon_points = poly.reshape(-1, 2) + _, _, top_line, bot_line = self.reorder_poly_edge(polygon_points) + resampled_top_line, resampled_bot_line = self.resample_sidelines( + top_line, bot_line, self.resample_step) + resampled_bot_line = resampled_bot_line[::-1] + center_line = (resampled_top_line + resampled_bot_line) / 2 + + line_head_shrink_len = norm(resampled_top_line[0] - + resampled_bot_line[0]) / 4.0 + line_tail_shrink_len = norm(resampled_top_line[-1] - + resampled_bot_line[-1]) / 4.0 + head_shrink_num = int(line_head_shrink_len // self.resample_step) + tail_shrink_num = int(line_tail_shrink_len // self.resample_step) + if len(center_line) > head_shrink_num + tail_shrink_num + 2: + center_line = center_line[head_shrink_num:len(center_line) - + tail_shrink_num] + resampled_top_line = resampled_top_line[head_shrink_num:len( + resampled_top_line) - tail_shrink_num] + resampled_bot_line = resampled_bot_line[head_shrink_num:len( + resampled_bot_line) - tail_shrink_num] + + for i in range(0, len(center_line) - 1): + tl = center_line[i] + (resampled_top_line[i] - center_line[i] + ) * self.center_region_shrink_ratio + tr = center_line[i + 1] + (resampled_top_line[i + 1] - + center_line[i + 1] + ) * self.center_region_shrink_ratio + br = center_line[i + 1] + (resampled_bot_line[i + 1] - + center_line[i + 1] + ) * self.center_region_shrink_ratio + bl = center_line[i] + (resampled_bot_line[i] - center_line[i] + ) * self.center_region_shrink_ratio + current_center_box = np.vstack([tl, tr, br, + bl]).astype(np.int32) + center_region_boxes.append(current_center_box) + + cv2.fillPoly(center_region_mask, center_region_boxes, 1) + return center_region_mask + + def resample_polygon(self, polygon, n=400): + """Resample one polygon with n points on its boundary. + + Args: + polygon (list[float]): The input polygon. + n (int): The number of resampled points. + Returns: + resampled_polygon (list[float]): The resampled polygon. + """ + length = [] + + for i in range(len(polygon)): + p1 = polygon[i] + if i == len(polygon) - 1: + p2 = polygon[0] + else: + p2 = polygon[i + 1] + length.append(((p1[0] - p2[0])**2 + (p1[1] - p2[1])**2)**0.5) + + total_length = sum(length) + n_on_each_line = (np.array(length) / (total_length + 1e-8)) * n + n_on_each_line = n_on_each_line.astype(np.int32) + new_polygon = [] + + for i in range(len(polygon)): + num = n_on_each_line[i] + p1 = polygon[i] + if i == len(polygon) - 1: + p2 = polygon[0] + else: + p2 = polygon[i + 1] + + if num == 0: + continue + + dxdy = (p2 - p1) / num + for j in range(num): + point = p1 + dxdy * j + new_polygon.append(point) + + return np.array(new_polygon) + + def normalize_polygon(self, polygon): + """Normalize one polygon so that its start point is at right most. + + Args: + polygon (list[float]): The origin polygon. + Returns: + new_polygon (lost[float]): The polygon with start point at right. + """ + temp_polygon = polygon - polygon.mean(axis=0) + x = np.abs(temp_polygon[:, 0]) + y = temp_polygon[:, 1] + index_x = np.argsort(x) + index_y = np.argmin(y[index_x[:8]]) + index = index_x[index_y] + new_polygon = np.concatenate([polygon[index:], polygon[:index]]) + return new_polygon + + def poly2fourier(self, polygon, fourier_degree): + """Perform Fourier transformation to generate Fourier coefficients ck + from polygon. + + Args: + polygon (ndarray): An input polygon. + fourier_degree (int): The maximum Fourier degree K. + Returns: + c (ndarray(complex)): Fourier coefficients. + """ + points = polygon[:, 0] + polygon[:, 1] * 1j + c_fft = fft(points) / len(points) + c = np.hstack((c_fft[-fourier_degree:], c_fft[:fourier_degree + 1])) + return c + + def clockwise(self, c, fourier_degree): + """Make sure the polygon reconstructed from Fourier coefficients c in + the clockwise direction. + + Args: + polygon (list[float]): The origin polygon. + Returns: + new_polygon (lost[float]): The polygon in clockwise point order. + """ + if np.abs(c[fourier_degree + 1]) > np.abs(c[fourier_degree - 1]): + return c + elif np.abs(c[fourier_degree + 1]) < np.abs(c[fourier_degree - 1]): + return c[::-1] + else: + if np.abs(c[fourier_degree + 2]) > np.abs(c[fourier_degree - 2]): + return c + else: + return c[::-1] + + def cal_fourier_signature(self, polygon, fourier_degree): + """Calculate Fourier signature from input polygon. + + Args: + polygon (ndarray): The input polygon. + fourier_degree (int): The maximum Fourier degree K. + Returns: + fourier_signature (ndarray): An array shaped (2k+1, 2) containing + real part and image part of 2k+1 Fourier coefficients. + """ + resampled_polygon = self.resample_polygon(polygon) + resampled_polygon = self.normalize_polygon(resampled_polygon) + + fourier_coeff = self.poly2fourier(resampled_polygon, fourier_degree) + fourier_coeff = self.clockwise(fourier_coeff, fourier_degree) + + real_part = np.real(fourier_coeff).reshape((-1, 1)) + image_part = np.imag(fourier_coeff).reshape((-1, 1)) + fourier_signature = np.hstack([real_part, image_part]) + + return fourier_signature + + def generate_fourier_maps(self, img_size, text_polys): + """Generate Fourier coefficient maps. + + Args: + img_size (tuple): The image size of (height, width). + text_polys (list[list[ndarray]]): The list of text polygons. + + Returns: + fourier_real_map (ndarray): The Fourier coefficient real part maps. + fourier_image_map (ndarray): The Fourier coefficient image part + maps. + """ + + assert isinstance(img_size, tuple) + + h, w = img_size + k = self.fourier_degree + real_map = np.zeros((k * 2 + 1, h, w), dtype=np.float32) + imag_map = np.zeros((k * 2 + 1, h, w), dtype=np.float32) + + for poly in text_polys: + mask = np.zeros((h, w), dtype=np.uint8) + polygon = np.array(poly).reshape((1, -1, 2)) + cv2.fillPoly(mask, polygon.astype(np.int32), 1) + fourier_coeff = self.cal_fourier_signature(polygon[0], k) + for i in range(-k, k + 1): + if i != 0: + real_map[i + k, :, :] = mask * fourier_coeff[i + k, 0] + ( + 1 - mask) * real_map[i + k, :, :] + imag_map[i + k, :, :] = mask * fourier_coeff[i + k, 1] + ( + 1 - mask) * imag_map[i + k, :, :] + else: + yx = np.argwhere(mask > 0.5) + k_ind = np.ones((len(yx)), dtype=np.int64) * k + y, x = yx[:, 0], yx[:, 1] + real_map[k_ind, y, x] = fourier_coeff[k, 0] - x + imag_map[k_ind, y, x] = fourier_coeff[k, 1] - y + + return real_map, imag_map + + def generate_text_region_mask(self, img_size, text_polys): + """Generate text center region mask and geometry attribute maps. + + Args: + img_size (tuple): The image size (height, width). + text_polys (list[list[ndarray]]): The list of text polygons. + + Returns: + text_region_mask (ndarray): The text region mask. + """ + + assert isinstance(img_size, tuple) + + h, w = img_size + text_region_mask = np.zeros((h, w), dtype=np.uint8) + + for poly in text_polys: + polygon = np.array(poly, dtype=np.int32).reshape((1, -1, 2)) + cv2.fillPoly(text_region_mask, polygon, 1) + + return text_region_mask + + def generate_effective_mask(self, mask_size: tuple, polygons_ignore): + """Generate effective mask by setting the ineffective regions to 0 and + effective regions to 1. + + Args: + mask_size (tuple): The mask size. + polygons_ignore (list[[ndarray]]: The list of ignored text + polygons. + + Returns: + mask (ndarray): The effective mask of (height, width). + """ + + mask = np.ones(mask_size, dtype=np.uint8) + + for poly in polygons_ignore: + instance = poly.reshape(-1, 2).astype(np.int32).reshape(1, -1, 2) + cv2.fillPoly(mask, instance, 0) + + return mask + + def generate_level_targets(self, img_size, text_polys, ignore_polys): + """Generate ground truth target on each level. + + Args: + img_size (list[int]): Shape of input image. + text_polys (list[list[ndarray]]): A list of ground truth polygons. + ignore_polys (list[list[ndarray]]): A list of ignored polygons. + Returns: + level_maps (list(ndarray)): A list of ground target on each level. + """ + h, w = img_size + lv_size_divs = self.level_size_divisors + lv_proportion_range = self.level_proportion_range + lv_text_polys = [[] for i in range(len(lv_size_divs))] + lv_ignore_polys = [[] for i in range(len(lv_size_divs))] + level_maps = [] + for poly in text_polys: + polygon = np.array(poly, dtype=np.int).reshape((1, -1, 2)) + _, _, box_w, box_h = cv2.boundingRect(polygon) + proportion = max(box_h, box_w) / (h + 1e-8) + + for ind, proportion_range in enumerate(lv_proportion_range): + if proportion_range[0] < proportion < proportion_range[1]: + lv_text_polys[ind].append(poly / lv_size_divs[ind]) + + for ignore_poly in ignore_polys: + polygon = np.array(ignore_poly, dtype=np.int).reshape((1, -1, 2)) + _, _, box_w, box_h = cv2.boundingRect(polygon) + proportion = max(box_h, box_w) / (h + 1e-8) + + for ind, proportion_range in enumerate(lv_proportion_range): + if proportion_range[0] < proportion < proportion_range[1]: + lv_ignore_polys[ind].append(ignore_poly / lv_size_divs[ind]) + + for ind, size_divisor in enumerate(lv_size_divs): + current_level_maps = [] + level_img_size = (h // size_divisor, w // size_divisor) + + text_region = self.generate_text_region_mask( + level_img_size, lv_text_polys[ind])[None] + current_level_maps.append(text_region) + + center_region = self.generate_center_region_mask( + level_img_size, lv_text_polys[ind])[None] + current_level_maps.append(center_region) + + effective_mask = self.generate_effective_mask( + level_img_size, lv_ignore_polys[ind])[None] + current_level_maps.append(effective_mask) + + fourier_real_map, fourier_image_maps = self.generate_fourier_maps( + level_img_size, lv_text_polys[ind]) + current_level_maps.append(fourier_real_map) + current_level_maps.append(fourier_image_maps) + + level_maps.append(np.concatenate(current_level_maps)) + + return level_maps + + def generate_targets(self, results): + """Generate the ground truth targets for FCENet. + + Args: + results (dict): The input result dictionary. + + Returns: + results (dict): The output result dictionary. + """ + + assert isinstance(results, dict) + image = results['image'] + polygons = results['polys'] + ignore_tags = results['ignore_tags'] + h, w, _ = image.shape + + polygon_masks = [] + polygon_masks_ignore = [] + for tag, polygon in zip(ignore_tags, polygons): + if tag is True: + polygon_masks_ignore.append(polygon) + else: + polygon_masks.append(polygon) + + level_maps = self.generate_level_targets((h, w), polygon_masks, + polygon_masks_ignore) + + mapping = { + 'p3_maps': level_maps[0], + 'p4_maps': level_maps[1], + 'p5_maps': level_maps[2] + } + for key, value in mapping.items(): + results[key] = value + + return results + + def __call__(self, results): + results = self.generate_targets(results) + return results diff --git a/backend/ppocr/data/imaug/gen_table_mask.py b/backend/ppocr/data/imaug/gen_table_mask.py new file mode 100644 index 00000000..08e35d5d --- /dev/null +++ b/backend/ppocr/data/imaug/gen_table_mask.py @@ -0,0 +1,244 @@ +""" +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import sys +import six +import cv2 +import numpy as np + + +class GenTableMask(object): + """ gen table mask """ + + def __init__(self, shrink_h_max, shrink_w_max, mask_type=0, **kwargs): + self.shrink_h_max = 5 + self.shrink_w_max = 5 + self.mask_type = mask_type + + def projection(self, erosion, h, w, spilt_threshold=0): + # 水平投影 + projection_map = np.ones_like(erosion) + project_val_array = [0 for _ in range(0, h)] + + for j in range(0, h): + for i in range(0, w): + if erosion[j, i] == 255: + project_val_array[j] += 1 + # 根据数组,获取切割点 + start_idx = 0 # 记录进入字符区的索引 + end_idx = 0 # 记录进入空白区域的索引 + in_text = False # 是否遍历到了字符区内 + box_list = [] + for i in range(len(project_val_array)): + if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了 + in_text = True + start_idx = i + elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了 + end_idx = i + in_text = False + if end_idx - start_idx <= 2: + continue + box_list.append((start_idx, end_idx + 1)) + + if in_text: + box_list.append((start_idx, h - 1)) + # 绘制投影直方图 + for j in range(0, h): + for i in range(0, project_val_array[j]): + projection_map[j, i] = 0 + return box_list, projection_map + + def projection_cx(self, box_img): + box_gray_img = cv2.cvtColor(box_img, cv2.COLOR_BGR2GRAY) + h, w = box_gray_img.shape + # 灰度图片进行二值化处理 + ret, thresh1 = cv2.threshold(box_gray_img, 200, 255, cv2.THRESH_BINARY_INV) + # 纵向腐蚀 + if h < w: + kernel = np.ones((2, 1), np.uint8) + erode = cv2.erode(thresh1, kernel, iterations=1) + else: + erode = thresh1 + # 水平膨胀 + kernel = np.ones((1, 5), np.uint8) + erosion = cv2.dilate(erode, kernel, iterations=1) + # 水平投影 + projection_map = np.ones_like(erosion) + project_val_array = [0 for _ in range(0, h)] + + for j in range(0, h): + for i in range(0, w): + if erosion[j, i] == 255: + project_val_array[j] += 1 + # 根据数组,获取切割点 + start_idx = 0 # 记录进入字符区的索引 + end_idx = 0 # 记录进入空白区域的索引 + in_text = False # 是否遍历到了字符区内 + box_list = [] + spilt_threshold = 0 + for i in range(len(project_val_array)): + if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了 + in_text = True + start_idx = i + elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了 + end_idx = i + in_text = False + if end_idx - start_idx <= 2: + continue + box_list.append((start_idx, end_idx + 1)) + + if in_text: + box_list.append((start_idx, h - 1)) + # 绘制投影直方图 + for j in range(0, h): + for i in range(0, project_val_array[j]): + projection_map[j, i] = 0 + split_bbox_list = [] + if len(box_list) > 1: + for i, (h_start, h_end) in enumerate(box_list): + if i == 0: + h_start = 0 + if i == len(box_list): + h_end = h + word_img = erosion[h_start:h_end + 1, :] + word_h, word_w = word_img.shape + w_split_list, w_projection_map = self.projection(word_img.T, word_w, word_h) + w_start, w_end = w_split_list[0][0], w_split_list[-1][1] + if h_start > 0: + h_start -= 1 + h_end += 1 + word_img = box_img[h_start:h_end + 1:, w_start:w_end + 1, :] + split_bbox_list.append([w_start, h_start, w_end, h_end]) + else: + split_bbox_list.append([0, 0, w, h]) + return split_bbox_list + + def shrink_bbox(self, bbox): + left, top, right, bottom = bbox + sh_h = min(max(int((bottom - top) * 0.1), 1), self.shrink_h_max) + sh_w = min(max(int((right - left) * 0.1), 1), self.shrink_w_max) + left_new = left + sh_w + right_new = right - sh_w + top_new = top + sh_h + bottom_new = bottom - sh_h + if left_new >= right_new: + left_new = left + right_new = right + if top_new >= bottom_new: + top_new = top + bottom_new = bottom + return [left_new, top_new, right_new, bottom_new] + + def __call__(self, data): + img = data['image'] + cells = data['cells'] + height, width = img.shape[0:2] + if self.mask_type == 1: + mask_img = np.zeros((height, width), dtype=np.float32) + else: + mask_img = np.zeros((height, width, 3), dtype=np.float32) + cell_num = len(cells) + for cno in range(cell_num): + if "bbox" in cells[cno]: + bbox = cells[cno]['bbox'] + left, top, right, bottom = bbox + box_img = img[top:bottom, left:right, :].copy() + split_bbox_list = self.projection_cx(box_img) + for sno in range(len(split_bbox_list)): + split_bbox_list[sno][0] += left + split_bbox_list[sno][1] += top + split_bbox_list[sno][2] += left + split_bbox_list[sno][3] += top + + for sno in range(len(split_bbox_list)): + left, top, right, bottom = split_bbox_list[sno] + left, top, right, bottom = self.shrink_bbox([left, top, right, bottom]) + if self.mask_type == 1: + mask_img[top:bottom, left:right] = 1.0 + data['mask_img'] = mask_img + else: + mask_img[top:bottom, left:right, :] = (255, 255, 255) + data['image'] = mask_img + return data + +class ResizeTableImage(object): + def __init__(self, max_len, **kwargs): + super(ResizeTableImage, self).__init__() + self.max_len = max_len + + def get_img_bbox(self, cells): + bbox_list = [] + if len(cells) == 0: + return bbox_list + cell_num = len(cells) + for cno in range(cell_num): + if "bbox" in cells[cno]: + bbox = cells[cno]['bbox'] + bbox_list.append(bbox) + return bbox_list + + def resize_img_table(self, img, bbox_list, max_len): + height, width = img.shape[0:2] + ratio = max_len / (max(height, width) * 1.0) + resize_h = int(height * ratio) + resize_w = int(width * ratio) + img_new = cv2.resize(img, (resize_w, resize_h)) + bbox_list_new = [] + for bno in range(len(bbox_list)): + left, top, right, bottom = bbox_list[bno].copy() + left = int(left * ratio) + top = int(top * ratio) + right = int(right * ratio) + bottom = int(bottom * ratio) + bbox_list_new.append([left, top, right, bottom]) + return img_new, bbox_list_new + + def __call__(self, data): + img = data['image'] + if 'cells' not in data: + cells = [] + else: + cells = data['cells'] + bbox_list = self.get_img_bbox(cells) + img_new, bbox_list_new = self.resize_img_table(img, bbox_list, self.max_len) + data['image'] = img_new + cell_num = len(cells) + bno = 0 + for cno in range(cell_num): + if "bbox" in data['cells'][cno]: + data['cells'][cno]['bbox'] = bbox_list_new[bno] + bno += 1 + data['max_len'] = self.max_len + return data + +class PaddingTableImage(object): + def __init__(self, **kwargs): + super(PaddingTableImage, self).__init__() + + def __call__(self, data): + img = data['image'] + max_len = data['max_len'] + padding_img = np.zeros((max_len, max_len, 3), dtype=np.float32) + height, width = img.shape[0:2] + padding_img[0:height, 0:width, :] = img.copy() + data['image'] = padding_img + return data + \ No newline at end of file diff --git a/backend/ppocr/data/imaug/iaa_augment.py b/backend/ppocr/data/imaug/iaa_augment.py index 9ce6bd42..0aac7877 100644 --- a/backend/ppocr/data/imaug/iaa_augment.py +++ b/backend/ppocr/data/imaug/iaa_augment.py @@ -11,6 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +This code is refer from: +https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/iaa_augment.py +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function diff --git a/backend/ppocr/data/imaug/label_ops.py b/backend/ppocr/data/imaug/label_ops.py new file mode 100644 index 00000000..c9bc2e77 --- /dev/null +++ b/backend/ppocr/data/imaug/label_ops.py @@ -0,0 +1,1041 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import copy +import numpy as np +import string +from shapely.geometry import LineString, Point, Polygon +import json +import copy + +from ppocr.utils.logging import get_logger + + +class ClsLabelEncode(object): + def __init__(self, label_list, **kwargs): + self.label_list = label_list + + def __call__(self, data): + label = data['label'] + if label not in self.label_list: + return None + label = self.label_list.index(label) + data['label'] = label + return data + + +class DetLabelEncode(object): + def __init__(self, **kwargs): + pass + + def __call__(self, data): + label = data['label'] + label = json.loads(label) + nBox = len(label) + boxes, txts, txt_tags = [], [], [] + for bno in range(0, nBox): + box = label[bno]['points'] + txt = label[bno]['transcription'] + boxes.append(box) + txts.append(txt) + if txt in ['*', '###']: + txt_tags.append(True) + else: + txt_tags.append(False) + if len(boxes) == 0: + return None + boxes = self.expand_points_num(boxes) + boxes = np.array(boxes, dtype=np.float32) + txt_tags = np.array(txt_tags, dtype=np.bool) + + data['polys'] = boxes + data['texts'] = txts + data['ignore_tags'] = txt_tags + return data + + def order_points_clockwise(self, pts): + rect = np.zeros((4, 2), dtype="float32") + s = pts.sum(axis=1) + rect[0] = pts[np.argmin(s)] + rect[2] = pts[np.argmax(s)] + diff = np.diff(pts, axis=1) + rect[1] = pts[np.argmin(diff)] + rect[3] = pts[np.argmax(diff)] + return rect + + def expand_points_num(self, boxes): + max_points_num = 0 + for box in boxes: + if len(box) > max_points_num: + max_points_num = len(box) + ex_boxes = [] + for box in boxes: + ex_box = box + [box[-1]] * (max_points_num - len(box)) + ex_boxes.append(ex_box) + return ex_boxes + + +class BaseRecLabelEncode(object): + """ Convert between text-label and text-index """ + + def __init__(self, + max_text_length, + character_dict_path=None, + use_space_char=False): + + self.max_text_len = max_text_length + self.beg_str = "sos" + self.end_str = "eos" + self.lower = False + + if character_dict_path is None: + logger = get_logger() + logger.warning( + "The character_dict_path is None, model can only recognize number and lower letters" + ) + self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" + dict_character = list(self.character_str) + self.lower = True + else: + self.character_str = [] + with open(character_dict_path, "rb") as fin: + lines = fin.readlines() + for line in lines: + line = line.decode('utf-8').strip("\n").strip("\r\n") + self.character_str.append(line) + if use_space_char: + self.character_str.append(" ") + dict_character = list(self.character_str) + dict_character = self.add_special_char(dict_character) + self.dict = {} + for i, char in enumerate(dict_character): + self.dict[char] = i + self.character = dict_character + + def add_special_char(self, dict_character): + return dict_character + + def encode(self, text): + """convert text-label into text-index. + input: + text: text labels of each image. [batch_size] + + output: + text: concatenated text index for CTCLoss. + [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)] + length: length of each text. [batch_size] + """ + if len(text) == 0 or len(text) > self.max_text_len: + return None + if self.lower: + text = text.lower() + text_list = [] + for char in text: + if char not in self.dict: + # logger = get_logger() + # logger.warning('{} is not in dict'.format(char)) + continue + text_list.append(self.dict[char]) + if len(text_list) == 0: + return None + return text_list + + +class NRTRLabelEncode(BaseRecLabelEncode): + """ Convert between text-label and text-index """ + + def __init__(self, + max_text_length, + character_dict_path=None, + use_space_char=False, + **kwargs): + + super(NRTRLabelEncode, self).__init__( + max_text_length, character_dict_path, use_space_char) + + def __call__(self, data): + text = data['label'] + text = self.encode(text) + if text is None: + return None + if len(text) >= self.max_text_len - 1: + return None + data['length'] = np.array(len(text)) + text.insert(0, 2) + text.append(3) + text = text + [0] * (self.max_text_len - len(text)) + data['label'] = np.array(text) + return data + + def add_special_char(self, dict_character): + dict_character = ['blank', '', '', ''] + dict_character + return dict_character + + +class CTCLabelEncode(BaseRecLabelEncode): + """ Convert between text-label and text-index """ + + def __init__(self, + max_text_length, + character_dict_path=None, + use_space_char=False, + **kwargs): + super(CTCLabelEncode, self).__init__( + max_text_length, character_dict_path, use_space_char) + + def __call__(self, data): + text = data['label'] + text = self.encode(text) + if text is None: + return None + data['length'] = np.array(len(text)) + text = text + [0] * (self.max_text_len - len(text)) + data['label'] = np.array(text) + + label = [0] * len(self.character) + for x in text: + label[x] += 1 + data['label_ace'] = np.array(label) + return data + + def add_special_char(self, dict_character): + dict_character = ['blank'] + dict_character + return dict_character + + +class E2ELabelEncodeTest(BaseRecLabelEncode): + def __init__(self, + max_text_length, + character_dict_path=None, + use_space_char=False, + **kwargs): + super(E2ELabelEncodeTest, self).__init__( + max_text_length, character_dict_path, use_space_char) + + def __call__(self, data): + import json + padnum = len(self.dict) + label = data['label'] + label = json.loads(label) + nBox = len(label) + boxes, txts, txt_tags = [], [], [] + for bno in range(0, nBox): + box = label[bno]['points'] + txt = label[bno]['transcription'] + boxes.append(box) + txts.append(txt) + if txt in ['*', '###']: + txt_tags.append(True) + else: + txt_tags.append(False) + boxes = np.array(boxes, dtype=np.float32) + txt_tags = np.array(txt_tags, dtype=np.bool) + data['polys'] = boxes + data['ignore_tags'] = txt_tags + temp_texts = [] + for text in txts: + text = text.lower() + text = self.encode(text) + if text is None: + return None + text = text + [padnum] * (self.max_text_len - len(text) + ) # use 36 to pad + temp_texts.append(text) + data['texts'] = np.array(temp_texts) + return data + + +class E2ELabelEncodeTrain(object): + def __init__(self, **kwargs): + pass + + def __call__(self, data): + import json + label = data['label'] + label = json.loads(label) + nBox = len(label) + boxes, txts, txt_tags = [], [], [] + for bno in range(0, nBox): + box = label[bno]['points'] + txt = label[bno]['transcription'] + boxes.append(box) + txts.append(txt) + if txt in ['*', '###']: + txt_tags.append(True) + else: + txt_tags.append(False) + boxes = np.array(boxes, dtype=np.float32) + txt_tags = np.array(txt_tags, dtype=np.bool) + + data['polys'] = boxes + data['texts'] = txts + data['ignore_tags'] = txt_tags + return data + + +class KieLabelEncode(object): + def __init__(self, character_dict_path, norm=10, directed=False, **kwargs): + super(KieLabelEncode, self).__init__() + self.dict = dict({'': 0}) + with open(character_dict_path, 'r', encoding='utf-8') as fr: + idx = 1 + for line in fr: + char = line.strip() + self.dict[char] = idx + idx += 1 + self.norm = norm + self.directed = directed + + def compute_relation(self, boxes): + """Compute relation between every two boxes.""" + x1s, y1s = boxes[:, 0:1], boxes[:, 1:2] + x2s, y2s = boxes[:, 4:5], boxes[:, 5:6] + ws, hs = x2s - x1s + 1, np.maximum(y2s - y1s + 1, 1) + dxs = (x1s[:, 0][None] - x1s) / self.norm + dys = (y1s[:, 0][None] - y1s) / self.norm + xhhs, xwhs = hs[:, 0][None] / hs, ws[:, 0][None] / hs + whs = ws / hs + np.zeros_like(xhhs) + relations = np.stack([dxs, dys, whs, xhhs, xwhs], -1) + bboxes = np.concatenate([x1s, y1s, x2s, y2s], -1).astype(np.float32) + return relations, bboxes + + def pad_text_indices(self, text_inds): + """Pad text index to same length.""" + max_len = 300 + recoder_len = max([len(text_ind) for text_ind in text_inds]) + padded_text_inds = -np.ones((len(text_inds), max_len), np.int32) + for idx, text_ind in enumerate(text_inds): + padded_text_inds[idx, :len(text_ind)] = np.array(text_ind) + return padded_text_inds, recoder_len + + def list_to_numpy(self, ann_infos): + """Convert bboxes, relations, texts and labels to ndarray.""" + boxes, text_inds = ann_infos['points'], ann_infos['text_inds'] + boxes = np.array(boxes, np.int32) + relations, bboxes = self.compute_relation(boxes) + + labels = ann_infos.get('labels', None) + if labels is not None: + labels = np.array(labels, np.int32) + edges = ann_infos.get('edges', None) + if edges is not None: + labels = labels[:, None] + edges = np.array(edges) + edges = (edges[:, None] == edges[None, :]).astype(np.int32) + if self.directed: + edges = (edges & labels == 1).astype(np.int32) + np.fill_diagonal(edges, -1) + labels = np.concatenate([labels, edges], -1) + padded_text_inds, recoder_len = self.pad_text_indices(text_inds) + max_num = 300 + temp_bboxes = np.zeros([max_num, 4]) + h, _ = bboxes.shape + temp_bboxes[:h, :] = bboxes + + temp_relations = np.zeros([max_num, max_num, 5]) + temp_relations[:h, :h, :] = relations + + temp_padded_text_inds = np.zeros([max_num, max_num]) + temp_padded_text_inds[:h, :] = padded_text_inds + + temp_labels = np.zeros([max_num, max_num]) + temp_labels[:h, :h + 1] = labels + + tag = np.array([h, recoder_len]) + return dict( + image=ann_infos['image'], + points=temp_bboxes, + relations=temp_relations, + texts=temp_padded_text_inds, + labels=temp_labels, + tag=tag) + + def convert_canonical(self, points_x, points_y): + + assert len(points_x) == 4 + assert len(points_y) == 4 + + points = [Point(points_x[i], points_y[i]) for i in range(4)] + + polygon = Polygon([(p.x, p.y) for p in points]) + min_x, min_y, _, _ = polygon.bounds + points_to_lefttop = [ + LineString([points[i], Point(min_x, min_y)]) for i in range(4) + ] + distances = np.array([line.length for line in points_to_lefttop]) + sort_dist_idx = np.argsort(distances) + lefttop_idx = sort_dist_idx[0] + + if lefttop_idx == 0: + point_orders = [0, 1, 2, 3] + elif lefttop_idx == 1: + point_orders = [1, 2, 3, 0] + elif lefttop_idx == 2: + point_orders = [2, 3, 0, 1] + else: + point_orders = [3, 0, 1, 2] + + sorted_points_x = [points_x[i] for i in point_orders] + sorted_points_y = [points_y[j] for j in point_orders] + + return sorted_points_x, sorted_points_y + + def sort_vertex(self, points_x, points_y): + + assert len(points_x) == 4 + assert len(points_y) == 4 + + x = np.array(points_x) + y = np.array(points_y) + center_x = np.sum(x) * 0.25 + center_y = np.sum(y) * 0.25 + + x_arr = np.array(x - center_x) + y_arr = np.array(y - center_y) + + angle = np.arctan2(y_arr, x_arr) * 180.0 / np.pi + sort_idx = np.argsort(angle) + + sorted_points_x, sorted_points_y = [], [] + for i in range(4): + sorted_points_x.append(points_x[sort_idx[i]]) + sorted_points_y.append(points_y[sort_idx[i]]) + + return self.convert_canonical(sorted_points_x, sorted_points_y) + + def __call__(self, data): + import json + label = data['label'] + annotations = json.loads(label) + boxes, texts, text_inds, labels, edges = [], [], [], [], [] + for ann in annotations: + box = ann['points'] + x_list = [box[i][0] for i in range(4)] + y_list = [box[i][1] for i in range(4)] + sorted_x_list, sorted_y_list = self.sort_vertex(x_list, y_list) + sorted_box = [] + for x, y in zip(sorted_x_list, sorted_y_list): + sorted_box.append(x) + sorted_box.append(y) + boxes.append(sorted_box) + text = ann['transcription'] + texts.append(ann['transcription']) + text_ind = [self.dict[c] for c in text if c in self.dict] + text_inds.append(text_ind) + labels.append(ann['label']) + edges.append(ann.get('edge', 0)) + ann_infos = dict( + image=data['image'], + points=boxes, + texts=texts, + text_inds=text_inds, + edges=edges, + labels=labels) + + return self.list_to_numpy(ann_infos) + + +class AttnLabelEncode(BaseRecLabelEncode): + """ Convert between text-label and text-index """ + + def __init__(self, + max_text_length, + character_dict_path=None, + use_space_char=False, + **kwargs): + super(AttnLabelEncode, self).__init__( + max_text_length, character_dict_path, use_space_char) + + def add_special_char(self, dict_character): + self.beg_str = "sos" + self.end_str = "eos" + dict_character = [self.beg_str] + dict_character + [self.end_str] + return dict_character + + def __call__(self, data): + text = data['label'] + text = self.encode(text) + if text is None: + return None + if len(text) >= self.max_text_len: + return None + data['length'] = np.array(len(text)) + text = [0] + text + [len(self.character) - 1] + [0] * (self.max_text_len + - len(text) - 2) + data['label'] = np.array(text) + return data + + def get_ignored_tokens(self): + beg_idx = self.get_beg_end_flag_idx("beg") + end_idx = self.get_beg_end_flag_idx("end") + return [beg_idx, end_idx] + + def get_beg_end_flag_idx(self, beg_or_end): + if beg_or_end == "beg": + idx = np.array(self.dict[self.beg_str]) + elif beg_or_end == "end": + idx = np.array(self.dict[self.end_str]) + else: + assert False, "Unsupport type %s in get_beg_end_flag_idx" \ + % beg_or_end + return idx + + +class SEEDLabelEncode(BaseRecLabelEncode): + """ Convert between text-label and text-index """ + + def __init__(self, + max_text_length, + character_dict_path=None, + use_space_char=False, + **kwargs): + super(SEEDLabelEncode, self).__init__( + max_text_length, character_dict_path, use_space_char) + + def add_special_char(self, dict_character): + self.padding = "padding" + self.end_str = "eos" + self.unknown = "unknown" + dict_character = dict_character + [ + self.end_str, self.padding, self.unknown + ] + return dict_character + + def __call__(self, data): + text = data['label'] + text = self.encode(text) + if text is None: + return None + if len(text) >= self.max_text_len: + return None + data['length'] = np.array(len(text)) + 1 # conclude eos + text = text + [len(self.character) - 3] + [len(self.character) - 2] * ( + self.max_text_len - len(text) - 1) + data['label'] = np.array(text) + return data + + +class SRNLabelEncode(BaseRecLabelEncode): + """ Convert between text-label and text-index """ + + def __init__(self, + max_text_length=25, + character_dict_path=None, + use_space_char=False, + **kwargs): + super(SRNLabelEncode, self).__init__( + max_text_length, character_dict_path, use_space_char) + + def add_special_char(self, dict_character): + dict_character = dict_character + [self.beg_str, self.end_str] + return dict_character + + def __call__(self, data): + text = data['label'] + text = self.encode(text) + char_num = len(self.character) + if text is None: + return None + if len(text) > self.max_text_len: + return None + data['length'] = np.array(len(text)) + text = text + [char_num - 1] * (self.max_text_len - len(text)) + data['label'] = np.array(text) + return data + + def get_ignored_tokens(self): + beg_idx = self.get_beg_end_flag_idx("beg") + end_idx = self.get_beg_end_flag_idx("end") + return [beg_idx, end_idx] + + def get_beg_end_flag_idx(self, beg_or_end): + if beg_or_end == "beg": + idx = np.array(self.dict[self.beg_str]) + elif beg_or_end == "end": + idx = np.array(self.dict[self.end_str]) + else: + assert False, "Unsupport type %s in get_beg_end_flag_idx" \ + % beg_or_end + return idx + + +class TableLabelEncode(object): + """ Convert between text-label and text-index """ + + def __init__(self, + max_text_length, + max_elem_length, + max_cell_num, + character_dict_path, + span_weight=1.0, + **kwargs): + self.max_text_length = max_text_length + self.max_elem_length = max_elem_length + self.max_cell_num = max_cell_num + list_character, list_elem = self.load_char_elem_dict( + character_dict_path) + list_character = self.add_special_char(list_character) + list_elem = self.add_special_char(list_elem) + self.dict_character = {} + for i, char in enumerate(list_character): + self.dict_character[char] = i + self.dict_elem = {} + for i, elem in enumerate(list_elem): + self.dict_elem[elem] = i + self.span_weight = span_weight + + def load_char_elem_dict(self, character_dict_path): + list_character = [] + list_elem = [] + with open(character_dict_path, "rb") as fin: + lines = fin.readlines() + substr = lines[0].decode('utf-8').strip("\r\n").split("\t") + character_num = int(substr[0]) + elem_num = int(substr[1]) + for cno in range(1, 1 + character_num): + character = lines[cno].decode('utf-8').strip("\r\n") + list_character.append(character) + for eno in range(1 + character_num, 1 + character_num + elem_num): + elem = lines[eno].decode('utf-8').strip("\r\n") + list_elem.append(elem) + return list_character, list_elem + + def add_special_char(self, list_character): + self.beg_str = "sos" + self.end_str = "eos" + list_character = [self.beg_str] + list_character + [self.end_str] + return list_character + + def get_span_idx_list(self): + span_idx_list = [] + for elem in self.dict_elem: + if 'span' in elem: + span_idx_list.append(self.dict_elem[elem]) + return span_idx_list + + def __call__(self, data): + cells = data['cells'] + structure = data['structure']['tokens'] + structure = self.encode(structure, 'elem') + if structure is None: + return None + elem_num = len(structure) + structure = [0] + structure + [len(self.dict_elem) - 1] + structure = structure + [0] * (self.max_elem_length + 2 - len(structure) + ) + structure = np.array(structure) + data['structure'] = structure + elem_char_idx1 = self.dict_elem[''] + elem_char_idx2 = self.dict_elem[' 0: + span_weight = len(td_idx_list) * 1.0 / len(span_idx_list) + span_weight = min(max(span_weight, 1.0), self.span_weight) + for cno in range(len(cells)): + if 'bbox' in cells[cno]: + bbox = cells[cno]['bbox'].copy() + bbox[0] = bbox[0] * 1.0 / img_width + bbox[1] = bbox[1] * 1.0 / img_height + bbox[2] = bbox[2] * 1.0 / img_width + bbox[3] = bbox[3] * 1.0 / img_height + td_idx = td_idx_list[cno] + bbox_list[td_idx] = bbox + bbox_list_mask[td_idx] = 1.0 + cand_span_idx = td_idx + 1 + if cand_span_idx < (self.max_elem_length + 2): + if structure[cand_span_idx] in span_idx_list: + structure_mask[cand_span_idx] = span_weight + + data['bbox_list'] = bbox_list + data['bbox_list_mask'] = bbox_list_mask + data['structure_mask'] = structure_mask + char_beg_idx = self.get_beg_end_flag_idx('beg', 'char') + char_end_idx = self.get_beg_end_flag_idx('end', 'char') + elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem') + elem_end_idx = self.get_beg_end_flag_idx('end', 'elem') + data['sp_tokens'] = np.array([ + char_beg_idx, char_end_idx, elem_beg_idx, elem_end_idx, + elem_char_idx1, elem_char_idx2, self.max_text_length, + self.max_elem_length, self.max_cell_num, elem_num + ]) + return data + + def encode(self, text, char_or_elem): + """convert text-label into text-index. + """ + if char_or_elem == "char": + max_len = self.max_text_length + current_dict = self.dict_character + else: + max_len = self.max_elem_length + current_dict = self.dict_elem + if len(text) > max_len: + return None + if len(text) == 0: + if char_or_elem == "char": + return [self.dict_character['space']] + else: + return None + text_list = [] + for char in text: + if char not in current_dict: + return None + text_list.append(current_dict[char]) + if len(text_list) == 0: + if char_or_elem == "char": + return [self.dict_character['space']] + else: + return None + return text_list + + def get_ignored_tokens(self, char_or_elem): + beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem) + end_idx = self.get_beg_end_flag_idx("end", char_or_elem) + return [beg_idx, end_idx] + + def get_beg_end_flag_idx(self, beg_or_end, char_or_elem): + if char_or_elem == "char": + if beg_or_end == "beg": + idx = np.array(self.dict_character[self.beg_str]) + elif beg_or_end == "end": + idx = np.array(self.dict_character[self.end_str]) + else: + assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \ + % beg_or_end + elif char_or_elem == "elem": + if beg_or_end == "beg": + idx = np.array(self.dict_elem[self.beg_str]) + elif beg_or_end == "end": + idx = np.array(self.dict_elem[self.end_str]) + else: + assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \ + % beg_or_end + else: + assert False, "Unsupport type %s in char_or_elem" \ + % char_or_elem + return idx + + +class SARLabelEncode(BaseRecLabelEncode): + """ Convert between text-label and text-index """ + + def __init__(self, + max_text_length, + character_dict_path=None, + use_space_char=False, + **kwargs): + super(SARLabelEncode, self).__init__( + max_text_length, character_dict_path, use_space_char) + + def add_special_char(self, dict_character): + beg_end_str = "" + unknown_str = "" + padding_str = "" + dict_character = dict_character + [unknown_str] + self.unknown_idx = len(dict_character) - 1 + dict_character = dict_character + [beg_end_str] + self.start_idx = len(dict_character) - 1 + self.end_idx = len(dict_character) - 1 + dict_character = dict_character + [padding_str] + self.padding_idx = len(dict_character) - 1 + + return dict_character + + def __call__(self, data): + text = data['label'] + text = self.encode(text) + if text is None: + return None + if len(text) >= self.max_text_len - 1: + return None + data['length'] = np.array(len(text)) + target = [self.start_idx] + text + [self.end_idx] + padded_text = [self.padding_idx for _ in range(self.max_text_len)] + + padded_text[:len(target)] = target + data['label'] = np.array(padded_text) + return data + + def get_ignored_tokens(self): + return [self.padding_idx] + + +class PRENLabelEncode(BaseRecLabelEncode): + def __init__(self, + max_text_length, + character_dict_path, + use_space_char=False, + **kwargs): + super(PRENLabelEncode, self).__init__( + max_text_length, character_dict_path, use_space_char) + + def add_special_char(self, dict_character): + padding_str = '' # 0 + end_str = '' # 1 + unknown_str = '' # 2 + + dict_character = [padding_str, end_str, unknown_str] + dict_character + self.padding_idx = 0 + self.end_idx = 1 + self.unknown_idx = 2 + + return dict_character + + def encode(self, text): + if len(text) == 0 or len(text) >= self.max_text_len: + return None + if self.lower: + text = text.lower() + text_list = [] + for char in text: + if char not in self.dict: + text_list.append(self.unknown_idx) + else: + text_list.append(self.dict[char]) + text_list.append(self.end_idx) + if len(text_list) < self.max_text_len: + text_list += [self.padding_idx] * ( + self.max_text_len - len(text_list)) + return text_list + + def __call__(self, data): + text = data['label'] + encoded_text = self.encode(text) + if encoded_text is None: + return None + data['label'] = np.array(encoded_text) + return data + + +class VQATokenLabelEncode(object): + """ + Label encode for NLP VQA methods + """ + + def __init__(self, + class_path, + contains_re=False, + add_special_ids=False, + algorithm='LayoutXLM', + infer_mode=False, + ocr_engine=None, + **kwargs): + super(VQATokenLabelEncode, self).__init__() + from paddlenlp.transformers import LayoutXLMTokenizer, LayoutLMTokenizer, LayoutLMv2Tokenizer + from ppocr.utils.utility import load_vqa_bio_label_maps + tokenizer_dict = { + 'LayoutXLM': { + 'class': LayoutXLMTokenizer, + 'pretrained_model': 'layoutxlm-base-uncased' + }, + 'LayoutLM': { + 'class': LayoutLMTokenizer, + 'pretrained_model': 'layoutlm-base-uncased' + }, + 'LayoutLMv2': { + 'class': LayoutLMv2Tokenizer, + 'pretrained_model': 'layoutlmv2-base-uncased' + } + } + self.contains_re = contains_re + tokenizer_config = tokenizer_dict[algorithm] + self.tokenizer = tokenizer_config['class'].from_pretrained( + tokenizer_config['pretrained_model']) + self.label2id_map, id2label_map = load_vqa_bio_label_maps(class_path) + self.add_special_ids = add_special_ids + self.infer_mode = infer_mode + self.ocr_engine = ocr_engine + + def __call__(self, data): + # load bbox and label info + ocr_info = self._load_ocr_info(data) + + height, width, _ = data['image'].shape + + words_list = [] + bbox_list = [] + input_ids_list = [] + token_type_ids_list = [] + segment_offset_id = [] + gt_label_list = [] + + entities = [] + + # for re + train_re = self.contains_re and not self.infer_mode + if train_re: + relations = [] + id2label = {} + entity_id_to_index_map = {} + empty_entity = set() + + data['ocr_info'] = copy.deepcopy(ocr_info) + + for info in ocr_info: + if train_re: + # for re + if len(info["text"]) == 0: + empty_entity.add(info["id"]) + continue + id2label[info["id"]] = info["label"] + relations.extend([tuple(sorted(l)) for l in info["linking"]]) + # smooth_box + bbox = self._smooth_box(info["bbox"], height, width) + + text = info["text"] + encode_res = self.tokenizer.encode( + text, pad_to_max_seq_len=False, return_attention_mask=True) + + if not self.add_special_ids: + # TODO: use tok.all_special_ids to remove + encode_res["input_ids"] = encode_res["input_ids"][1:-1] + encode_res["token_type_ids"] = encode_res["token_type_ids"][1: + -1] + encode_res["attention_mask"] = encode_res["attention_mask"][1: + -1] + # parse label + if not self.infer_mode: + label = info['label'] + gt_label = self._parse_label(label, encode_res) + + # construct entities for re + if train_re: + if gt_label[0] != self.label2id_map["O"]: + entity_id_to_index_map[info["id"]] = len(entities) + label = label.upper() + entities.append({ + "start": len(input_ids_list), + "end": + len(input_ids_list) + len(encode_res["input_ids"]), + "label": label.upper(), + }) + else: + entities.append({ + "start": len(input_ids_list), + "end": len(input_ids_list) + len(encode_res["input_ids"]), + "label": 'O', + }) + input_ids_list.extend(encode_res["input_ids"]) + token_type_ids_list.extend(encode_res["token_type_ids"]) + bbox_list.extend([bbox] * len(encode_res["input_ids"])) + words_list.append(text) + segment_offset_id.append(len(input_ids_list)) + if not self.infer_mode: + gt_label_list.extend(gt_label) + + data['input_ids'] = input_ids_list + data['token_type_ids'] = token_type_ids_list + data['bbox'] = bbox_list + data['attention_mask'] = [1] * len(input_ids_list) + data['labels'] = gt_label_list + data['segment_offset_id'] = segment_offset_id + data['tokenizer_params'] = dict( + padding_side=self.tokenizer.padding_side, + pad_token_type_id=self.tokenizer.pad_token_type_id, + pad_token_id=self.tokenizer.pad_token_id) + data['entities'] = entities + + if train_re: + data['relations'] = relations + data['id2label'] = id2label + data['empty_entity'] = empty_entity + data['entity_id_to_index_map'] = entity_id_to_index_map + return data + + def _load_ocr_info(self, data): + def trans_poly_to_bbox(poly): + x1 = np.min([p[0] for p in poly]) + x2 = np.max([p[0] for p in poly]) + y1 = np.min([p[1] for p in poly]) + y2 = np.max([p[1] for p in poly]) + return [x1, y1, x2, y2] + + if self.infer_mode: + ocr_result = self.ocr_engine.ocr(data['image'], cls=False) + ocr_info = [] + for res in ocr_result: + ocr_info.append({ + "text": res[1][0], + "bbox": trans_poly_to_bbox(res[0]), + "poly": res[0], + }) + return ocr_info + else: + info = data['label'] + # read text info + info_dict = json.loads(info) + return info_dict["ocr_info"] + + def _smooth_box(self, bbox, height, width): + bbox[0] = int(bbox[0] * 1000.0 / width) + bbox[2] = int(bbox[2] * 1000.0 / width) + bbox[1] = int(bbox[1] * 1000.0 / height) + bbox[3] = int(bbox[3] * 1000.0 / height) + return bbox + + def _parse_label(self, label, encode_res): + gt_label = [] + if label.lower() == "other": + gt_label.extend([0] * len(encode_res["input_ids"])) + else: + gt_label.append(self.label2id_map[("b-" + label).upper()]) + gt_label.extend([self.label2id_map[("i-" + label).upper()]] * + (len(encode_res["input_ids"]) - 1)) + return gt_label + + +class MultiLabelEncode(BaseRecLabelEncode): + def __init__(self, + max_text_length, + character_dict_path=None, + use_space_char=False, + **kwargs): + super(MultiLabelEncode, self).__init__( + max_text_length, character_dict_path, use_space_char) + + self.ctc_encode = CTCLabelEncode(max_text_length, character_dict_path, + use_space_char, **kwargs) + self.sar_encode = SARLabelEncode(max_text_length, character_dict_path, + use_space_char, **kwargs) + + def __call__(self, data): + + data_ctc = copy.deepcopy(data) + data_sar = copy.deepcopy(data) + data_out = dict() + data_out['img_path'] = data.get('img_path', None) + data_out['image'] = data['image'] + ctc = self.ctc_encode.__call__(data_ctc) + sar = self.sar_encode.__call__(data_sar) + if ctc is None or sar is None: + return None + data_out['label_ctc'] = ctc['label'] + data_out['label_sar'] = sar['label'] + data_out['length'] = ctc['length'] + return data_out diff --git a/backend/ppocr/data/imaug/make_border_map.py b/backend/ppocr/data/imaug/make_border_map.py index cc2c9034..abab3836 100644 --- a/backend/ppocr/data/imaug/make_border_map.py +++ b/backend/ppocr/data/imaug/make_border_map.py @@ -1,4 +1,20 @@ -# -*- coding:utf-8 -*- +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/make_border_map.py +""" from __future__ import absolute_import from __future__ import division diff --git a/backend/ppocr/data/imaug/make_pse_gt.py b/backend/ppocr/data/imaug/make_pse_gt.py new file mode 100644 index 00000000..255d076b --- /dev/null +++ b/backend/ppocr/data/imaug/make_pse_gt.py @@ -0,0 +1,106 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import cv2 +import numpy as np +import pyclipper +from shapely.geometry import Polygon + +__all__ = ['MakePseGt'] + + +class MakePseGt(object): + def __init__(self, kernel_num=7, size=640, min_shrink_ratio=0.4, **kwargs): + self.kernel_num = kernel_num + self.min_shrink_ratio = min_shrink_ratio + self.size = size + + def __call__(self, data): + + image = data['image'] + text_polys = data['polys'] + ignore_tags = data['ignore_tags'] + + h, w, _ = image.shape + short_edge = min(h, w) + if short_edge < self.size: + # keep short_size >= self.size + scale = self.size / short_edge + image = cv2.resize(image, dsize=None, fx=scale, fy=scale) + text_polys *= scale + + gt_kernels = [] + for i in range(1, self.kernel_num + 1): + # s1->sn, from big to small + rate = 1.0 - (1.0 - self.min_shrink_ratio) / (self.kernel_num - 1 + ) * i + text_kernel, ignore_tags = self.generate_kernel( + image.shape[0:2], rate, text_polys, ignore_tags) + gt_kernels.append(text_kernel) + + training_mask = np.ones(image.shape[0:2], dtype='uint8') + for i in range(text_polys.shape[0]): + if ignore_tags[i]: + cv2.fillPoly(training_mask, + text_polys[i].astype(np.int32)[np.newaxis, :, :], + 0) + + gt_kernels = np.array(gt_kernels) + gt_kernels[gt_kernels > 0] = 1 + + data['image'] = image + data['polys'] = text_polys + data['gt_kernels'] = gt_kernels[0:] + data['gt_text'] = gt_kernels[0] + data['mask'] = training_mask.astype('float32') + return data + + def generate_kernel(self, + img_size, + shrink_ratio, + text_polys, + ignore_tags=None): + """ + Refer to part of the code: + https://github.com/open-mmlab/mmocr/blob/main/mmocr/datasets/pipelines/textdet_targets/base_textdet_targets.py + """ + + h, w = img_size + text_kernel = np.zeros((h, w), dtype=np.float32) + for i, poly in enumerate(text_polys): + polygon = Polygon(poly) + distance = polygon.area * (1 - shrink_ratio * shrink_ratio) / ( + polygon.length + 1e-6) + subject = [tuple(l) for l in poly] + pco = pyclipper.PyclipperOffset() + pco.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + shrinked = np.array(pco.Execute(-distance)) + + if len(shrinked) == 0 or shrinked.size == 0: + if ignore_tags is not None: + ignore_tags[i] = True + continue + try: + shrinked = np.array(shrinked[0]).reshape(-1, 2) + except: + if ignore_tags is not None: + ignore_tags[i] = True + continue + cv2.fillPoly(text_kernel, [shrinked.astype(np.int32)], i + 1) + return text_kernel, ignore_tags diff --git a/backend/ppocr/data/imaug/make_shrink_map.py b/backend/ppocr/data/imaug/make_shrink_map.py index ccdcd015..6c65c20e 100644 --- a/backend/ppocr/data/imaug/make_shrink_map.py +++ b/backend/ppocr/data/imaug/make_shrink_map.py @@ -1,4 +1,20 @@ -# -*- coding:utf-8 -*- +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/make_shrink_map.py +""" from __future__ import absolute_import from __future__ import division @@ -49,7 +65,7 @@ def __call__(self, data): pyclipper.ET_CLOSEDPOLYGON) shrinked = [] - # Increase the shrink ratio every time we get multiple polygon returned back + # Increase the shrink ratio every time we get multiple polygon returned back possible_ratios = np.arange(self.shrink_ratio, 1, self.shrink_ratio) np.append(possible_ratios, 1) @@ -71,7 +87,6 @@ def __call__(self, data): for each_shirnk in shrinked: shirnk = np.array(each_shirnk).reshape(-1, 2) cv2.fillPoly(gt, [shirnk.astype(np.int32)], 1) - # cv2.fillPoly(gt[0], [shrinked.astype(np.int32)], 1) data['shrink_map'] = gt data['shrink_mask'] = mask @@ -97,11 +112,12 @@ def validate_polygons(self, polygons, ignore_tags, h, w): return polygons, ignore_tags def polygon_area(self, polygon): - # return cv2.contourArea(polygon.astype(np.float32)) - edge = 0 - for i in range(polygon.shape[0]): - next_index = (i + 1) % polygon.shape[0] - edge += (polygon[next_index, 0] - polygon[i, 0]) * ( - polygon[next_index, 1] - polygon[i, 1]) - - return edge / 2. + """ + compute polygon area + """ + area = 0 + q = polygon[-1] + for p in polygon: + area += p[0] * q[1] - p[1] * q[0] + q = p + return area / 2.0 diff --git a/backend/ppocr/data/imaug/operators.py b/backend/ppocr/data/imaug/operators.py new file mode 100644 index 00000000..09736515 --- /dev/null +++ b/backend/ppocr/data/imaug/operators.py @@ -0,0 +1,468 @@ +""" +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import sys +import six +import cv2 +import numpy as np +import math + + +class DecodeImage(object): + """ decode image """ + + def __init__(self, + img_mode='RGB', + channel_first=False, + ignore_orientation=False, + **kwargs): + self.img_mode = img_mode + self.channel_first = channel_first + self.ignore_orientation = ignore_orientation + + def __call__(self, data): + img = data['image'] + if six.PY2: + assert type(img) is str and len( + img) > 0, "invalid input 'img' in DecodeImage" + else: + assert type(img) is bytes and len( + img) > 0, "invalid input 'img' in DecodeImage" + img = np.frombuffer(img, dtype='uint8') + if self.ignore_orientation: + img = cv2.imdecode(img, cv2.IMREAD_IGNORE_ORIENTATION | + cv2.IMREAD_COLOR) + else: + img = cv2.imdecode(img, 1) + if img is None: + return None + if self.img_mode == 'GRAY': + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + elif self.img_mode == 'RGB': + assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape) + img = img[:, :, ::-1] + + if self.channel_first: + img = img.transpose((2, 0, 1)) + + data['image'] = img + return data + + +class NRTRDecodeImage(object): + """ decode image """ + + def __init__(self, img_mode='RGB', channel_first=False, **kwargs): + self.img_mode = img_mode + self.channel_first = channel_first + + def __call__(self, data): + img = data['image'] + if six.PY2: + assert type(img) is str and len( + img) > 0, "invalid input 'img' in DecodeImage" + else: + assert type(img) is bytes and len( + img) > 0, "invalid input 'img' in DecodeImage" + img = np.frombuffer(img, dtype='uint8') + + img = cv2.imdecode(img, 1) + + if img is None: + return None + if self.img_mode == 'GRAY': + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + elif self.img_mode == 'RGB': + assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape) + img = img[:, :, ::-1] + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + if self.channel_first: + img = img.transpose((2, 0, 1)) + data['image'] = img + return data + + +class NormalizeImage(object): + """ normalize image such as substract mean, divide std + """ + + def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs): + if isinstance(scale, str): + scale = eval(scale) + self.scale = np.float32(scale if scale is not None else 1.0 / 255.0) + mean = mean if mean is not None else [0.485, 0.456, 0.406] + std = std if std is not None else [0.229, 0.224, 0.225] + + shape = (3, 1, 1) if order == 'chw' else (1, 1, 3) + self.mean = np.array(mean).reshape(shape).astype('float32') + self.std = np.array(std).reshape(shape).astype('float32') + + def __call__(self, data): + img = data['image'] + from PIL import Image + if isinstance(img, Image.Image): + img = np.array(img) + assert isinstance(img, + np.ndarray), "invalid input 'img' in NormalizeImage" + data['image'] = ( + img.astype('float32') * self.scale - self.mean) / self.std + return data + + +class ToCHWImage(object): + """ convert hwc image to chw image + """ + + def __init__(self, **kwargs): + pass + + def __call__(self, data): + img = data['image'] + from PIL import Image + if isinstance(img, Image.Image): + img = np.array(img) + data['image'] = img.transpose((2, 0, 1)) + return data + + +class Fasttext(object): + def __init__(self, path="None", **kwargs): + import fasttext + self.fast_model = fasttext.load_model(path) + + def __call__(self, data): + label = data['label'] + fast_label = self.fast_model[label] + data['fast_label'] = fast_label + return data + + +class KeepKeys(object): + def __init__(self, keep_keys, **kwargs): + self.keep_keys = keep_keys + + def __call__(self, data): + data_list = [] + for key in self.keep_keys: + data_list.append(data[key]) + return data_list + + +class Pad(object): + def __init__(self, size=None, size_div=32, **kwargs): + if size is not None and not isinstance(size, (int, list, tuple)): + raise TypeError("Type of target_size is invalid. Now is {}".format( + type(size))) + if isinstance(size, int): + size = [size, size] + self.size = size + self.size_div = size_div + + def __call__(self, data): + + img = data['image'] + img_h, img_w = img.shape[0], img.shape[1] + if self.size: + resize_h2, resize_w2 = self.size + assert ( + img_h < resize_h2 and img_w < resize_w2 + ), '(h, w) of target size should be greater than (img_h, img_w)' + else: + resize_h2 = max( + int(math.ceil(img.shape[0] / self.size_div) * self.size_div), + self.size_div) + resize_w2 = max( + int(math.ceil(img.shape[1] / self.size_div) * self.size_div), + self.size_div) + img = cv2.copyMakeBorder( + img, + 0, + resize_h2 - img_h, + 0, + resize_w2 - img_w, + cv2.BORDER_CONSTANT, + value=0) + data['image'] = img + return data + + +class Resize(object): + def __init__(self, size=(640, 640), **kwargs): + self.size = size + + def resize_image(self, img): + resize_h, resize_w = self.size + ori_h, ori_w = img.shape[:2] # (h, w, c) + ratio_h = float(resize_h) / ori_h + ratio_w = float(resize_w) / ori_w + img = cv2.resize(img, (int(resize_w), int(resize_h))) + return img, [ratio_h, ratio_w] + + def __call__(self, data): + img = data['image'] + if 'polys' in data: + text_polys = data['polys'] + + img_resize, [ratio_h, ratio_w] = self.resize_image(img) + if 'polys' in data: + new_boxes = [] + for box in text_polys: + new_box = [] + for cord in box: + new_box.append([cord[0] * ratio_w, cord[1] * ratio_h]) + new_boxes.append(new_box) + data['polys'] = np.array(new_boxes, dtype=np.float32) + data['image'] = img_resize + return data + + +class DetResizeForTest(object): + def __init__(self, **kwargs): + super(DetResizeForTest, self).__init__() + self.resize_type = 0 + if 'image_shape' in kwargs: + self.image_shape = kwargs['image_shape'] + self.resize_type = 1 + elif 'limit_side_len' in kwargs: + self.limit_side_len = kwargs['limit_side_len'] + self.limit_type = kwargs.get('limit_type', 'min') + elif 'resize_long' in kwargs: + self.resize_type = 2 + self.resize_long = kwargs.get('resize_long', 960) + else: + self.limit_side_len = 736 + self.limit_type = 'min' + + def __call__(self, data): + img = data['image'] + src_h, src_w, _ = img.shape + + if self.resize_type == 0: + # img, shape = self.resize_image_type0(img) + img, [ratio_h, ratio_w] = self.resize_image_type0(img) + elif self.resize_type == 2: + img, [ratio_h, ratio_w] = self.resize_image_type2(img) + else: + # img, shape = self.resize_image_type1(img) + img, [ratio_h, ratio_w] = self.resize_image_type1(img) + data['image'] = img + data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w]) + return data + + def resize_image_type1(self, img): + resize_h, resize_w = self.image_shape + ori_h, ori_w = img.shape[:2] # (h, w, c) + ratio_h = float(resize_h) / ori_h + ratio_w = float(resize_w) / ori_w + img = cv2.resize(img, (int(resize_w), int(resize_h))) + # return img, np.array([ori_h, ori_w]) + return img, [ratio_h, ratio_w] + + def resize_image_type0(self, img): + """ + resize image to a size multiple of 32 which is required by the network + args: + img(array): array with shape [h, w, c] + return(tuple): + img, (ratio_h, ratio_w) + """ + limit_side_len = self.limit_side_len + h, w, c = img.shape + + # limit the max side + if self.limit_type == 'max': + if max(h, w) > limit_side_len: + if h > w: + ratio = float(limit_side_len) / h + else: + ratio = float(limit_side_len) / w + else: + ratio = 1. + elif self.limit_type == 'min': + if min(h, w) < limit_side_len: + if h < w: + ratio = float(limit_side_len) / h + else: + ratio = float(limit_side_len) / w + else: + ratio = 1. + elif self.limit_type == 'resize_long': + ratio = float(limit_side_len) / max(h, w) + else: + raise Exception('not support limit type, image ') + resize_h = int(h * ratio) + resize_w = int(w * ratio) + + resize_h = max(int(round(resize_h / 32) * 32), 32) + resize_w = max(int(round(resize_w / 32) * 32), 32) + + try: + if int(resize_w) <= 0 or int(resize_h) <= 0: + return None, (None, None) + img = cv2.resize(img, (int(resize_w), int(resize_h))) + except: + print(img.shape, resize_w, resize_h) + sys.exit(0) + ratio_h = resize_h / float(h) + ratio_w = resize_w / float(w) + return img, [ratio_h, ratio_w] + + def resize_image_type2(self, img): + h, w, _ = img.shape + + resize_w = w + resize_h = h + + if resize_h > resize_w: + ratio = float(self.resize_long) / resize_h + else: + ratio = float(self.resize_long) / resize_w + + resize_h = int(resize_h * ratio) + resize_w = int(resize_w * ratio) + + max_stride = 128 + resize_h = (resize_h + max_stride - 1) // max_stride * max_stride + resize_w = (resize_w + max_stride - 1) // max_stride * max_stride + img = cv2.resize(img, (int(resize_w), int(resize_h))) + ratio_h = resize_h / float(h) + ratio_w = resize_w / float(w) + + return img, [ratio_h, ratio_w] + + +class E2EResizeForTest(object): + def __init__(self, **kwargs): + super(E2EResizeForTest, self).__init__() + self.max_side_len = kwargs['max_side_len'] + self.valid_set = kwargs['valid_set'] + + def __call__(self, data): + img = data['image'] + src_h, src_w, _ = img.shape + if self.valid_set == 'totaltext': + im_resized, [ratio_h, ratio_w] = self.resize_image_for_totaltext( + img, max_side_len=self.max_side_len) + else: + im_resized, (ratio_h, ratio_w) = self.resize_image( + img, max_side_len=self.max_side_len) + data['image'] = im_resized + data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w]) + return data + + def resize_image_for_totaltext(self, im, max_side_len=512): + + h, w, _ = im.shape + resize_w = w + resize_h = h + ratio = 1.25 + if h * ratio > max_side_len: + ratio = float(max_side_len) / resize_h + resize_h = int(resize_h * ratio) + resize_w = int(resize_w * ratio) + + max_stride = 128 + resize_h = (resize_h + max_stride - 1) // max_stride * max_stride + resize_w = (resize_w + max_stride - 1) // max_stride * max_stride + im = cv2.resize(im, (int(resize_w), int(resize_h))) + ratio_h = resize_h / float(h) + ratio_w = resize_w / float(w) + return im, (ratio_h, ratio_w) + + def resize_image(self, im, max_side_len=512): + """ + resize image to a size multiple of max_stride which is required by the network + :param im: the resized image + :param max_side_len: limit of max image size to avoid out of memory in gpu + :return: the resized image and the resize ratio + """ + h, w, _ = im.shape + + resize_w = w + resize_h = h + + # Fix the longer side + if resize_h > resize_w: + ratio = float(max_side_len) / resize_h + else: + ratio = float(max_side_len) / resize_w + + resize_h = int(resize_h * ratio) + resize_w = int(resize_w * ratio) + + max_stride = 128 + resize_h = (resize_h + max_stride - 1) // max_stride * max_stride + resize_w = (resize_w + max_stride - 1) // max_stride * max_stride + im = cv2.resize(im, (int(resize_w), int(resize_h))) + ratio_h = resize_h / float(h) + ratio_w = resize_w / float(w) + + return im, (ratio_h, ratio_w) + + +class KieResize(object): + def __init__(self, **kwargs): + super(KieResize, self).__init__() + self.max_side, self.min_side = kwargs['img_scale'][0], kwargs[ + 'img_scale'][1] + + def __call__(self, data): + img = data['image'] + points = data['points'] + src_h, src_w, _ = img.shape + im_resized, scale_factor, [ratio_h, ratio_w + ], [new_h, new_w] = self.resize_image(img) + resize_points = self.resize_boxes(img, points, scale_factor) + data['ori_image'] = img + data['ori_boxes'] = points + data['points'] = resize_points + data['image'] = im_resized + data['shape'] = np.array([new_h, new_w]) + return data + + def resize_image(self, img): + norm_img = np.zeros([1024, 1024, 3], dtype='float32') + scale = [512, 1024] + h, w = img.shape[:2] + max_long_edge = max(scale) + max_short_edge = min(scale) + scale_factor = min(max_long_edge / max(h, w), + max_short_edge / min(h, w)) + resize_w, resize_h = int(w * float(scale_factor) + 0.5), int(h * float( + scale_factor) + 0.5) + max_stride = 32 + resize_h = (resize_h + max_stride - 1) // max_stride * max_stride + resize_w = (resize_w + max_stride - 1) // max_stride * max_stride + im = cv2.resize(img, (resize_w, resize_h)) + new_h, new_w = im.shape[:2] + w_scale = new_w / w + h_scale = new_h / h + scale_factor = np.array( + [w_scale, h_scale, w_scale, h_scale], dtype=np.float32) + norm_img[:new_h, :new_w, :] = im + return norm_img, scale_factor, [h_scale, w_scale], [new_h, new_w] + + def resize_boxes(self, im, points, scale_factor): + points = points * scale_factor + img_shape = im.shape[:2] + points[:, 0::2] = np.clip(points[:, 0::2], 0, img_shape[1]) + points[:, 1::2] = np.clip(points[:, 1::2], 0, img_shape[0]) + return points diff --git a/backend/ppocr/data/imaug/pg_process.py b/backend/ppocr/data/imaug/pg_process.py new file mode 100644 index 00000000..53031064 --- /dev/null +++ b/backend/ppocr/data/imaug/pg_process.py @@ -0,0 +1,906 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import cv2 +import numpy as np + +__all__ = ['PGProcessTrain'] + + +class PGProcessTrain(object): + def __init__(self, + character_dict_path, + max_text_length, + max_text_nums, + tcl_len, + batch_size=14, + min_crop_size=24, + min_text_size=4, + max_text_size=512, + **kwargs): + self.tcl_len = tcl_len + self.max_text_length = max_text_length + self.max_text_nums = max_text_nums + self.batch_size = batch_size + self.min_crop_size = min_crop_size + self.min_text_size = min_text_size + self.max_text_size = max_text_size + self.Lexicon_Table = self.get_dict(character_dict_path) + self.pad_num = len(self.Lexicon_Table) + self.img_id = 0 + + def get_dict(self, character_dict_path): + character_str = "" + with open(character_dict_path, "rb") as fin: + lines = fin.readlines() + for line in lines: + line = line.decode('utf-8').strip("\n").strip("\r\n") + character_str += line + dict_character = list(character_str) + return dict_character + + def quad_area(self, poly): + """ + compute area of a polygon + :param poly: + :return: + """ + edge = [(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]), + (poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]), + (poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]), + (poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])] + return np.sum(edge) / 2. + + def gen_quad_from_poly(self, poly): + """ + Generate min area quad from poly. + """ + point_num = poly.shape[0] + min_area_quad = np.zeros((4, 2), dtype=np.float32) + rect = cv2.minAreaRect(poly.astype( + np.int32)) # (center (x,y), (width, height), angle of rotation) + box = np.array(cv2.boxPoints(rect)) + + first_point_idx = 0 + min_dist = 1e4 + for i in range(4): + dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \ + np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \ + np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \ + np.linalg.norm(box[(i + 3) % 4] - poly[-1]) + if dist < min_dist: + min_dist = dist + first_point_idx = i + for i in range(4): + min_area_quad[i] = box[(first_point_idx + i) % 4] + + return min_area_quad + + def check_and_validate_polys(self, polys, tags, im_size): + """ + check so that the text poly is in the same direction, + and also filter some invalid polygons + :param polys: + :param tags: + :return: + """ + (h, w) = im_size + if polys.shape[0] == 0: + return polys, np.array([]), np.array([]) + polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1) + polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1) + + validated_polys = [] + validated_tags = [] + hv_tags = [] + for poly, tag in zip(polys, tags): + quad = self.gen_quad_from_poly(poly) + p_area = self.quad_area(quad) + if abs(p_area) < 1: + print('invalid poly') + continue + if p_area > 0: + if tag == False: + print('poly in wrong direction') + tag = True # reversed cases should be ignore + poly = poly[(0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, + 1), :] + quad = quad[(0, 3, 2, 1), :] + + len_w = np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[3] - + quad[2]) + len_h = np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - + quad[2]) + hv_tag = 1 + + if len_w * 2.0 < len_h: + hv_tag = 0 + + validated_polys.append(poly) + validated_tags.append(tag) + hv_tags.append(hv_tag) + return np.array(validated_polys), np.array(validated_tags), np.array( + hv_tags) + + def crop_area(self, + im, + polys, + tags, + hv_tags, + txts, + crop_background=False, + max_tries=25): + """ + make random crop from the input image + :param im: + :param polys: [b,4,2] + :param tags: + :param crop_background: + :param max_tries: 50 -> 25 + :return: + """ + h, w, _ = im.shape + pad_h = h // 10 + pad_w = w // 10 + h_array = np.zeros((h + pad_h * 2), dtype=np.int32) + w_array = np.zeros((w + pad_w * 2), dtype=np.int32) + for poly in polys: + poly = np.round(poly, decimals=0).astype(np.int32) + minx = np.min(poly[:, 0]) + maxx = np.max(poly[:, 0]) + w_array[minx + pad_w:maxx + pad_w] = 1 + miny = np.min(poly[:, 1]) + maxy = np.max(poly[:, 1]) + h_array[miny + pad_h:maxy + pad_h] = 1 + # ensure the cropped area not across a text + h_axis = np.where(h_array == 0)[0] + w_axis = np.where(w_array == 0)[0] + if len(h_axis) == 0 or len(w_axis) == 0: + return im, polys, tags, hv_tags, txts + for i in range(max_tries): + xx = np.random.choice(w_axis, size=2) + xmin = np.min(xx) - pad_w + xmax = np.max(xx) - pad_w + xmin = np.clip(xmin, 0, w - 1) + xmax = np.clip(xmax, 0, w - 1) + yy = np.random.choice(h_axis, size=2) + ymin = np.min(yy) - pad_h + ymax = np.max(yy) - pad_h + ymin = np.clip(ymin, 0, h - 1) + ymax = np.clip(ymax, 0, h - 1) + if xmax - xmin < self.min_crop_size or \ + ymax - ymin < self.min_crop_size: + continue + if polys.shape[0] != 0: + poly_axis_in_area = (polys[:, :, 0] >= xmin) & (polys[:, :, 0] <= xmax) \ + & (polys[:, :, 1] >= ymin) & (polys[:, :, 1] <= ymax) + selected_polys = np.where( + np.sum(poly_axis_in_area, axis=1) == 4)[0] + else: + selected_polys = [] + if len(selected_polys) == 0: + # no text in this area + if crop_background: + txts_tmp = [] + for selected_poly in selected_polys: + txts_tmp.append(txts[selected_poly]) + txts = txts_tmp + return im[ymin: ymax + 1, xmin: xmax + 1, :], \ + polys[selected_polys], tags[selected_polys], hv_tags[selected_polys], txts + else: + continue + im = im[ymin:ymax + 1, xmin:xmax + 1, :] + polys = polys[selected_polys] + tags = tags[selected_polys] + hv_tags = hv_tags[selected_polys] + txts_tmp = [] + for selected_poly in selected_polys: + txts_tmp.append(txts[selected_poly]) + txts = txts_tmp + polys[:, :, 0] -= xmin + polys[:, :, 1] -= ymin + return im, polys, tags, hv_tags, txts + + return im, polys, tags, hv_tags, txts + + def fit_and_gather_tcl_points_v2(self, + min_area_quad, + poly, + max_h, + max_w, + fixed_point_num=64, + img_id=0, + reference_height=3): + """ + Find the center point of poly as key_points, then fit and gather. + """ + key_point_xys = [] + point_num = poly.shape[0] + for idx in range(point_num // 2): + center_point = (poly[idx] + poly[point_num - 1 - idx]) / 2.0 + key_point_xys.append(center_point) + + tmp_image = np.zeros( + shape=( + max_h, + max_w, ), dtype='float32') + cv2.polylines(tmp_image, [np.array(key_point_xys).astype('int32')], + False, 1.0) + ys, xs = np.where(tmp_image > 0) + xy_text = np.array(list(zip(xs, ys)), dtype='float32') + + left_center_pt = ( + (min_area_quad[0] - min_area_quad[1]) / 2.0).reshape(1, 2) + right_center_pt = ( + (min_area_quad[1] - min_area_quad[2]) / 2.0).reshape(1, 2) + proj_unit_vec = (right_center_pt - left_center_pt) / ( + np.linalg.norm(right_center_pt - left_center_pt) + 1e-6) + proj_unit_vec_tile = np.tile(proj_unit_vec, + (xy_text.shape[0], 1)) # (n, 2) + left_center_pt_tile = np.tile(left_center_pt, + (xy_text.shape[0], 1)) # (n, 2) + xy_text_to_left_center = xy_text - left_center_pt_tile + proj_value = np.sum(xy_text_to_left_center * proj_unit_vec_tile, axis=1) + xy_text = xy_text[np.argsort(proj_value)] + + # convert to np and keep the num of point not greater then fixed_point_num + pos_info = np.array(xy_text).reshape(-1, 2)[:, ::-1] # xy-> yx + point_num = len(pos_info) + if point_num > fixed_point_num: + keep_ids = [ + int((point_num * 1.0 / fixed_point_num) * x) + for x in range(fixed_point_num) + ] + pos_info = pos_info[keep_ids, :] + + keep = int(min(len(pos_info), fixed_point_num)) + if np.random.rand() < 0.2 and reference_height >= 3: + dl = (np.random.rand(keep) - 0.5) * reference_height * 0.3 + random_float = np.array([1, 0]).reshape([1, 2]) * dl.reshape( + [keep, 1]) + pos_info += random_float + pos_info[:, 0] = np.clip(pos_info[:, 0], 0, max_h - 1) + pos_info[:, 1] = np.clip(pos_info[:, 1], 0, max_w - 1) + + # padding to fixed length + pos_l = np.zeros((self.tcl_len, 3), dtype=np.int32) + pos_l[:, 0] = np.ones((self.tcl_len, )) * img_id + pos_m = np.zeros((self.tcl_len, 1), dtype=np.float32) + pos_l[:keep, 1:] = np.round(pos_info).astype(np.int32) + pos_m[:keep] = 1.0 + return pos_l, pos_m + + def generate_direction_map(self, poly_quads, n_char, direction_map): + """ + """ + width_list = [] + height_list = [] + for quad in poly_quads: + quad_w = (np.linalg.norm(quad[0] - quad[1]) + + np.linalg.norm(quad[2] - quad[3])) / 2.0 + quad_h = (np.linalg.norm(quad[0] - quad[3]) + + np.linalg.norm(quad[2] - quad[1])) / 2.0 + width_list.append(quad_w) + height_list.append(quad_h) + norm_width = max(sum(width_list) / n_char, 1.0) + average_height = max(sum(height_list) / len(height_list), 1.0) + k = 1 + for quad in poly_quads: + direct_vector_full = ( + (quad[1] + quad[2]) - (quad[0] + quad[3])) / 2.0 + direct_vector = direct_vector_full / ( + np.linalg.norm(direct_vector_full) + 1e-6) * norm_width + direction_label = tuple( + map(float, + [direct_vector[0], direct_vector[1], 1.0 / average_height])) + cv2.fillPoly(direction_map, + quad.round().astype(np.int32)[np.newaxis, :, :], + direction_label) + k += 1 + return direction_map + + def calculate_average_height(self, poly_quads): + """ + """ + height_list = [] + for quad in poly_quads: + quad_h = (np.linalg.norm(quad[0] - quad[3]) + + np.linalg.norm(quad[2] - quad[1])) / 2.0 + height_list.append(quad_h) + average_height = max(sum(height_list) / len(height_list), 1.0) + return average_height + + def generate_tcl_ctc_label(self, + h, + w, + polys, + tags, + text_strs, + ds_ratio, + tcl_ratio=0.3, + shrink_ratio_of_width=0.15): + """ + Generate polygon. + """ + score_map_big = np.zeros( + ( + h, + w, ), dtype=np.float32) + h, w = int(h * ds_ratio), int(w * ds_ratio) + polys = polys * ds_ratio + + score_map = np.zeros( + ( + h, + w, ), dtype=np.float32) + score_label_map = np.zeros( + ( + h, + w, ), dtype=np.float32) + tbo_map = np.zeros((h, w, 5), dtype=np.float32) + training_mask = np.ones( + ( + h, + w, ), dtype=np.float32) + direction_map = np.ones((h, w, 3)) * np.array([0, 0, 1]).reshape( + [1, 1, 3]).astype(np.float32) + + label_idx = 0 + score_label_map_text_label_list = [] + pos_list, pos_mask, label_list = [], [], [] + for poly_idx, poly_tag in enumerate(zip(polys, tags)): + poly = poly_tag[0] + tag = poly_tag[1] + + # generate min_area_quad + min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly) + min_area_quad_h = 0.5 * ( + np.linalg.norm(min_area_quad[0] - min_area_quad[3]) + + np.linalg.norm(min_area_quad[1] - min_area_quad[2])) + min_area_quad_w = 0.5 * ( + np.linalg.norm(min_area_quad[0] - min_area_quad[1]) + + np.linalg.norm(min_area_quad[2] - min_area_quad[3])) + + if min(min_area_quad_h, min_area_quad_w) < self.min_text_size * ds_ratio \ + or min(min_area_quad_h, min_area_quad_w) > self.max_text_size * ds_ratio: + continue + + if tag: + cv2.fillPoly(training_mask, + poly.astype(np.int32)[np.newaxis, :, :], 0.15) + else: + text_label = text_strs[poly_idx] + text_label = self.prepare_text_label(text_label, + self.Lexicon_Table) + + text_label_index_list = [[self.Lexicon_Table.index(c_)] + for c_ in text_label + if c_ in self.Lexicon_Table] + if len(text_label_index_list) < 1: + continue + + tcl_poly = self.poly2tcl(poly, tcl_ratio) + tcl_quads = self.poly2quads(tcl_poly) + poly_quads = self.poly2quads(poly) + + stcl_quads, quad_index = self.shrink_poly_along_width( + tcl_quads, + shrink_ratio_of_width=shrink_ratio_of_width, + expand_height_ratio=1.0 / tcl_ratio) + + cv2.fillPoly(score_map, + np.round(stcl_quads).astype(np.int32), 1.0) + cv2.fillPoly(score_map_big, + np.round(stcl_quads / ds_ratio).astype(np.int32), + 1.0) + + for idx, quad in enumerate(stcl_quads): + quad_mask = np.zeros((h, w), dtype=np.float32) + quad_mask = cv2.fillPoly( + quad_mask, + np.round(quad[np.newaxis, :, :]).astype(np.int32), 1.0) + tbo_map = self.gen_quad_tbo(poly_quads[quad_index[idx]], + quad_mask, tbo_map) + + # score label map and score_label_map_text_label_list for refine + if label_idx == 0: + text_pos_list_ = [[len(self.Lexicon_Table)], ] + score_label_map_text_label_list.append(text_pos_list_) + + label_idx += 1 + cv2.fillPoly(score_label_map, + np.round(poly_quads).astype(np.int32), label_idx) + score_label_map_text_label_list.append(text_label_index_list) + + # direction info, fix-me + n_char = len(text_label_index_list) + direction_map = self.generate_direction_map(poly_quads, n_char, + direction_map) + + # pos info + average_shrink_height = self.calculate_average_height( + stcl_quads) + pos_l, pos_m = self.fit_and_gather_tcl_points_v2( + min_area_quad, + poly, + max_h=h, + max_w=w, + fixed_point_num=64, + img_id=self.img_id, + reference_height=average_shrink_height) + + label_l = text_label_index_list + if len(text_label_index_list) < 2: + continue + + pos_list.append(pos_l) + pos_mask.append(pos_m) + label_list.append(label_l) + + # use big score_map for smooth tcl lines + score_map_big_resized = cv2.resize( + score_map_big, dsize=None, fx=ds_ratio, fy=ds_ratio) + score_map = np.array(score_map_big_resized > 1e-3, dtype='float32') + + return score_map, score_label_map, tbo_map, direction_map, training_mask, \ + pos_list, pos_mask, label_list, score_label_map_text_label_list + + def adjust_point(self, poly): + """ + adjust point order. + """ + point_num = poly.shape[0] + if point_num == 4: + len_1 = np.linalg.norm(poly[0] - poly[1]) + len_2 = np.linalg.norm(poly[1] - poly[2]) + len_3 = np.linalg.norm(poly[2] - poly[3]) + len_4 = np.linalg.norm(poly[3] - poly[0]) + + if (len_1 + len_3) * 1.5 < (len_2 + len_4): + poly = poly[[1, 2, 3, 0], :] + + elif point_num > 4: + vector_1 = poly[0] - poly[1] + vector_2 = poly[1] - poly[2] + cos_theta = np.dot(vector_1, vector_2) / ( + np.linalg.norm(vector_1) * np.linalg.norm(vector_2) + 1e-6) + theta = np.arccos(np.round(cos_theta, decimals=4)) + + if abs(theta) > (70 / 180 * math.pi): + index = list(range(1, point_num)) + [0] + poly = poly[np.array(index), :] + return poly + + def gen_min_area_quad_from_poly(self, poly): + """ + Generate min area quad from poly. + """ + point_num = poly.shape[0] + min_area_quad = np.zeros((4, 2), dtype=np.float32) + if point_num == 4: + min_area_quad = poly + center_point = np.sum(poly, axis=0) / 4 + else: + rect = cv2.minAreaRect(poly.astype( + np.int32)) # (center (x,y), (width, height), angle of rotation) + center_point = rect[0] + box = np.array(cv2.boxPoints(rect)) + + first_point_idx = 0 + min_dist = 1e4 + for i in range(4): + dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \ + np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \ + np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \ + np.linalg.norm(box[(i + 3) % 4] - poly[-1]) + if dist < min_dist: + min_dist = dist + first_point_idx = i + + for i in range(4): + min_area_quad[i] = box[(first_point_idx + i) % 4] + + return min_area_quad, center_point + + def shrink_quad_along_width(self, + quad, + begin_width_ratio=0., + end_width_ratio=1.): + """ + Generate shrink_quad_along_width. + """ + ratio_pair = np.array( + [[begin_width_ratio], [end_width_ratio]], dtype=np.float32) + p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair + p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair + return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]]) + + def shrink_poly_along_width(self, + quads, + shrink_ratio_of_width, + expand_height_ratio=1.0): + """ + shrink poly with given length. + """ + upper_edge_list = [] + + def get_cut_info(edge_len_list, cut_len): + for idx, edge_len in enumerate(edge_len_list): + cut_len -= edge_len + if cut_len <= 0.000001: + ratio = (cut_len + edge_len_list[idx]) / edge_len_list[idx] + return idx, ratio + + for quad in quads: + upper_edge_len = np.linalg.norm(quad[0] - quad[1]) + upper_edge_list.append(upper_edge_len) + + # length of left edge and right edge. + left_length = np.linalg.norm(quads[0][0] - quads[0][ + 3]) * expand_height_ratio + right_length = np.linalg.norm(quads[-1][1] - quads[-1][ + 2]) * expand_height_ratio + + shrink_length = min(left_length, right_length, + sum(upper_edge_list)) * shrink_ratio_of_width + # shrinking length + upper_len_left = shrink_length + upper_len_right = sum(upper_edge_list) - shrink_length + + left_idx, left_ratio = get_cut_info(upper_edge_list, upper_len_left) + left_quad = self.shrink_quad_along_width( + quads[left_idx], begin_width_ratio=left_ratio, end_width_ratio=1) + right_idx, right_ratio = get_cut_info(upper_edge_list, upper_len_right) + right_quad = self.shrink_quad_along_width( + quads[right_idx], begin_width_ratio=0, end_width_ratio=right_ratio) + + out_quad_list = [] + if left_idx == right_idx: + out_quad_list.append( + [left_quad[0], right_quad[1], right_quad[2], left_quad[3]]) + else: + out_quad_list.append(left_quad) + for idx in range(left_idx + 1, right_idx): + out_quad_list.append(quads[idx]) + out_quad_list.append(right_quad) + + return np.array(out_quad_list), list(range(left_idx, right_idx + 1)) + + def prepare_text_label(self, label_str, Lexicon_Table): + """ + Prepare text lablel by given Lexicon_Table. + """ + if len(Lexicon_Table) == 36: + return label_str.lower() + else: + return label_str + + def vector_angle(self, A, B): + """ + Calculate the angle between vector AB and x-axis positive direction. + """ + AB = np.array([B[1] - A[1], B[0] - A[0]]) + return np.arctan2(*AB) + + def theta_line_cross_point(self, theta, point): + """ + Calculate the line through given point and angle in ax + by + c =0 form. + """ + x, y = point + cos = np.cos(theta) + sin = np.sin(theta) + return [sin, -cos, cos * y - sin * x] + + def line_cross_two_point(self, A, B): + """ + Calculate the line through given point A and B in ax + by + c =0 form. + """ + angle = self.vector_angle(A, B) + return self.theta_line_cross_point(angle, A) + + def average_angle(self, poly): + """ + Calculate the average angle between left and right edge in given poly. + """ + p0, p1, p2, p3 = poly + angle30 = self.vector_angle(p3, p0) + angle21 = self.vector_angle(p2, p1) + return (angle30 + angle21) / 2 + + def line_cross_point(self, line1, line2): + """ + line1 and line2 in 0=ax+by+c form, compute the cross point of line1 and line2 + """ + a1, b1, c1 = line1 + a2, b2, c2 = line2 + d = a1 * b2 - a2 * b1 + + if d == 0: + print('Cross point does not exist') + return np.array([0, 0], dtype=np.float32) + else: + x = (b1 * c2 - b2 * c1) / d + y = (a2 * c1 - a1 * c2) / d + + return np.array([x, y], dtype=np.float32) + + def quad2tcl(self, poly, ratio): + """ + Generate center line by poly clock-wise point. (4, 2) + """ + ratio_pair = np.array( + [[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32) + p0_3 = poly[0] + (poly[3] - poly[0]) * ratio_pair + p1_2 = poly[1] + (poly[2] - poly[1]) * ratio_pair + return np.array([p0_3[0], p1_2[0], p1_2[1], p0_3[1]]) + + def poly2tcl(self, poly, ratio): + """ + Generate center line by poly clock-wise point. + """ + ratio_pair = np.array( + [[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32) + tcl_poly = np.zeros_like(poly) + point_num = poly.shape[0] + + for idx in range(point_num // 2): + point_pair = poly[idx] + (poly[point_num - 1 - idx] - poly[idx] + ) * ratio_pair + tcl_poly[idx] = point_pair[0] + tcl_poly[point_num - 1 - idx] = point_pair[1] + return tcl_poly + + def gen_quad_tbo(self, quad, tcl_mask, tbo_map): + """ + Generate tbo_map for give quad. + """ + # upper and lower line function: ax + by + c = 0; + up_line = self.line_cross_two_point(quad[0], quad[1]) + lower_line = self.line_cross_two_point(quad[3], quad[2]) + + quad_h = 0.5 * (np.linalg.norm(quad[0] - quad[3]) + + np.linalg.norm(quad[1] - quad[2])) + quad_w = 0.5 * (np.linalg.norm(quad[0] - quad[1]) + + np.linalg.norm(quad[2] - quad[3])) + + # average angle of left and right line. + angle = self.average_angle(quad) + + xy_in_poly = np.argwhere(tcl_mask == 1) + for y, x in xy_in_poly: + point = (x, y) + line = self.theta_line_cross_point(angle, point) + cross_point_upper = self.line_cross_point(up_line, line) + cross_point_lower = self.line_cross_point(lower_line, line) + ##FIX, offset reverse + upper_offset_x, upper_offset_y = cross_point_upper - point + lower_offset_x, lower_offset_y = cross_point_lower - point + tbo_map[y, x, 0] = upper_offset_y + tbo_map[y, x, 1] = upper_offset_x + tbo_map[y, x, 2] = lower_offset_y + tbo_map[y, x, 3] = lower_offset_x + tbo_map[y, x, 4] = 1.0 / max(min(quad_h, quad_w), 1.0) * 2 + return tbo_map + + def poly2quads(self, poly): + """ + Split poly into quads. + """ + quad_list = [] + point_num = poly.shape[0] + + # point pair + point_pair_list = [] + for idx in range(point_num // 2): + point_pair = [poly[idx], poly[point_num - 1 - idx]] + point_pair_list.append(point_pair) + + quad_num = point_num // 2 - 1 + for idx in range(quad_num): + # reshape and adjust to clock-wise + quad_list.append((np.array(point_pair_list)[[idx, idx + 1]] + ).reshape(4, 2)[[0, 2, 3, 1]]) + + return np.array(quad_list) + + def rotate_im_poly(self, im, text_polys): + """ + rotate image with 90 / 180 / 270 degre + """ + im_w, im_h = im.shape[1], im.shape[0] + dst_im = im.copy() + dst_polys = [] + rand_degree_ratio = np.random.rand() + rand_degree_cnt = 1 + if rand_degree_ratio > 0.5: + rand_degree_cnt = 3 + for i in range(rand_degree_cnt): + dst_im = np.rot90(dst_im) + rot_degree = -90 * rand_degree_cnt + rot_angle = rot_degree * math.pi / 180.0 + n_poly = text_polys.shape[0] + cx, cy = 0.5 * im_w, 0.5 * im_h + ncx, ncy = 0.5 * dst_im.shape[1], 0.5 * dst_im.shape[0] + for i in range(n_poly): + wordBB = text_polys[i] + poly = [] + for j in range(4): # 16->4 + sx, sy = wordBB[j][0], wordBB[j][1] + dx = math.cos(rot_angle) * (sx - cx) - math.sin(rot_angle) * ( + sy - cy) + ncx + dy = math.sin(rot_angle) * (sx - cx) + math.cos(rot_angle) * ( + sy - cy) + ncy + poly.append([dx, dy]) + dst_polys.append(poly) + return dst_im, np.array(dst_polys, dtype=np.float32) + + def __call__(self, data): + input_size = 512 + im = data['image'] + text_polys = data['polys'] + text_tags = data['ignore_tags'] + text_strs = data['texts'] + h, w, _ = im.shape + text_polys, text_tags, hv_tags = self.check_and_validate_polys( + text_polys, text_tags, (h, w)) + if text_polys.shape[0] <= 0: + return None + # set aspect ratio and keep area fix + asp_scales = np.arange(1.0, 1.55, 0.1) + asp_scale = np.random.choice(asp_scales) + if np.random.rand() < 0.5: + asp_scale = 1.0 / asp_scale + asp_scale = math.sqrt(asp_scale) + + asp_wx = asp_scale + asp_hy = 1.0 / asp_scale + im = cv2.resize(im, dsize=None, fx=asp_wx, fy=asp_hy) + text_polys[:, :, 0] *= asp_wx + text_polys[:, :, 1] *= asp_hy + + h, w, _ = im.shape + if max(h, w) > 2048: + rd_scale = 2048.0 / max(h, w) + im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale) + text_polys *= rd_scale + h, w, _ = im.shape + if min(h, w) < 16: + return None + + # no background + im, text_polys, text_tags, hv_tags, text_strs = self.crop_area( + im, + text_polys, + text_tags, + hv_tags, + text_strs, + crop_background=False) + + if text_polys.shape[0] == 0: + return None + # # continue for all ignore case + if np.sum((text_tags * 1.0)) >= text_tags.size: + return None + new_h, new_w, _ = im.shape + if (new_h is None) or (new_w is None): + return None + # resize image + std_ratio = float(input_size) / max(new_w, new_h) + rand_scales = np.array( + [0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.0, 1.0, 1.0, 1.0]) + rz_scale = std_ratio * np.random.choice(rand_scales) + im = cv2.resize(im, dsize=None, fx=rz_scale, fy=rz_scale) + text_polys[:, :, 0] *= rz_scale + text_polys[:, :, 1] *= rz_scale + + # add gaussian blur + if np.random.rand() < 0.1 * 0.5: + ks = np.random.permutation(5)[0] + 1 + ks = int(ks / 2) * 2 + 1 + im = cv2.GaussianBlur(im, ksize=(ks, ks), sigmaX=0, sigmaY=0) + # add brighter + if np.random.rand() < 0.1 * 0.5: + im = im * (1.0 + np.random.rand() * 0.5) + im = np.clip(im, 0.0, 255.0) + # add darker + if np.random.rand() < 0.1 * 0.5: + im = im * (1.0 - np.random.rand() * 0.5) + im = np.clip(im, 0.0, 255.0) + + # Padding the im to [input_size, input_size] + new_h, new_w, _ = im.shape + if min(new_w, new_h) < input_size * 0.5: + return None + im_padded = np.ones((input_size, input_size, 3), dtype=np.float32) + im_padded[:, :, 2] = 0.485 * 255 + im_padded[:, :, 1] = 0.456 * 255 + im_padded[:, :, 0] = 0.406 * 255 + + # Random the start position + del_h = input_size - new_h + del_w = input_size - new_w + sh, sw = 0, 0 + if del_h > 1: + sh = int(np.random.rand() * del_h) + if del_w > 1: + sw = int(np.random.rand() * del_w) + + # Padding + im_padded[sh:sh + new_h, sw:sw + new_w, :] = im.copy() + text_polys[:, :, 0] += sw + text_polys[:, :, 1] += sh + + score_map, score_label_map, border_map, direction_map, training_mask, \ + pos_list, pos_mask, label_list, score_label_map_text_label = self.generate_tcl_ctc_label(input_size, + input_size, + text_polys, + text_tags, + text_strs, 0.25) + if len(label_list) <= 0: # eliminate negative samples + return None + pos_list_temp = np.zeros([64, 3]) + pos_mask_temp = np.zeros([64, 1]) + label_list_temp = np.zeros([self.max_text_length, 1]) + self.pad_num + + for i, label in enumerate(label_list): + n = len(label) + if n > self.max_text_length: + label_list[i] = label[:self.max_text_length] + continue + while n < self.max_text_length: + label.append([self.pad_num]) + n += 1 + + for i in range(len(label_list)): + label_list[i] = np.array(label_list[i]) + + if len(pos_list) <= 0 or len(pos_list) > self.max_text_nums: + return None + for __ in range(self.max_text_nums - len(pos_list), 0, -1): + pos_list.append(pos_list_temp) + pos_mask.append(pos_mask_temp) + label_list.append(label_list_temp) + + if self.img_id == self.batch_size - 1: + self.img_id = 0 + else: + self.img_id += 1 + + im_padded[:, :, 2] -= 0.485 * 255 + im_padded[:, :, 1] -= 0.456 * 255 + im_padded[:, :, 0] -= 0.406 * 255 + im_padded[:, :, 2] /= (255.0 * 0.229) + im_padded[:, :, 1] /= (255.0 * 0.224) + im_padded[:, :, 0] /= (255.0 * 0.225) + im_padded = im_padded.transpose((2, 0, 1)) + images = im_padded[::-1, :, :] + tcl_maps = score_map[np.newaxis, :, :] + tcl_label_maps = score_label_map[np.newaxis, :, :] + border_maps = border_map.transpose((2, 0, 1)) + direction_maps = direction_map.transpose((2, 0, 1)) + training_masks = training_mask[np.newaxis, :, :] + pos_list = np.array(pos_list) + pos_mask = np.array(pos_mask) + label_list = np.array(label_list) + data['images'] = images + data['tcl_maps'] = tcl_maps + data['tcl_label_maps'] = tcl_label_maps + data['border_maps'] = border_maps + data['direction_maps'] = direction_maps + data['training_masks'] = training_masks + data['label_list'] = label_list + data['pos_list'] = pos_list + data['pos_mask'] = pos_mask + return data diff --git a/backend/ppocr/data/imaug/randaugment.py b/backend/ppocr/data/imaug/randaugment.py new file mode 100644 index 00000000..56f114d2 --- /dev/null +++ b/backend/ppocr/data/imaug/randaugment.py @@ -0,0 +1,143 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from PIL import Image, ImageEnhance, ImageOps +import numpy as np +import random +import six + + +class RawRandAugment(object): + def __init__(self, + num_layers=2, + magnitude=5, + fillcolor=(128, 128, 128), + **kwargs): + self.num_layers = num_layers + self.magnitude = magnitude + self.max_level = 10 + + abso_level = self.magnitude / self.max_level + self.level_map = { + "shearX": 0.3 * abso_level, + "shearY": 0.3 * abso_level, + "translateX": 150.0 / 331 * abso_level, + "translateY": 150.0 / 331 * abso_level, + "rotate": 30 * abso_level, + "color": 0.9 * abso_level, + "posterize": int(4.0 * abso_level), + "solarize": 256.0 * abso_level, + "contrast": 0.9 * abso_level, + "sharpness": 0.9 * abso_level, + "brightness": 0.9 * abso_level, + "autocontrast": 0, + "equalize": 0, + "invert": 0 + } + + # from https://stackoverflow.com/questions/5252170/ + # specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand + def rotate_with_fill(img, magnitude): + rot = img.convert("RGBA").rotate(magnitude) + return Image.composite(rot, + Image.new("RGBA", rot.size, (128, ) * 4), + rot).convert(img.mode) + + rnd_ch_op = random.choice + + self.func = { + "shearX": lambda img, magnitude: img.transform( + img.size, + Image.AFFINE, + (1, magnitude * rnd_ch_op([-1, 1]), 0, 0, 1, 0), + Image.BICUBIC, + fillcolor=fillcolor), + "shearY": lambda img, magnitude: img.transform( + img.size, + Image.AFFINE, + (1, 0, 0, magnitude * rnd_ch_op([-1, 1]), 1, 0), + Image.BICUBIC, + fillcolor=fillcolor), + "translateX": lambda img, magnitude: img.transform( + img.size, + Image.AFFINE, + (1, 0, magnitude * img.size[0] * rnd_ch_op([-1, 1]), 0, 1, 0), + fillcolor=fillcolor), + "translateY": lambda img, magnitude: img.transform( + img.size, + Image.AFFINE, + (1, 0, 0, 0, 1, magnitude * img.size[1] * rnd_ch_op([-1, 1])), + fillcolor=fillcolor), + "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), + "color": lambda img, magnitude: ImageEnhance.Color(img).enhance( + 1 + magnitude * rnd_ch_op([-1, 1])), + "posterize": lambda img, magnitude: + ImageOps.posterize(img, magnitude), + "solarize": lambda img, magnitude: + ImageOps.solarize(img, magnitude), + "contrast": lambda img, magnitude: + ImageEnhance.Contrast(img).enhance( + 1 + magnitude * rnd_ch_op([-1, 1])), + "sharpness": lambda img, magnitude: + ImageEnhance.Sharpness(img).enhance( + 1 + magnitude * rnd_ch_op([-1, 1])), + "brightness": lambda img, magnitude: + ImageEnhance.Brightness(img).enhance( + 1 + magnitude * rnd_ch_op([-1, 1])), + "autocontrast": lambda img, magnitude: + ImageOps.autocontrast(img), + "equalize": lambda img, magnitude: ImageOps.equalize(img), + "invert": lambda img, magnitude: ImageOps.invert(img) + } + + def __call__(self, img): + avaiable_op_names = list(self.level_map.keys()) + for layer_num in range(self.num_layers): + op_name = np.random.choice(avaiable_op_names) + img = self.func[op_name](img, self.level_map[op_name]) + return img + + +class RandAugment(RawRandAugment): + """ RandAugment wrapper to auto fit different img types """ + + def __init__(self, prob=0.5, *args, **kwargs): + self.prob = prob + if six.PY2: + super(RandAugment, self).__init__(*args, **kwargs) + else: + super().__init__(*args, **kwargs) + + def __call__(self, data): + if np.random.rand() > self.prob: + return data + img = data['image'] + if not isinstance(img, Image.Image): + img = np.ascontiguousarray(img) + img = Image.fromarray(img) + + if six.PY2: + img = super(RandAugment, self).__call__(img) + else: + img = super().__call__(img) + + if isinstance(img, Image.Image): + img = np.asarray(img) + data['image'] = img + return data diff --git a/backend/ppocr/data/imaug/random_crop_data.py b/backend/ppocr/data/imaug/random_crop_data.py new file mode 100644 index 00000000..64aa110d --- /dev/null +++ b/backend/ppocr/data/imaug/random_crop_data.py @@ -0,0 +1,234 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/random_crop_data.py +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import numpy as np +import cv2 +import random + + +def is_poly_in_rect(poly, x, y, w, h): + poly = np.array(poly) + if poly[:, 0].min() < x or poly[:, 0].max() > x + w: + return False + if poly[:, 1].min() < y or poly[:, 1].max() > y + h: + return False + return True + + +def is_poly_outside_rect(poly, x, y, w, h): + poly = np.array(poly) + if poly[:, 0].max() < x or poly[:, 0].min() > x + w: + return True + if poly[:, 1].max() < y or poly[:, 1].min() > y + h: + return True + return False + + +def split_regions(axis): + regions = [] + min_axis = 0 + for i in range(1, axis.shape[0]): + if axis[i] != axis[i - 1] + 1: + region = axis[min_axis:i] + min_axis = i + regions.append(region) + return regions + + +def random_select(axis, max_size): + xx = np.random.choice(axis, size=2) + xmin = np.min(xx) + xmax = np.max(xx) + xmin = np.clip(xmin, 0, max_size - 1) + xmax = np.clip(xmax, 0, max_size - 1) + return xmin, xmax + + +def region_wise_random_select(regions, max_size): + selected_index = list(np.random.choice(len(regions), 2)) + selected_values = [] + for index in selected_index: + axis = regions[index] + xx = int(np.random.choice(axis, size=1)) + selected_values.append(xx) + xmin = min(selected_values) + xmax = max(selected_values) + return xmin, xmax + + +def crop_area(im, text_polys, min_crop_side_ratio, max_tries): + h, w, _ = im.shape + h_array = np.zeros(h, dtype=np.int32) + w_array = np.zeros(w, dtype=np.int32) + for points in text_polys: + points = np.round(points, decimals=0).astype(np.int32) + minx = np.min(points[:, 0]) + maxx = np.max(points[:, 0]) + w_array[minx:maxx] = 1 + miny = np.min(points[:, 1]) + maxy = np.max(points[:, 1]) + h_array[miny:maxy] = 1 + # ensure the cropped area not across a text + h_axis = np.where(h_array == 0)[0] + w_axis = np.where(w_array == 0)[0] + + if len(h_axis) == 0 or len(w_axis) == 0: + return 0, 0, w, h + + h_regions = split_regions(h_axis) + w_regions = split_regions(w_axis) + + for i in range(max_tries): + if len(w_regions) > 1: + xmin, xmax = region_wise_random_select(w_regions, w) + else: + xmin, xmax = random_select(w_axis, w) + if len(h_regions) > 1: + ymin, ymax = region_wise_random_select(h_regions, h) + else: + ymin, ymax = random_select(h_axis, h) + + if xmax - xmin < min_crop_side_ratio * w or ymax - ymin < min_crop_side_ratio * h: + # area too small + continue + num_poly_in_rect = 0 + for poly in text_polys: + if not is_poly_outside_rect(poly, xmin, ymin, xmax - xmin, + ymax - ymin): + num_poly_in_rect += 1 + break + + if num_poly_in_rect > 0: + return xmin, ymin, xmax - xmin, ymax - ymin + + return 0, 0, w, h + + +class EastRandomCropData(object): + def __init__(self, + size=(640, 640), + max_tries=10, + min_crop_side_ratio=0.1, + keep_ratio=True, + **kwargs): + self.size = size + self.max_tries = max_tries + self.min_crop_side_ratio = min_crop_side_ratio + self.keep_ratio = keep_ratio + + def __call__(self, data): + img = data['image'] + text_polys = data['polys'] + ignore_tags = data['ignore_tags'] + texts = data['texts'] + all_care_polys = [ + text_polys[i] for i, tag in enumerate(ignore_tags) if not tag + ] + # 计算crop区域 + crop_x, crop_y, crop_w, crop_h = crop_area( + img, all_care_polys, self.min_crop_side_ratio, self.max_tries) + # crop 图片 保持比例填充 + scale_w = self.size[0] / crop_w + scale_h = self.size[1] / crop_h + scale = min(scale_w, scale_h) + h = int(crop_h * scale) + w = int(crop_w * scale) + if self.keep_ratio: + padimg = np.zeros((self.size[1], self.size[0], img.shape[2]), + img.dtype) + padimg[:h, :w] = cv2.resize( + img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h)) + img = padimg + else: + img = cv2.resize( + img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], + tuple(self.size)) + # crop 文本框 + text_polys_crop = [] + ignore_tags_crop = [] + texts_crop = [] + for poly, text, tag in zip(text_polys, texts, ignore_tags): + poly = ((poly - (crop_x, crop_y)) * scale).tolist() + if not is_poly_outside_rect(poly, 0, 0, w, h): + text_polys_crop.append(poly) + ignore_tags_crop.append(tag) + texts_crop.append(text) + data['image'] = img + data['polys'] = np.array(text_polys_crop) + data['ignore_tags'] = ignore_tags_crop + data['texts'] = texts_crop + return data + + +class RandomCropImgMask(object): + def __init__(self, size, main_key, crop_keys, p=3 / 8, **kwargs): + self.size = size + self.main_key = main_key + self.crop_keys = crop_keys + self.p = p + + def __call__(self, data): + image = data['image'] + + h, w = image.shape[0:2] + th, tw = self.size + if w == tw and h == th: + return data + + mask = data[self.main_key] + if np.max(mask) > 0 and random.random() > self.p: + # make sure to crop the text region + tl = np.min(np.where(mask > 0), axis=1) - (th, tw) + tl[tl < 0] = 0 + br = np.max(np.where(mask > 0), axis=1) - (th, tw) + br[br < 0] = 0 + + br[0] = min(br[0], h - th) + br[1] = min(br[1], w - tw) + + i = random.randint(tl[0], br[0]) if tl[0] < br[0] else 0 + j = random.randint(tl[1], br[1]) if tl[1] < br[1] else 0 + else: + i = random.randint(0, h - th) if h - th > 0 else 0 + j = random.randint(0, w - tw) if w - tw > 0 else 0 + + # return i, j, th, tw + for k in data: + if k in self.crop_keys: + if len(data[k].shape) == 3: + if np.argmin(data[k].shape) == 0: + img = data[k][:, i:i + th, j:j + tw] + if img.shape[1] != img.shape[2]: + a = 1 + elif np.argmin(data[k].shape) == 2: + img = data[k][i:i + th, j:j + tw, :] + if img.shape[1] != img.shape[0]: + a = 1 + else: + img = data[k] + else: + img = data[k][i:i + th, j:j + tw] + if img.shape[0] != img.shape[1]: + a = 1 + data[k] = img + return data diff --git a/backend/ppocr/data/imaug/rec_img_aug.py b/backend/ppocr/data/imaug/rec_img_aug.py new file mode 100644 index 00000000..7483dffe --- /dev/null +++ b/backend/ppocr/data/imaug/rec_img_aug.py @@ -0,0 +1,601 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import cv2 +import numpy as np +import random +import copy +from PIL import Image +from .text_image_aug import tia_perspective, tia_stretch, tia_distort + + +class RecAug(object): + def __init__(self, use_tia=True, aug_prob=0.4, **kwargs): + self.use_tia = use_tia + self.aug_prob = aug_prob + + def __call__(self, data): + img = data['image'] + img = warp(img, 10, self.use_tia, self.aug_prob) + data['image'] = img + return data + + +class RecConAug(object): + def __init__(self, + prob=0.5, + image_shape=(32, 320, 3), + max_text_length=25, + ext_data_num=1, + **kwargs): + self.ext_data_num = ext_data_num + self.prob = prob + self.max_text_length = max_text_length + self.image_shape = image_shape + self.max_wh_ratio = self.image_shape[1] / self.image_shape[0] + + def merge_ext_data(self, data, ext_data): + ori_w = round(data['image'].shape[1] / data['image'].shape[0] * + self.image_shape[0]) + ext_w = round(ext_data['image'].shape[1] / ext_data['image'].shape[0] * + self.image_shape[0]) + data['image'] = cv2.resize(data['image'], (ori_w, self.image_shape[0])) + ext_data['image'] = cv2.resize(ext_data['image'], + (ext_w, self.image_shape[0])) + data['image'] = np.concatenate( + [data['image'], ext_data['image']], axis=1) + data["label"] += ext_data["label"] + return data + + def __call__(self, data): + rnd_num = random.random() + if rnd_num > self.prob: + return data + for idx, ext_data in enumerate(data["ext_data"]): + if len(data["label"]) + len(ext_data[ + "label"]) > self.max_text_length: + break + concat_ratio = data['image'].shape[1] / data['image'].shape[ + 0] + ext_data['image'].shape[1] / ext_data['image'].shape[0] + if concat_ratio > self.max_wh_ratio: + break + data = self.merge_ext_data(data, ext_data) + data.pop("ext_data") + return data + + +class ClsResizeImg(object): + def __init__(self, image_shape, **kwargs): + self.image_shape = image_shape + + def __call__(self, data): + img = data['image'] + norm_img, _ = resize_norm_img(img, self.image_shape) + data['image'] = norm_img + return data + + +class NRTRRecResizeImg(object): + def __init__(self, image_shape, resize_type, padding=False, **kwargs): + self.image_shape = image_shape + self.resize_type = resize_type + self.padding = padding + + def __call__(self, data): + img = data['image'] + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + image_shape = self.image_shape + if self.padding: + imgC, imgH, imgW = image_shape + # todo: change to 0 and modified image shape + h = img.shape[0] + w = img.shape[1] + ratio = w / float(h) + if math.ceil(imgH * ratio) > imgW: + resized_w = imgW + else: + resized_w = int(math.ceil(imgH * ratio)) + resized_image = cv2.resize(img, (resized_w, imgH)) + norm_img = np.expand_dims(resized_image, -1) + norm_img = norm_img.transpose((2, 0, 1)) + resized_image = norm_img.astype(np.float32) / 128. - 1. + padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) + padding_im[:, :, 0:resized_w] = resized_image + data['image'] = padding_im + return data + if self.resize_type == 'PIL': + image_pil = Image.fromarray(np.uint8(img)) + img = image_pil.resize(self.image_shape, Image.ANTIALIAS) + img = np.array(img) + if self.resize_type == 'OpenCV': + img = cv2.resize(img, self.image_shape) + norm_img = np.expand_dims(img, -1) + norm_img = norm_img.transpose((2, 0, 1)) + data['image'] = norm_img.astype(np.float32) / 128. - 1. + return data + + +class RecResizeImg(object): + def __init__(self, + image_shape, + infer_mode=False, + character_dict_path='./ppocr/utils/ppocr_keys_v1.txt', + padding=True, + **kwargs): + self.image_shape = image_shape + self.infer_mode = infer_mode + self.character_dict_path = character_dict_path + self.padding = padding + + def __call__(self, data): + img = data['image'] + if self.infer_mode and self.character_dict_path is not None: + norm_img, valid_ratio = resize_norm_img_chinese(img, + self.image_shape) + else: + norm_img, valid_ratio = resize_norm_img(img, self.image_shape, + self.padding) + data['image'] = norm_img + data['valid_ratio'] = valid_ratio + return data + + +class SRNRecResizeImg(object): + def __init__(self, image_shape, num_heads, max_text_length, **kwargs): + self.image_shape = image_shape + self.num_heads = num_heads + self.max_text_length = max_text_length + + def __call__(self, data): + img = data['image'] + norm_img = resize_norm_img_srn(img, self.image_shape) + data['image'] = norm_img + [encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \ + srn_other_inputs(self.image_shape, self.num_heads, self.max_text_length) + + data['encoder_word_pos'] = encoder_word_pos + data['gsrm_word_pos'] = gsrm_word_pos + data['gsrm_slf_attn_bias1'] = gsrm_slf_attn_bias1 + data['gsrm_slf_attn_bias2'] = gsrm_slf_attn_bias2 + return data + + +class SARRecResizeImg(object): + def __init__(self, image_shape, width_downsample_ratio=0.25, **kwargs): + self.image_shape = image_shape + self.width_downsample_ratio = width_downsample_ratio + + def __call__(self, data): + img = data['image'] + norm_img, resize_shape, pad_shape, valid_ratio = resize_norm_img_sar( + img, self.image_shape, self.width_downsample_ratio) + data['image'] = norm_img + data['resized_shape'] = resize_shape + data['pad_shape'] = pad_shape + data['valid_ratio'] = valid_ratio + return data + + +class PRENResizeImg(object): + def __init__(self, image_shape, **kwargs): + """ + Accroding to original paper's realization, it's a hard resize method here. + So maybe you should optimize it to fit for your task better. + """ + self.dst_h, self.dst_w = image_shape + + def __call__(self, data): + img = data['image'] + resized_img = cv2.resize( + img, (self.dst_w, self.dst_h), interpolation=cv2.INTER_LINEAR) + resized_img = resized_img.transpose((2, 0, 1)) / 255 + resized_img -= 0.5 + resized_img /= 0.5 + data['image'] = resized_img.astype(np.float32) + return data + + +def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25): + imgC, imgH, imgW_min, imgW_max = image_shape + h = img.shape[0] + w = img.shape[1] + valid_ratio = 1.0 + # make sure new_width is an integral multiple of width_divisor. + width_divisor = int(1 / width_downsample_ratio) + # resize + ratio = w / float(h) + resize_w = math.ceil(imgH * ratio) + if resize_w % width_divisor != 0: + resize_w = round(resize_w / width_divisor) * width_divisor + if imgW_min is not None: + resize_w = max(imgW_min, resize_w) + if imgW_max is not None: + valid_ratio = min(1.0, 1.0 * resize_w / imgW_max) + resize_w = min(imgW_max, resize_w) + resized_image = cv2.resize(img, (resize_w, imgH)) + resized_image = resized_image.astype('float32') + # norm + if image_shape[0] == 1: + resized_image = resized_image / 255 + resized_image = resized_image[np.newaxis, :] + else: + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + resize_shape = resized_image.shape + padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32) + padding_im[:, :, 0:resize_w] = resized_image + pad_shape = padding_im.shape + + return padding_im, resize_shape, pad_shape, valid_ratio + + +def resize_norm_img(img, image_shape, padding=True): + imgC, imgH, imgW = image_shape + h = img.shape[0] + w = img.shape[1] + if not padding: + resized_image = cv2.resize( + img, (imgW, imgH), interpolation=cv2.INTER_LINEAR) + resized_w = imgW + else: + ratio = w / float(h) + if math.ceil(imgH * ratio) > imgW: + resized_w = imgW + else: + resized_w = int(math.ceil(imgH * ratio)) + resized_image = cv2.resize(img, (resized_w, imgH)) + resized_image = resized_image.astype('float32') + if image_shape[0] == 1: + resized_image = resized_image / 255 + resized_image = resized_image[np.newaxis, :] + else: + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) + padding_im[:, :, 0:resized_w] = resized_image + valid_ratio = min(1.0, float(resized_w / imgW)) + return padding_im, valid_ratio + + +def resize_norm_img_chinese(img, image_shape): + imgC, imgH, imgW = image_shape + # todo: change to 0 and modified image shape + max_wh_ratio = imgW * 1.0 / imgH + h, w = img.shape[0], img.shape[1] + ratio = w * 1.0 / h + max_wh_ratio = max(max_wh_ratio, ratio) + imgW = int(imgH * max_wh_ratio) + if math.ceil(imgH * ratio) > imgW: + resized_w = imgW + else: + resized_w = int(math.ceil(imgH * ratio)) + resized_image = cv2.resize(img, (resized_w, imgH)) + resized_image = resized_image.astype('float32') + if image_shape[0] == 1: + resized_image = resized_image / 255 + resized_image = resized_image[np.newaxis, :] + else: + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) + padding_im[:, :, 0:resized_w] = resized_image + valid_ratio = min(1.0, float(resized_w / imgW)) + return padding_im, valid_ratio + + +def resize_norm_img_srn(img, image_shape): + imgC, imgH, imgW = image_shape + + img_black = np.zeros((imgH, imgW)) + im_hei = img.shape[0] + im_wid = img.shape[1] + + if im_wid <= im_hei * 1: + img_new = cv2.resize(img, (imgH * 1, imgH)) + elif im_wid <= im_hei * 2: + img_new = cv2.resize(img, (imgH * 2, imgH)) + elif im_wid <= im_hei * 3: + img_new = cv2.resize(img, (imgH * 3, imgH)) + else: + img_new = cv2.resize(img, (imgW, imgH)) + + img_np = np.asarray(img_new) + img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY) + img_black[:, 0:img_np.shape[1]] = img_np + img_black = img_black[:, :, np.newaxis] + + row, col, c = img_black.shape + c = 1 + + return np.reshape(img_black, (c, row, col)).astype(np.float32) + + +def srn_other_inputs(image_shape, num_heads, max_text_length): + + imgC, imgH, imgW = image_shape + feature_dim = int((imgH / 8) * (imgW / 8)) + + encoder_word_pos = np.array(range(0, feature_dim)).reshape( + (feature_dim, 1)).astype('int64') + gsrm_word_pos = np.array(range(0, max_text_length)).reshape( + (max_text_length, 1)).astype('int64') + + gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length)) + gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape( + [1, max_text_length, max_text_length]) + gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1, + [num_heads, 1, 1]) * [-1e9] + + gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape( + [1, max_text_length, max_text_length]) + gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2, + [num_heads, 1, 1]) * [-1e9] + + return [ + encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, + gsrm_slf_attn_bias2 + ] + + +def flag(): + """ + flag + """ + return 1 if random.random() > 0.5000001 else -1 + + +def cvtColor(img): + """ + cvtColor + """ + hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) + delta = 0.001 * random.random() * flag() + hsv[:, :, 2] = hsv[:, :, 2] * (1 + delta) + new_img = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) + return new_img + + +def blur(img): + """ + blur + """ + h, w, _ = img.shape + if h > 10 and w > 10: + return cv2.GaussianBlur(img, (5, 5), 1) + else: + return img + + +def jitter(img): + """ + jitter + """ + w, h, _ = img.shape + if h > 10 and w > 10: + thres = min(w, h) + s = int(random.random() * thres * 0.01) + src_img = img.copy() + for i in range(s): + img[i:, i:, :] = src_img[:w - i, :h - i, :] + return img + else: + return img + + +def add_gasuss_noise(image, mean=0, var=0.1): + """ + Gasuss noise + """ + + noise = np.random.normal(mean, var**0.5, image.shape) + out = image + 0.5 * noise + out = np.clip(out, 0, 255) + out = np.uint8(out) + return out + + +def get_crop(image): + """ + random crop + """ + h, w, _ = image.shape + top_min = 1 + top_max = 8 + top_crop = int(random.randint(top_min, top_max)) + top_crop = min(top_crop, h - 1) + crop_img = image.copy() + ratio = random.randint(0, 1) + if ratio: + crop_img = crop_img[top_crop:h, :, :] + else: + crop_img = crop_img[0:h - top_crop, :, :] + return crop_img + + +class Config: + """ + Config + """ + + def __init__(self, use_tia): + self.anglex = random.random() * 30 + self.angley = random.random() * 15 + self.anglez = random.random() * 10 + self.fov = 42 + self.r = 0 + self.shearx = random.random() * 0.3 + self.sheary = random.random() * 0.05 + self.borderMode = cv2.BORDER_REPLICATE + self.use_tia = use_tia + + def make(self, w, h, ang): + """ + make + """ + self.anglex = random.random() * 5 * flag() + self.angley = random.random() * 5 * flag() + self.anglez = -1 * random.random() * int(ang) * flag() + self.fov = 42 + self.r = 0 + self.shearx = 0 + self.sheary = 0 + self.borderMode = cv2.BORDER_REPLICATE + self.w = w + self.h = h + + self.perspective = self.use_tia + self.stretch = self.use_tia + self.distort = self.use_tia + + self.crop = True + self.affine = False + self.reverse = True + self.noise = True + self.jitter = True + self.blur = True + self.color = True + + +def rad(x): + """ + rad + """ + return x * np.pi / 180 + + +def get_warpR(config): + """ + get_warpR + """ + anglex, angley, anglez, fov, w, h, r = \ + config.anglex, config.angley, config.anglez, config.fov, config.w, config.h, config.r + if w > 69 and w < 112: + anglex = anglex * 1.5 + + z = np.sqrt(w**2 + h**2) / 2 / np.tan(rad(fov / 2)) + # Homogeneous coordinate transformation matrix + rx = np.array([[1, 0, 0, 0], + [0, np.cos(rad(anglex)), -np.sin(rad(anglex)), 0], [ + 0, + -np.sin(rad(anglex)), + np.cos(rad(anglex)), + 0, + ], [0, 0, 0, 1]], np.float32) + ry = np.array([[np.cos(rad(angley)), 0, np.sin(rad(angley)), 0], + [0, 1, 0, 0], [ + -np.sin(rad(angley)), + 0, + np.cos(rad(angley)), + 0, + ], [0, 0, 0, 1]], np.float32) + rz = np.array([[np.cos(rad(anglez)), np.sin(rad(anglez)), 0, 0], + [-np.sin(rad(anglez)), np.cos(rad(anglez)), 0, 0], + [0, 0, 1, 0], [0, 0, 0, 1]], np.float32) + r = rx.dot(ry).dot(rz) + # generate 4 points + pcenter = np.array([h / 2, w / 2, 0, 0], np.float32) + p1 = np.array([0, 0, 0, 0], np.float32) - pcenter + p2 = np.array([w, 0, 0, 0], np.float32) - pcenter + p3 = np.array([0, h, 0, 0], np.float32) - pcenter + p4 = np.array([w, h, 0, 0], np.float32) - pcenter + dst1 = r.dot(p1) + dst2 = r.dot(p2) + dst3 = r.dot(p3) + dst4 = r.dot(p4) + list_dst = np.array([dst1, dst2, dst3, dst4]) + org = np.array([[0, 0], [w, 0], [0, h], [w, h]], np.float32) + dst = np.zeros((4, 2), np.float32) + # Project onto the image plane + dst[:, 0] = list_dst[:, 0] * z / (z - list_dst[:, 2]) + pcenter[0] + dst[:, 1] = list_dst[:, 1] * z / (z - list_dst[:, 2]) + pcenter[1] + + warpR = cv2.getPerspectiveTransform(org, dst) + + dst1, dst2, dst3, dst4 = dst + r1 = int(min(dst1[1], dst2[1])) + r2 = int(max(dst3[1], dst4[1])) + c1 = int(min(dst1[0], dst3[0])) + c2 = int(max(dst2[0], dst4[0])) + + try: + ratio = min(1.0 * h / (r2 - r1), 1.0 * w / (c2 - c1)) + + dx = -c1 + dy = -r1 + T1 = np.float32([[1., 0, dx], [0, 1., dy], [0, 0, 1.0 / ratio]]) + ret = T1.dot(warpR) + except: + ratio = 1.0 + T1 = np.float32([[1., 0, 0], [0, 1., 0], [0, 0, 1.]]) + ret = T1 + return ret, (-r1, -c1), ratio, dst + + +def get_warpAffine(config): + """ + get_warpAffine + """ + anglez = config.anglez + rz = np.array([[np.cos(rad(anglez)), np.sin(rad(anglez)), 0], + [-np.sin(rad(anglez)), np.cos(rad(anglez)), 0]], np.float32) + return rz + + +def warp(img, ang, use_tia=True, prob=0.4): + """ + warp + """ + h, w, _ = img.shape + config = Config(use_tia=use_tia) + config.make(w, h, ang) + new_img = img + + if config.distort: + img_height, img_width = img.shape[0:2] + if random.random() <= prob and img_height >= 20 and img_width >= 20: + new_img = tia_distort(new_img, random.randint(3, 6)) + + if config.stretch: + img_height, img_width = img.shape[0:2] + if random.random() <= prob and img_height >= 20 and img_width >= 20: + new_img = tia_stretch(new_img, random.randint(3, 6)) + + if config.perspective: + if random.random() <= prob: + new_img = tia_perspective(new_img) + + if config.crop: + img_height, img_width = img.shape[0:2] + if random.random() <= prob and img_height >= 20 and img_width >= 20: + new_img = get_crop(new_img) + + if config.blur: + if random.random() <= prob: + new_img = blur(new_img) + if config.color: + if random.random() <= prob: + new_img = cvtColor(new_img) + if config.jitter: + new_img = jitter(new_img) + if config.noise: + if random.random() <= prob: + new_img = add_gasuss_noise(new_img) + if config.reverse: + if random.random() <= prob: + new_img = 255 - new_img + return new_img diff --git a/backend/ppocr/data/imaug/sast_process.py b/backend/ppocr/data/imaug/sast_process.py index 1536dceb..08d03b19 100644 --- a/backend/ppocr/data/imaug/sast_process.py +++ b/backend/ppocr/data/imaug/sast_process.py @@ -11,7 +11,10 @@ #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #See the License for the specific language governing permissions and #limitations under the License. - +""" +This part code is refered from: +https://github.com/songdejia/EAST/blob/master/data_utils.py +""" import math import cv2 import numpy as np diff --git a/backend/ppocr/data/imaug/ssl_img_aug.py b/backend/ppocr/data/imaug/ssl_img_aug.py new file mode 100644 index 00000000..f9ed6ac3 --- /dev/null +++ b/backend/ppocr/data/imaug/ssl_img_aug.py @@ -0,0 +1,60 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import cv2 +import numpy as np +import random +from PIL import Image + +from .rec_img_aug import resize_norm_img + + +class SSLRotateResize(object): + def __init__(self, + image_shape, + padding=False, + select_all=True, + mode="train", + **kwargs): + self.image_shape = image_shape + self.padding = padding + self.select_all = select_all + self.mode = mode + + def __call__(self, data): + img = data["image"] + + data["image_r90"] = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE) + data["image_r180"] = cv2.rotate(data["image_r90"], + cv2.ROTATE_90_CLOCKWISE) + data["image_r270"] = cv2.rotate(data["image_r180"], + cv2.ROTATE_90_CLOCKWISE) + + images = [] + for key in ["image", "image_r90", "image_r180", "image_r270"]: + images.append( + resize_norm_img( + data.pop(key), + image_shape=self.image_shape, + padding=self.padding)[0]) + data["image"] = np.stack(images, axis=0) + data["label"] = np.array(list(range(4))) + if not self.select_all: + data["image"] = data["image"][0::2] # just choose 0 and 180 + data["label"] = data["label"][0:2] # label needs to be continuous + if self.mode == "test": + data["image"] = data["image"][0] + data["label"] = data["label"][0] + return data diff --git a/backend/ppocr/data/imaug/text_image_aug/augment.py b/backend/ppocr/data/imaug/text_image_aug/augment.py index 1aeff373..2d15dd5f 100644 --- a/backend/ppocr/data/imaug/text_image_aug/augment.py +++ b/backend/ppocr/data/imaug/text_image_aug/augment.py @@ -11,6 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +This code is refer from: +https://github.com/RubanSeven/Text-Image-Augmentation-python/blob/master/augment.py +""" import numpy as np from .warp_mls import WarpMLS diff --git a/backend/ppocr/data/imaug/text_image_aug/warp_mls.py b/backend/ppocr/data/imaug/text_image_aug/warp_mls.py index d6cbe749..75de1111 100644 --- a/backend/ppocr/data/imaug/text_image_aug/warp_mls.py +++ b/backend/ppocr/data/imaug/text_image_aug/warp_mls.py @@ -11,6 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +This code is refer from: +https://github.com/RubanSeven/Text-Image-Augmentation-python/blob/master/warp_mls.py +""" import numpy as np @@ -161,4 +165,4 @@ def gen_img(self): dst = np.clip(dst, 0, 255) dst = np.array(dst, dtype=np.uint8) - return dst \ No newline at end of file + return dst diff --git a/backend/ppocr/data/imaug/vqa/__init__.py b/backend/ppocr/data/imaug/vqa/__init__.py new file mode 100644 index 00000000..a5025e79 --- /dev/null +++ b/backend/ppocr/data/imaug/vqa/__init__.py @@ -0,0 +1,19 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .token import VQATokenPad, VQASerTokenChunk, VQAReTokenChunk, VQAReTokenRelation + +__all__ = [ + 'VQATokenPad', 'VQASerTokenChunk', 'VQAReTokenChunk', 'VQAReTokenRelation' +] diff --git a/backend/ppocr/data/imaug/vqa/token/__init__.py b/backend/ppocr/data/imaug/vqa/token/__init__.py new file mode 100644 index 00000000..7c115661 --- /dev/null +++ b/backend/ppocr/data/imaug/vqa/token/__init__.py @@ -0,0 +1,17 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .vqa_token_chunk import VQASerTokenChunk, VQAReTokenChunk +from .vqa_token_pad import VQATokenPad +from .vqa_token_relation import VQAReTokenRelation diff --git a/backend/ppocr/data/imaug/vqa/token/vqa_token_chunk.py b/backend/ppocr/data/imaug/vqa/token/vqa_token_chunk.py new file mode 100644 index 00000000..1fa949e6 --- /dev/null +++ b/backend/ppocr/data/imaug/vqa/token/vqa_token_chunk.py @@ -0,0 +1,122 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict + + +class VQASerTokenChunk(object): + def __init__(self, max_seq_len=512, infer_mode=False, **kwargs): + self.max_seq_len = max_seq_len + self.infer_mode = infer_mode + + def __call__(self, data): + encoded_inputs_all = [] + seq_len = len(data['input_ids']) + for index in range(0, seq_len, self.max_seq_len): + chunk_beg = index + chunk_end = min(index + self.max_seq_len, seq_len) + encoded_inputs_example = {} + for key in data: + if key in [ + 'label', 'input_ids', 'labels', 'token_type_ids', + 'bbox', 'attention_mask' + ]: + if self.infer_mode and key == 'labels': + encoded_inputs_example[key] = data[key] + else: + encoded_inputs_example[key] = data[key][chunk_beg: + chunk_end] + else: + encoded_inputs_example[key] = data[key] + + encoded_inputs_all.append(encoded_inputs_example) + if len(encoded_inputs_all) == 0: + return None + return encoded_inputs_all[0] + + +class VQAReTokenChunk(object): + def __init__(self, + max_seq_len=512, + entities_labels=None, + infer_mode=False, + **kwargs): + self.max_seq_len = max_seq_len + self.entities_labels = { + 'HEADER': 0, + 'QUESTION': 1, + 'ANSWER': 2 + } if entities_labels is None else entities_labels + self.infer_mode = infer_mode + + def __call__(self, data): + # prepare data + entities = data.pop('entities') + relations = data.pop('relations') + encoded_inputs_all = [] + for index in range(0, len(data["input_ids"]), self.max_seq_len): + item = {} + for key in data: + if key in [ + 'label', 'input_ids', 'labels', 'token_type_ids', + 'bbox', 'attention_mask' + ]: + if self.infer_mode and key == 'labels': + item[key] = data[key] + else: + item[key] = data[key][index:index + self.max_seq_len] + else: + item[key] = data[key] + # select entity in current chunk + entities_in_this_span = [] + global_to_local_map = {} # + for entity_id, entity in enumerate(entities): + if (index <= entity["start"] < index + self.max_seq_len and + index <= entity["end"] < index + self.max_seq_len): + entity["start"] = entity["start"] - index + entity["end"] = entity["end"] - index + global_to_local_map[entity_id] = len(entities_in_this_span) + entities_in_this_span.append(entity) + + # select relations in current chunk + relations_in_this_span = [] + for relation in relations: + if (index <= relation["start_index"] < index + self.max_seq_len + and index <= relation["end_index"] < + index + self.max_seq_len): + relations_in_this_span.append({ + "head": global_to_local_map[relation["head"]], + "tail": global_to_local_map[relation["tail"]], + "start_index": relation["start_index"] - index, + "end_index": relation["end_index"] - index, + }) + item.update({ + "entities": self.reformat(entities_in_this_span), + "relations": self.reformat(relations_in_this_span), + }) + if len(item['entities']) > 0: + item['entities']['label'] = [ + self.entities_labels[x] for x in item['entities']['label'] + ] + encoded_inputs_all.append(item) + if len(encoded_inputs_all) == 0: + return None + return encoded_inputs_all[0] + + def reformat(self, data): + new_data = defaultdict(list) + for item in data: + for k, v in item.items(): + new_data[k].append(v) + return new_data diff --git a/backend/ppocr/data/imaug/vqa/token/vqa_token_pad.py b/backend/ppocr/data/imaug/vqa/token/vqa_token_pad.py new file mode 100644 index 00000000..8e5a20f9 --- /dev/null +++ b/backend/ppocr/data/imaug/vqa/token/vqa_token_pad.py @@ -0,0 +1,104 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +import numpy as np + + +class VQATokenPad(object): + def __init__(self, + max_seq_len=512, + pad_to_max_seq_len=True, + return_attention_mask=True, + return_token_type_ids=True, + truncation_strategy="longest_first", + return_overflowing_tokens=False, + return_special_tokens_mask=False, + infer_mode=False, + **kwargs): + self.max_seq_len = max_seq_len + self.pad_to_max_seq_len = max_seq_len + self.return_attention_mask = return_attention_mask + self.return_token_type_ids = return_token_type_ids + self.truncation_strategy = truncation_strategy + self.return_overflowing_tokens = return_overflowing_tokens + self.return_special_tokens_mask = return_special_tokens_mask + self.pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index + self.infer_mode = infer_mode + + def __call__(self, data): + needs_to_be_padded = self.pad_to_max_seq_len and len(data[ + "input_ids"]) < self.max_seq_len + + if needs_to_be_padded: + if 'tokenizer_params' in data: + tokenizer_params = data.pop('tokenizer_params') + else: + tokenizer_params = dict( + padding_side='right', pad_token_type_id=0, pad_token_id=1) + + difference = self.max_seq_len - len(data["input_ids"]) + if tokenizer_params['padding_side'] == 'right': + if self.return_attention_mask: + data["attention_mask"] = [1] * len(data[ + "input_ids"]) + [0] * difference + if self.return_token_type_ids: + data["token_type_ids"] = ( + data["token_type_ids"] + + [tokenizer_params['pad_token_type_id']] * difference) + if self.return_special_tokens_mask: + data["special_tokens_mask"] = data[ + "special_tokens_mask"] + [1] * difference + data["input_ids"] = data["input_ids"] + [ + tokenizer_params['pad_token_id'] + ] * difference + if not self.infer_mode: + data["labels"] = data[ + "labels"] + [self.pad_token_label_id] * difference + data["bbox"] = data["bbox"] + [[0, 0, 0, 0]] * difference + elif tokenizer_params['padding_side'] == 'left': + if self.return_attention_mask: + data["attention_mask"] = [0] * difference + [ + 1 + ] * len(data["input_ids"]) + if self.return_token_type_ids: + data["token_type_ids"] = ( + [tokenizer_params['pad_token_type_id']] * difference + + data["token_type_ids"]) + if self.return_special_tokens_mask: + data["special_tokens_mask"] = [ + 1 + ] * difference + data["special_tokens_mask"] + data["input_ids"] = [tokenizer_params['pad_token_id'] + ] * difference + data["input_ids"] + if not self.infer_mode: + data["labels"] = [self.pad_token_label_id + ] * difference + data["labels"] + data["bbox"] = [[0, 0, 0, 0]] * difference + data["bbox"] + else: + if self.return_attention_mask: + data["attention_mask"] = [1] * len(data["input_ids"]) + + for key in data: + if key in [ + 'input_ids', 'labels', 'token_type_ids', 'bbox', + 'attention_mask' + ]: + if self.infer_mode: + if key != 'labels': + length = min(len(data[key]), self.max_seq_len) + data[key] = data[key][:length] + else: + continue + data[key] = np.array(data[key], dtype='int64') + return data diff --git a/backend/ppocr/data/imaug/vqa/token/vqa_token_relation.py b/backend/ppocr/data/imaug/vqa/token/vqa_token_relation.py new file mode 100644 index 00000000..293988ff --- /dev/null +++ b/backend/ppocr/data/imaug/vqa/token/vqa_token_relation.py @@ -0,0 +1,67 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class VQAReTokenRelation(object): + def __init__(self, **kwargs): + pass + + def __call__(self, data): + """ + build relations + """ + entities = data['entities'] + relations = data['relations'] + id2label = data.pop('id2label') + empty_entity = data.pop('empty_entity') + entity_id_to_index_map = data.pop('entity_id_to_index_map') + + relations = list(set(relations)) + relations = [ + rel for rel in relations + if rel[0] not in empty_entity and rel[1] not in empty_entity + ] + kv_relations = [] + for rel in relations: + pair = [id2label[rel[0]], id2label[rel[1]]] + if pair == ["question", "answer"]: + kv_relations.append({ + "head": entity_id_to_index_map[rel[0]], + "tail": entity_id_to_index_map[rel[1]] + }) + elif pair == ["answer", "question"]: + kv_relations.append({ + "head": entity_id_to_index_map[rel[1]], + "tail": entity_id_to_index_map[rel[0]] + }) + else: + continue + relations = sorted( + [{ + "head": rel["head"], + "tail": rel["tail"], + "start_index": self.get_relation_span(rel, entities)[0], + "end_index": self.get_relation_span(rel, entities)[1], + } for rel in kv_relations], + key=lambda x: x["head"], ) + + data['relations'] = relations + return data + + def get_relation_span(self, rel, entities): + bound = [] + for entity_index in [rel["head"], rel["tail"]]: + bound.append(entities[entity_index]["start"]) + bound.append(entities[entity_index]["end"]) + return min(bound), max(bound) diff --git a/backend/ppocr/data/lmdb_dataset.py b/backend/ppocr/data/lmdb_dataset.py new file mode 100644 index 00000000..e1b49809 --- /dev/null +++ b/backend/ppocr/data/lmdb_dataset.py @@ -0,0 +1,118 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import os +from paddle.io import Dataset +import lmdb +import cv2 + +from .imaug import transform, create_operators + + +class LMDBDataSet(Dataset): + def __init__(self, config, mode, logger, seed=None): + super(LMDBDataSet, self).__init__() + + global_config = config['Global'] + dataset_config = config[mode]['dataset'] + loader_config = config[mode]['loader'] + batch_size = loader_config['batch_size_per_card'] + data_dir = dataset_config['data_dir'] + self.do_shuffle = loader_config['shuffle'] + + self.lmdb_sets = self.load_hierarchical_lmdb_dataset(data_dir) + logger.info("Initialize indexs of datasets:%s" % data_dir) + self.data_idx_order_list = self.dataset_traversal() + if self.do_shuffle: + np.random.shuffle(self.data_idx_order_list) + self.ops = create_operators(dataset_config['transforms'], global_config) + + ratio_list = dataset_config.get("ratio_list", [1.0]) + self.need_reset = True in [x < 1 for x in ratio_list] + + def load_hierarchical_lmdb_dataset(self, data_dir): + lmdb_sets = {} + dataset_idx = 0 + for dirpath, dirnames, filenames in os.walk(data_dir + '/'): + if not dirnames: + env = lmdb.open( + dirpath, + max_readers=32, + readonly=True, + lock=False, + readahead=False, + meminit=False) + txn = env.begin(write=False) + num_samples = int(txn.get('num-samples'.encode())) + lmdb_sets[dataset_idx] = {"dirpath":dirpath, "env":env, \ + "txn":txn, "num_samples":num_samples} + dataset_idx += 1 + return lmdb_sets + + def dataset_traversal(self): + lmdb_num = len(self.lmdb_sets) + total_sample_num = 0 + for lno in range(lmdb_num): + total_sample_num += self.lmdb_sets[lno]['num_samples'] + data_idx_order_list = np.zeros((total_sample_num, 2)) + beg_idx = 0 + for lno in range(lmdb_num): + tmp_sample_num = self.lmdb_sets[lno]['num_samples'] + end_idx = beg_idx + tmp_sample_num + data_idx_order_list[beg_idx:end_idx, 0] = lno + data_idx_order_list[beg_idx:end_idx, 1] \ + = list(range(tmp_sample_num)) + data_idx_order_list[beg_idx:end_idx, 1] += 1 + beg_idx = beg_idx + tmp_sample_num + return data_idx_order_list + + def get_img_data(self, value): + """get_img_data""" + if not value: + return None + imgdata = np.frombuffer(value, dtype='uint8') + if imgdata is None: + return None + imgori = cv2.imdecode(imgdata, 1) + if imgori is None: + return None + return imgori + + def get_lmdb_sample_info(self, txn, index): + label_key = 'label-%09d'.encode() % index + label = txn.get(label_key) + if label is None: + return None + label = label.decode('utf-8') + img_key = 'image-%09d'.encode() % index + imgbuf = txn.get(img_key) + return imgbuf, label + + def __getitem__(self, idx): + lmdb_idx, file_idx = self.data_idx_order_list[idx] + lmdb_idx = int(lmdb_idx) + file_idx = int(file_idx) + sample_info = self.get_lmdb_sample_info(self.lmdb_sets[lmdb_idx]['txn'], + file_idx) + if sample_info is None: + return self.__getitem__(np.random.randint(self.__len__())) + img, label = sample_info + data = {'image': img, 'label': label} + outs = transform(data, self.ops) + if outs is None: + return self.__getitem__(np.random.randint(self.__len__())) + return outs + + def __len__(self): + return self.data_idx_order_list.shape[0] diff --git a/backend/ppocr/data/pgnet_dataset.py b/backend/ppocr/data/pgnet_dataset.py new file mode 100644 index 00000000..6f80179c --- /dev/null +++ b/backend/ppocr/data/pgnet_dataset.py @@ -0,0 +1,106 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import os +from paddle.io import Dataset +from .imaug import transform, create_operators +import random + + +class PGDataSet(Dataset): + def __init__(self, config, mode, logger, seed=None): + super(PGDataSet, self).__init__() + + self.logger = logger + self.seed = seed + self.mode = mode + global_config = config['Global'] + dataset_config = config[mode]['dataset'] + loader_config = config[mode]['loader'] + + self.delimiter = dataset_config.get('delimiter', '\t') + label_file_list = dataset_config.pop('label_file_list') + data_source_num = len(label_file_list) + ratio_list = dataset_config.get("ratio_list", [1.0]) + if isinstance(ratio_list, (float, int)): + ratio_list = [float(ratio_list)] * int(data_source_num) + assert len( + ratio_list + ) == data_source_num, "The length of ratio_list should be the same as the file_list." + self.data_dir = dataset_config['data_dir'] + self.do_shuffle = loader_config['shuffle'] + + logger.info("Initialize indexs of datasets:%s" % label_file_list) + self.data_lines = self.get_image_info_list(label_file_list, ratio_list) + self.data_idx_order_list = list(range(len(self.data_lines))) + if mode.lower() == "train": + self.shuffle_data_random() + + self.ops = create_operators(dataset_config['transforms'], global_config) + + self.need_reset = True in [x < 1 for x in ratio_list] + + def shuffle_data_random(self): + if self.do_shuffle: + random.seed(self.seed) + random.shuffle(self.data_lines) + return + + def get_image_info_list(self, file_list, ratio_list): + if isinstance(file_list, str): + file_list = [file_list] + data_lines = [] + for idx, file in enumerate(file_list): + with open(file, "rb") as f: + lines = f.readlines() + if self.mode == "train" or ratio_list[idx] < 1.0: + random.seed(self.seed) + lines = random.sample(lines, + round(len(lines) * ratio_list[idx])) + data_lines.extend(lines) + return data_lines + + def __getitem__(self, idx): + file_idx = self.data_idx_order_list[idx] + data_line = self.data_lines[file_idx] + img_id = 0 + try: + data_line = data_line.decode('utf-8') + substr = data_line.strip("\n").split(self.delimiter) + file_name = substr[0] + label = substr[1] + img_path = os.path.join(self.data_dir, file_name) + if self.mode.lower() == 'eval': + try: + img_id = int(data_line.split(".")[0][7:]) + except: + img_id = 0 + data = {'img_path': img_path, 'label': label, 'img_id': img_id} + if not os.path.exists(img_path): + raise Exception("{} does not exist!".format(img_path)) + with open(data['img_path'], 'rb') as f: + img = f.read() + data['image'] = img + outs = transform(data, self.ops) + except Exception as e: + self.logger.error( + "When parsing line {}, error happened with msg: {}".format( + self.data_idx_order_list[idx], e)) + outs = None + if outs is None: + return self.__getitem__(np.random.randint(self.__len__())) + return outs + + def __len__(self): + return len(self.data_idx_order_list) diff --git a/backend/ppocr/data/pubtab_dataset.py b/backend/ppocr/data/pubtab_dataset.py new file mode 100644 index 00000000..671cda76 --- /dev/null +++ b/backend/ppocr/data/pubtab_dataset.py @@ -0,0 +1,114 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import os +import random +from paddle.io import Dataset +import json + +from .imaug import transform, create_operators + + +class PubTabDataSet(Dataset): + def __init__(self, config, mode, logger, seed=None): + super(PubTabDataSet, self).__init__() + self.logger = logger + + global_config = config['Global'] + dataset_config = config[mode]['dataset'] + loader_config = config[mode]['loader'] + + label_file_path = dataset_config.pop('label_file_path') + + self.data_dir = dataset_config['data_dir'] + self.do_shuffle = loader_config['shuffle'] + self.do_hard_select = False + if 'hard_select' in loader_config: + self.do_hard_select = loader_config['hard_select'] + self.hard_prob = loader_config['hard_prob'] + if self.do_hard_select: + self.img_select_prob = self.load_hard_select_prob() + self.table_select_type = None + if 'table_select_type' in loader_config: + self.table_select_type = loader_config['table_select_type'] + self.table_select_prob = loader_config['table_select_prob'] + + self.seed = seed + logger.info("Initialize indexs of datasets:%s" % label_file_path) + with open(label_file_path, "rb") as f: + self.data_lines = f.readlines() + self.data_idx_order_list = list(range(len(self.data_lines))) + if mode.lower() == "train": + self.shuffle_data_random() + self.ops = create_operators(dataset_config['transforms'], global_config) + + ratio_list = dataset_config.get("ratio_list", [1.0]) + self.need_reset = True in [x < 1 for x in ratio_list] + + def shuffle_data_random(self): + if self.do_shuffle: + random.seed(self.seed) + random.shuffle(self.data_lines) + return + + def __getitem__(self, idx): + try: + data_line = self.data_lines[idx] + data_line = data_line.decode('utf-8').strip("\n") + info = json.loads(data_line) + file_name = info['filename'] + select_flag = True + if self.do_hard_select: + prob = self.img_select_prob[file_name] + if prob < random.uniform(0, 1): + select_flag = False + + if self.table_select_type: + structure = info['html']['structure']['tokens'].copy() + structure_str = ''.join(structure) + table_type = "simple" + if 'colspan' in structure_str or 'rowspan' in structure_str: + table_type = "complex" + if table_type == "complex": + if self.table_select_prob < random.uniform(0, 1): + select_flag = False + + if select_flag: + cells = info['html']['cells'].copy() + structure = info['html']['structure'].copy() + img_path = os.path.join(self.data_dir, file_name) + data = { + 'img_path': img_path, + 'cells': cells, + 'structure': structure + } + if not os.path.exists(img_path): + raise Exception("{} does not exist!".format(img_path)) + with open(data['img_path'], 'rb') as f: + img = f.read() + data['image'] = img + outs = transform(data, self.ops) + else: + outs = None + except Exception as e: + self.logger.error( + "When parsing line {}, error happened with msg: {}".format( + data_line, e)) + outs = None + if outs is None: + return self.__getitem__(np.random.randint(self.__len__())) + return outs + + def __len__(self): + return len(self.data_idx_order_list) diff --git a/backend/ppocr/data/simple_dataset.py b/backend/ppocr/data/simple_dataset.py index d2a86b0f..b5da9b88 100644 --- a/backend/ppocr/data/simple_dataset.py +++ b/backend/ppocr/data/simple_dataset.py @@ -13,9 +13,10 @@ # limitations under the License. import numpy as np import os +import json import random +import traceback from paddle.io import Dataset - from .imaug import transform, create_operators @@ -23,6 +24,7 @@ class SimpleDataSet(Dataset): def __init__(self, config, mode, logger, seed=None): super(SimpleDataSet, self).__init__() self.logger = logger + self.mode = mode.lower() global_config = config['Global'] dataset_config = config[mode]['dataset'] @@ -40,14 +42,16 @@ def __init__(self, config, mode, logger, seed=None): ) == data_source_num, "The length of ratio_list should be the same as the file_list." self.data_dir = dataset_config['data_dir'] self.do_shuffle = loader_config['shuffle'] - self.seed = seed logger.info("Initialize indexs of datasets:%s" % label_file_list) self.data_lines = self.get_image_info_list(label_file_list, ratio_list) self.data_idx_order_list = list(range(len(self.data_lines))) - if mode.lower() == "train": + if self.mode == "train" and self.do_shuffle: self.shuffle_data_random() self.ops = create_operators(dataset_config['transforms'], global_config) + self.ext_op_transform_idx = dataset_config.get("ext_op_transform_idx", + 2) + self.need_reset = True in [x < 1 for x in ratio_list] def get_image_info_list(self, file_list, ratio_list): if isinstance(file_list, str): @@ -56,18 +60,63 @@ def get_image_info_list(self, file_list, ratio_list): for idx, file in enumerate(file_list): with open(file, "rb") as f: lines = f.readlines() - random.seed(self.seed) - lines = random.sample(lines, - round(len(lines) * ratio_list[idx])) + if self.mode == "train" or ratio_list[idx] < 1.0: + random.seed(self.seed) + lines = random.sample(lines, + round(len(lines) * ratio_list[idx])) data_lines.extend(lines) return data_lines def shuffle_data_random(self): - if self.do_shuffle: - random.seed(self.seed) - random.shuffle(self.data_lines) + random.seed(self.seed) + random.shuffle(self.data_lines) return + def _try_parse_filename_list(self, file_name): + # multiple images -> one gt label + if len(file_name) > 0 and file_name[0] == "[": + try: + info = json.loads(file_name) + file_name = random.choice(info) + except: + pass + return file_name + + def get_ext_data(self): + ext_data_num = 0 + for op in self.ops: + if hasattr(op, 'ext_data_num'): + ext_data_num = getattr(op, 'ext_data_num') + break + load_data_ops = self.ops[:self.ext_op_transform_idx] + ext_data = [] + + while len(ext_data) < ext_data_num: + file_idx = self.data_idx_order_list[np.random.randint(self.__len__( + ))] + data_line = self.data_lines[file_idx] + data_line = data_line.decode('utf-8') + substr = data_line.strip("\n").split(self.delimiter) + file_name = substr[0] + file_name = self._try_parse_filename_list(file_name) + label = substr[1] + img_path = os.path.join(self.data_dir, file_name) + data = {'img_path': img_path, 'label': label} + if not os.path.exists(img_path): + continue + with open(data['img_path'], 'rb') as f: + img = f.read() + data['image'] = img + data = transform(data, load_data_ops) + + if data is None: + continue + if 'polys' in data.keys(): + if data['polys'].shape[1] != 4: + continue + ext_data.append(data) + return ext_data + def __getitem__(self, idx): file_idx = self.data_idx_order_list[idx] data_line = self.data_lines[file_idx] @@ -75,6 +124,7 @@ def __getitem__(self, idx): data_line = data_line.decode('utf-8') substr = data_line.strip("\n").split(self.delimiter) file_name = substr[0] + file_name = self._try_parse_filename_list(file_name) label = substr[1] img_path = os.path.join(self.data_dir, file_name) data = {'img_path': img_path, 'label': label} @@ -83,14 +133,18 @@ def __getitem__(self, idx): with open(data['img_path'], 'rb') as f: img = f.read() data['image'] = img + data['ext_data'] = self.get_ext_data() outs = transform(data, self.ops) - except Exception as e: + except: self.logger.error( "When parsing line {}, error happened with msg: {}".format( - data_line, e)) + data_line, traceback.format_exc())) outs = None if outs is None: - return self.__getitem__(np.random.randint(self.__len__())) + # during evaluation, we should fix the idx to get same results for many times of evaluation. + rnd_idx = np.random.randint(self.__len__( + )) if self.mode == "train" else (idx + 1) % self.__len__() + return self.__getitem__(rnd_idx) return outs def __len__(self): diff --git a/backend/ppocr/losses/__init__.py b/backend/ppocr/losses/__init__.py index 3881abf7..de8419b7 100755 --- a/backend/ppocr/losses/__init__.py +++ b/backend/ppocr/losses/__init__.py @@ -13,27 +13,56 @@ # limitations under the License. import copy +import paddle +import paddle.nn as nn +# basic_loss +from .basic_loss import LossFromOutput -def build_loss(config): - # det loss - from .det_db_loss import DBLoss - from .det_east_loss import EASTLoss - from .det_sast_loss import SASTLoss +# det loss +from .det_db_loss import DBLoss +from .det_east_loss import EASTLoss +from .det_sast_loss import SASTLoss +from .det_pse_loss import PSELoss +from .det_fce_loss import FCELoss + +# rec loss +from .rec_ctc_loss import CTCLoss +from .rec_att_loss import AttentionLoss +from .rec_srn_loss import SRNLoss +from .rec_nrtr_loss import NRTRLoss +from .rec_sar_loss import SARLoss +from .rec_aster_loss import AsterLoss +from .rec_pren_loss import PRENLoss +from .rec_multi_loss import MultiLoss + +# cls loss +from .cls_loss import ClsLoss + +# e2e loss +from .e2e_pg_loss import PGLoss +from .kie_sdmgr_loss import SDMGRLoss + +# basic loss function +from .basic_loss import DistanceLoss - # rec loss - from .rec_ctc_loss import CTCLoss - from .rec_att_loss import AttentionLoss - from .rec_srn_loss import SRNLoss +# combined loss function +from .combined_loss import CombinedLoss - # cls loss - from .cls_loss import ClsLoss +# table loss +from .table_att_loss import TableAttentionLoss +# vqa token loss +from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss + + +def build_loss(config): support_dict = [ - 'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss', - 'SRNLoss' + 'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'FCELoss', 'CTCLoss', + 'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', + 'NRTRLoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss', + 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss' ] - config = copy.deepcopy(config) module_name = config.pop('name') assert module_name in support_dict, Exception('loss only support {}'.format( diff --git a/backend/ppocr/losses/ace_loss.py b/backend/ppocr/losses/ace_loss.py new file mode 100644 index 00000000..915b99e6 --- /dev/null +++ b/backend/ppocr/losses/ace_loss.py @@ -0,0 +1,52 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This code is refer from: https://github.com/viig99/LS-ACELoss + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn as nn + + +class ACELoss(nn.Layer): + def __init__(self, **kwargs): + super().__init__() + self.loss_func = nn.CrossEntropyLoss( + weight=None, + ignore_index=0, + reduction='none', + soft_label=True, + axis=-1) + + def __call__(self, predicts, batch): + if isinstance(predicts, (list, tuple)): + predicts = predicts[-1] + + B, N = predicts.shape[:2] + div = paddle.to_tensor([N]).astype('float32') + + predicts = nn.functional.softmax(predicts, axis=-1) + aggregation_preds = paddle.sum(predicts, axis=1) + aggregation_preds = paddle.divide(aggregation_preds, div) + + length = batch[2].astype("float32") + batch = batch[3].astype("float32") + batch[:, 0] = paddle.subtract(div, length) + batch = paddle.divide(batch, div) + + loss = self.loss_func(aggregation_preds, batch) + return {"loss_ace": loss} diff --git a/backend/ppocr/losses/basic_loss.py b/backend/ppocr/losses/basic_loss.py new file mode 100644 index 00000000..2df96ea2 --- /dev/null +++ b/backend/ppocr/losses/basic_loss.py @@ -0,0 +1,155 @@ +#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from paddle.nn import L1Loss +from paddle.nn import MSELoss as L2Loss +from paddle.nn import SmoothL1Loss + + +class CELoss(nn.Layer): + def __init__(self, epsilon=None): + super().__init__() + if epsilon is not None and (epsilon <= 0 or epsilon >= 1): + epsilon = None + self.epsilon = epsilon + + def _labelsmoothing(self, target, class_num): + if target.shape[-1] != class_num: + one_hot_target = F.one_hot(target, class_num) + else: + one_hot_target = target + soft_target = F.label_smooth(one_hot_target, epsilon=self.epsilon) + soft_target = paddle.reshape(soft_target, shape=[-1, class_num]) + return soft_target + + def forward(self, x, label): + loss_dict = {} + if self.epsilon is not None: + class_num = x.shape[-1] + label = self._labelsmoothing(label, class_num) + x = -F.log_softmax(x, axis=-1) + loss = paddle.sum(x * label, axis=-1) + else: + if label.shape[-1] == x.shape[-1]: + label = F.softmax(label, axis=-1) + soft_label = True + else: + soft_label = False + loss = F.cross_entropy(x, label=label, soft_label=soft_label) + return loss + + +class KLJSLoss(object): + def __init__(self, mode='kl'): + assert mode in ['kl', 'js', 'KL', 'JS' + ], "mode can only be one of ['kl', 'js', 'KL', 'JS']" + self.mode = mode + + def __call__(self, p1, p2, reduction="mean"): + + loss = paddle.multiply(p2, paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5)) + + if self.mode.lower() == "js": + loss += paddle.multiply( + p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5)) + loss *= 0.5 + if reduction == "mean": + loss = paddle.mean(loss, axis=[1, 2]) + elif reduction == "none" or reduction is None: + return loss + else: + loss = paddle.sum(loss, axis=[1, 2]) + + return loss + + +class DMLLoss(nn.Layer): + """ + DMLLoss + """ + + def __init__(self, act=None, use_log=False): + super().__init__() + if act is not None: + assert act in ["softmax", "sigmoid"] + if act == "softmax": + self.act = nn.Softmax(axis=-1) + elif act == "sigmoid": + self.act = nn.Sigmoid() + else: + self.act = None + + self.use_log = use_log + self.jskl_loss = KLJSLoss(mode="js") + + def _kldiv(self, x, target): + eps = 1.0e-10 + loss = target * (paddle.log(target + eps) - x) + # batch mean loss + loss = paddle.sum(loss) / loss.shape[0] + return loss + + def forward(self, out1, out2): + if self.act is not None: + out1 = self.act(out1) + 1e-10 + out2 = self.act(out2) + 1e-10 + if self.use_log: + # for recognition distillation, log is needed for feature map + log_out1 = paddle.log(out1) + log_out2 = paddle.log(out2) + loss = ( + self._kldiv(log_out1, out2) + self._kldiv(log_out2, out1)) / 2.0 + else: + # for detection distillation log is not needed + loss = self.jskl_loss(out1, out2) + return loss + + +class DistanceLoss(nn.Layer): + """ + DistanceLoss: + mode: loss mode + """ + + def __init__(self, mode="l2", **kargs): + super().__init__() + assert mode in ["l1", "l2", "smooth_l1"] + if mode == "l1": + self.loss_func = nn.L1Loss(**kargs) + elif mode == "l2": + self.loss_func = nn.MSELoss(**kargs) + elif mode == "smooth_l1": + self.loss_func = nn.SmoothL1Loss(**kargs) + + def forward(self, x, y): + return self.loss_func(x, y) + + +class LossFromOutput(nn.Layer): + def __init__(self, key='loss', reduction='none'): + super().__init__() + self.key = key + self.reduction = reduction + + def forward(self, predicts, batch): + loss = predicts[self.key] + if self.reduction == 'mean': + loss = paddle.mean(loss) + elif self.reduction == 'sum': + loss = paddle.sum(loss) + return {'loss': loss} diff --git a/backend/ppocr/losses/center_loss.py b/backend/ppocr/losses/center_loss.py new file mode 100644 index 00000000..f62b8af3 --- /dev/null +++ b/backend/ppocr/losses/center_loss.py @@ -0,0 +1,88 @@ +#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +# This code is refer from: https://github.com/KaiyangZhou/pytorch-center-loss + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import os +import pickle + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +class CenterLoss(nn.Layer): + """ + Reference: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. + """ + + def __init__(self, num_classes=6625, feat_dim=96, center_file_path=None): + super().__init__() + self.num_classes = num_classes + self.feat_dim = feat_dim + self.centers = paddle.randn( + shape=[self.num_classes, self.feat_dim]).astype("float64") + + if center_file_path is not None: + assert os.path.exists( + center_file_path + ), f"center path({center_file_path}) must exist when it is not None." + with open(center_file_path, 'rb') as f: + char_dict = pickle.load(f) + for key in char_dict.keys(): + self.centers[key] = paddle.to_tensor(char_dict[key]) + + def __call__(self, predicts, batch): + assert isinstance(predicts, (list, tuple)) + features, predicts = predicts + + feats_reshape = paddle.reshape( + features, [-1, features.shape[-1]]).astype("float64") + label = paddle.argmax(predicts, axis=2) + label = paddle.reshape(label, [label.shape[0] * label.shape[1]]) + + batch_size = feats_reshape.shape[0] + + #calc l2 distance between feats and centers + square_feat = paddle.sum(paddle.square(feats_reshape), + axis=1, + keepdim=True) + square_feat = paddle.expand(square_feat, [batch_size, self.num_classes]) + + square_center = paddle.sum(paddle.square(self.centers), + axis=1, + keepdim=True) + square_center = paddle.expand( + square_center, [self.num_classes, batch_size]).astype("float64") + square_center = paddle.transpose(square_center, [1, 0]) + + distmat = paddle.add(square_feat, square_center) + feat_dot_center = paddle.matmul(feats_reshape, + paddle.transpose(self.centers, [1, 0])) + distmat = distmat - 2.0 * feat_dot_center + + #generate the mask + classes = paddle.arange(self.num_classes).astype("int64") + label = paddle.expand( + paddle.unsqueeze(label, 1), (batch_size, self.num_classes)) + mask = paddle.equal( + paddle.expand(classes, [batch_size, self.num_classes]), + label).astype("float64") + dist = paddle.multiply(distmat, mask) + + loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size + return {'loss_center': loss} diff --git a/backend/ppocr/losses/cls_loss.py b/backend/ppocr/losses/cls_loss.py index 41c7db02..abc5e5b7 100755 --- a/backend/ppocr/losses/cls_loss.py +++ b/backend/ppocr/losses/cls_loss.py @@ -24,7 +24,7 @@ def __init__(self, **kwargs): super(ClsLoss, self).__init__() self.loss_func = nn.CrossEntropyLoss(reduction='mean') - def __call__(self, predicts, batch): - label = batch[1] + def forward(self, predicts, batch): + label = batch[1].astype("int64") loss = self.loss_func(input=predicts, label=label) return {'loss': loss} diff --git a/backend/ppocr/losses/combined_loss.py b/backend/ppocr/losses/combined_loss.py new file mode 100644 index 00000000..f4cdee8f --- /dev/null +++ b/backend/ppocr/losses/combined_loss.py @@ -0,0 +1,69 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn + +from .rec_ctc_loss import CTCLoss +from .center_loss import CenterLoss +from .ace_loss import ACELoss +from .rec_sar_loss import SARLoss + +from .distillation_loss import DistillationCTCLoss +from .distillation_loss import DistillationSARLoss +from .distillation_loss import DistillationDMLLoss +from .distillation_loss import DistillationDistanceLoss, DistillationDBLoss, DistillationDilaDBLoss + + +class CombinedLoss(nn.Layer): + """ + CombinedLoss: + a combionation of loss function + """ + + def __init__(self, loss_config_list=None): + super().__init__() + self.loss_func = [] + self.loss_weight = [] + assert isinstance(loss_config_list, list), ( + 'operator config should be a list') + for config in loss_config_list: + assert isinstance(config, + dict) and len(config) == 1, "yaml format error" + name = list(config)[0] + param = config[name] + assert "weight" in param, "weight must be in param, but param just contains {}".format( + param.keys()) + self.loss_weight.append(param.pop("weight")) + self.loss_func.append(eval(name)(**param)) + + def forward(self, input, batch, **kargs): + loss_dict = {} + loss_all = 0. + for idx, loss_func in enumerate(self.loss_func): + loss = loss_func(input, batch, **kargs) + if isinstance(loss, paddle.Tensor): + loss = {"loss_{}_{}".format(str(loss), idx): loss} + + weight = self.loss_weight[idx] + + loss = {key: loss[key] * weight for key in loss} + + if "loss" in loss: + loss_all += loss["loss"] + else: + loss_all += paddle.add_n(list(loss.values())) + loss_dict.update(loss) + loss_dict["loss"] = loss_all + return loss_dict diff --git a/backend/ppocr/losses/det_basic_loss.py b/backend/ppocr/losses/det_basic_loss.py index 57b3667d..61ea579b 100644 --- a/backend/ppocr/losses/det_basic_loss.py +++ b/backend/ppocr/losses/det_basic_loss.py @@ -11,7 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +""" +This code is refer from: +https://github.com/WenmuZhou/DBNet.pytorch/blob/master/models/losses/basic_loss.py +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -75,12 +78,6 @@ def forward(self, pred, gt, mask=None): mask (variable): masked maps. return: (variable) balanced loss """ - # if self.main_loss_type in ['DiceLoss']: - # # For the loss that returns to scalar value, perform ohem on the mask - # mask = ohem_batch(pred, gt, mask, self.negative_ratio) - # loss = self.loss(pred, gt, mask) - # return loss - positive = gt * mask negative = (1 - gt) * mask @@ -154,52 +151,3 @@ def __init__(self, reduction='mean'): def forward(self, input, label, mask=None, weight=None, name=None): loss = F.binary_cross_entropy(input, label, reduction=self.reduction) return loss - - -def ohem_single(score, gt_text, training_mask, ohem_ratio): - pos_num = (int)(np.sum(gt_text > 0.5)) - ( - int)(np.sum((gt_text > 0.5) & (training_mask <= 0.5))) - - if pos_num == 0: - # selected_mask = gt_text.copy() * 0 # may be not good - selected_mask = training_mask - selected_mask = selected_mask.reshape( - 1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32') - return selected_mask - - neg_num = (int)(np.sum(gt_text <= 0.5)) - neg_num = (int)(min(pos_num * ohem_ratio, neg_num)) - - if neg_num == 0: - selected_mask = training_mask - selected_mask = selected_mask.reshape( - 1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32') - return selected_mask - - neg_score = score[gt_text <= 0.5] - # 将负样本得分从高到低排序 - neg_score_sorted = np.sort(-neg_score) - threshold = -neg_score_sorted[neg_num - 1] - # 选出 得分高的 负样本 和正样本 的 mask - selected_mask = ((score >= threshold) | - (gt_text > 0.5)) & (training_mask > 0.5) - selected_mask = selected_mask.reshape( - 1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32') - return selected_mask - - -def ohem_batch(scores, gt_texts, training_masks, ohem_ratio): - scores = scores.numpy() - gt_texts = gt_texts.numpy() - training_masks = training_masks.numpy() - - selected_masks = [] - for i in range(scores.shape[0]): - selected_masks.append( - ohem_single(scores[i, :, :], gt_texts[i, :, :], training_masks[ - i, :, :], ohem_ratio)) - - selected_masks = np.concatenate(selected_masks, 0) - selected_masks = paddle.to_variable(selected_masks) - - return selected_masks diff --git a/backend/ppocr/losses/det_db_loss.py b/backend/ppocr/losses/det_db_loss.py index b079aabf..708ffbdb 100755 --- a/backend/ppocr/losses/det_db_loss.py +++ b/backend/ppocr/losses/det_db_loss.py @@ -11,6 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +This code is refer from: +https://github.com/WenmuZhou/DBNet.pytorch/blob/master/models/losses/DB_loss.py +""" from __future__ import absolute_import from __future__ import division diff --git a/backend/ppocr/losses/det_fce_loss.py b/backend/ppocr/losses/det_fce_loss.py new file mode 100644 index 00000000..d7dfb5aa --- /dev/null +++ b/backend/ppocr/losses/det_fce_loss.py @@ -0,0 +1,227 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textdet/losses/fce_loss.py +""" + +import numpy as np +from paddle import nn +import paddle +import paddle.nn.functional as F +from functools import partial + + +def multi_apply(func, *args, **kwargs): + pfunc = partial(func, **kwargs) if kwargs else func + map_results = map(pfunc, *args) + return tuple(map(list, zip(*map_results))) + + +class FCELoss(nn.Layer): + """The class for implementing FCENet loss + FCENet(CVPR2021): Fourier Contour Embedding for Arbitrary-shaped + Text Detection + + [https://arxiv.org/abs/2104.10442] + + Args: + fourier_degree (int) : The maximum Fourier transform degree k. + num_sample (int) : The sampling points number of regression + loss. If it is too small, fcenet tends to be overfitting. + ohem_ratio (float): the negative/positive ratio in OHEM. + """ + + def __init__(self, fourier_degree, num_sample, ohem_ratio=3.): + super().__init__() + self.fourier_degree = fourier_degree + self.num_sample = num_sample + self.ohem_ratio = ohem_ratio + + def forward(self, preds, labels): + assert isinstance(preds, dict) + preds = preds['levels'] + + p3_maps, p4_maps, p5_maps = labels[1:] + assert p3_maps[0].shape[0] == 4 * self.fourier_degree + 5,\ + 'fourier degree not equal in FCEhead and FCEtarget' + + # to tensor + gts = [p3_maps, p4_maps, p5_maps] + for idx, maps in enumerate(gts): + gts[idx] = paddle.to_tensor(np.stack(maps)) + + losses = multi_apply(self.forward_single, preds, gts) + + loss_tr = paddle.to_tensor(0.).astype('float32') + loss_tcl = paddle.to_tensor(0.).astype('float32') + loss_reg_x = paddle.to_tensor(0.).astype('float32') + loss_reg_y = paddle.to_tensor(0.).astype('float32') + loss_all = paddle.to_tensor(0.).astype('float32') + + for idx, loss in enumerate(losses): + loss_all += sum(loss) + if idx == 0: + loss_tr += sum(loss) + elif idx == 1: + loss_tcl += sum(loss) + elif idx == 2: + loss_reg_x += sum(loss) + else: + loss_reg_y += sum(loss) + + results = dict( + loss=loss_all, + loss_text=loss_tr, + loss_center=loss_tcl, + loss_reg_x=loss_reg_x, + loss_reg_y=loss_reg_y, ) + return results + + def forward_single(self, pred, gt): + cls_pred = paddle.transpose(pred[0], (0, 2, 3, 1)) + reg_pred = paddle.transpose(pred[1], (0, 2, 3, 1)) + gt = paddle.transpose(gt, (0, 2, 3, 1)) + + k = 2 * self.fourier_degree + 1 + tr_pred = paddle.reshape(cls_pred[:, :, :, :2], (-1, 2)) + tcl_pred = paddle.reshape(cls_pred[:, :, :, 2:], (-1, 2)) + x_pred = paddle.reshape(reg_pred[:, :, :, 0:k], (-1, k)) + y_pred = paddle.reshape(reg_pred[:, :, :, k:2 * k], (-1, k)) + + tr_mask = gt[:, :, :, :1].reshape([-1]) + tcl_mask = gt[:, :, :, 1:2].reshape([-1]) + train_mask = gt[:, :, :, 2:3].reshape([-1]) + x_map = paddle.reshape(gt[:, :, :, 3:3 + k], (-1, k)) + y_map = paddle.reshape(gt[:, :, :, 3 + k:], (-1, k)) + + tr_train_mask = (train_mask * tr_mask).astype('bool') + tr_train_mask2 = paddle.concat( + [tr_train_mask.unsqueeze(1), tr_train_mask.unsqueeze(1)], axis=1) + # tr loss + loss_tr = self.ohem(tr_pred, tr_mask, train_mask) + # tcl loss + loss_tcl = paddle.to_tensor(0.).astype('float32') + tr_neg_mask = tr_train_mask.logical_not() + tr_neg_mask2 = paddle.concat( + [tr_neg_mask.unsqueeze(1), tr_neg_mask.unsqueeze(1)], axis=1) + if tr_train_mask.sum().item() > 0: + loss_tcl_pos = F.cross_entropy( + tcl_pred.masked_select(tr_train_mask2).reshape([-1, 2]), + tcl_mask.masked_select(tr_train_mask).astype('int64')) + loss_tcl_neg = F.cross_entropy( + tcl_pred.masked_select(tr_neg_mask2).reshape([-1, 2]), + tcl_mask.masked_select(tr_neg_mask).astype('int64')) + loss_tcl = loss_tcl_pos + 0.5 * loss_tcl_neg + + # regression loss + loss_reg_x = paddle.to_tensor(0.).astype('float32') + loss_reg_y = paddle.to_tensor(0.).astype('float32') + if tr_train_mask.sum().item() > 0: + weight = (tr_mask.masked_select(tr_train_mask.astype('bool')) + .astype('float32') + tcl_mask.masked_select( + tr_train_mask.astype('bool')).astype('float32')) / 2 + weight = weight.reshape([-1, 1]) + + ft_x, ft_y = self.fourier2poly(x_map, y_map) + ft_x_pre, ft_y_pre = self.fourier2poly(x_pred, y_pred) + + dim = ft_x.shape[1] + + tr_train_mask3 = paddle.concat( + [tr_train_mask.unsqueeze(1) for i in range(dim)], axis=1) + + loss_reg_x = paddle.mean(weight * F.smooth_l1_loss( + ft_x_pre.masked_select(tr_train_mask3).reshape([-1, dim]), + ft_x.masked_select(tr_train_mask3).reshape([-1, dim]), + reduction='none')) + loss_reg_y = paddle.mean(weight * F.smooth_l1_loss( + ft_y_pre.masked_select(tr_train_mask3).reshape([-1, dim]), + ft_y.masked_select(tr_train_mask3).reshape([-1, dim]), + reduction='none')) + + return loss_tr, loss_tcl, loss_reg_x, loss_reg_y + + def ohem(self, predict, target, train_mask): + + pos = (target * train_mask).astype('bool') + neg = ((1 - target) * train_mask).astype('bool') + + pos2 = paddle.concat([pos.unsqueeze(1), pos.unsqueeze(1)], axis=1) + neg2 = paddle.concat([neg.unsqueeze(1), neg.unsqueeze(1)], axis=1) + + n_pos = pos.astype('float32').sum() + + if n_pos.item() > 0: + loss_pos = F.cross_entropy( + predict.masked_select(pos2).reshape([-1, 2]), + target.masked_select(pos).astype('int64'), + reduction='sum') + loss_neg = F.cross_entropy( + predict.masked_select(neg2).reshape([-1, 2]), + target.masked_select(neg).astype('int64'), + reduction='none') + n_neg = min( + int(neg.astype('float32').sum().item()), + int(self.ohem_ratio * n_pos.astype('float32'))) + else: + loss_pos = paddle.to_tensor(0.) + loss_neg = F.cross_entropy( + predict.masked_select(neg2).reshape([-1, 2]), + target.masked_select(neg).astype('int64'), + reduction='none') + n_neg = 100 + if len(loss_neg) > n_neg: + loss_neg, _ = paddle.topk(loss_neg, n_neg) + + return (loss_pos + loss_neg.sum()) / (n_pos + n_neg).astype('float32') + + def fourier2poly(self, real_maps, imag_maps): + """Transform Fourier coefficient maps to polygon maps. + + Args: + real_maps (tensor): A map composed of the real parts of the + Fourier coefficients, whose shape is (-1, 2k+1) + imag_maps (tensor):A map composed of the imag parts of the + Fourier coefficients, whose shape is (-1, 2k+1) + + Returns + x_maps (tensor): A map composed of the x value of the polygon + represented by n sample points (xn, yn), whose shape is (-1, n) + y_maps (tensor): A map composed of the y value of the polygon + represented by n sample points (xn, yn), whose shape is (-1, n) + """ + + k_vect = paddle.arange( + -self.fourier_degree, self.fourier_degree + 1, + dtype='float32').reshape([-1, 1]) + i_vect = paddle.arange( + 0, self.num_sample, dtype='float32').reshape([1, -1]) + + transform_matrix = 2 * np.pi / self.num_sample * paddle.matmul(k_vect, + i_vect) + + x1 = paddle.einsum('ak, kn-> an', real_maps, + paddle.cos(transform_matrix)) + x2 = paddle.einsum('ak, kn-> an', imag_maps, + paddle.sin(transform_matrix)) + y1 = paddle.einsum('ak, kn-> an', real_maps, + paddle.sin(transform_matrix)) + y2 = paddle.einsum('ak, kn-> an', imag_maps, + paddle.cos(transform_matrix)) + + x_maps = x1 - x2 + y_maps = y1 + y2 + + return x_maps, y_maps diff --git a/backend/ppocr/losses/det_pse_loss.py b/backend/ppocr/losses/det_pse_loss.py new file mode 100644 index 00000000..6b31343e --- /dev/null +++ b/backend/ppocr/losses/det_pse_loss.py @@ -0,0 +1,149 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/whai362/PSENet/blob/python3/models/head/psenet_head.py +""" + +import paddle +from paddle import nn +from paddle.nn import functional as F +import numpy as np +from ppocr.utils.iou import iou + + +class PSELoss(nn.Layer): + def __init__(self, + alpha, + ohem_ratio=3, + kernel_sample_mask='pred', + reduction='sum', + eps=1e-6, + **kwargs): + """Implement PSE Loss. + """ + super(PSELoss, self).__init__() + assert reduction in ['sum', 'mean', 'none'] + self.alpha = alpha + self.ohem_ratio = ohem_ratio + self.kernel_sample_mask = kernel_sample_mask + self.reduction = reduction + self.eps = eps + + def forward(self, outputs, labels): + predicts = outputs['maps'] + predicts = F.interpolate(predicts, scale_factor=4) + + texts = predicts[:, 0, :, :] + kernels = predicts[:, 1:, :, :] + gt_texts, gt_kernels, training_masks = labels[1:] + + # text loss + selected_masks = self.ohem_batch(texts, gt_texts, training_masks) + + loss_text = self.dice_loss(texts, gt_texts, selected_masks) + iou_text = iou((texts > 0).astype('int64'), + gt_texts, + training_masks, + reduce=False) + losses = dict(loss_text=loss_text, iou_text=iou_text) + + # kernel loss + loss_kernels = [] + if self.kernel_sample_mask == 'gt': + selected_masks = gt_texts * training_masks + elif self.kernel_sample_mask == 'pred': + selected_masks = ( + F.sigmoid(texts) > 0.5).astype('float32') * training_masks + + for i in range(kernels.shape[1]): + kernel_i = kernels[:, i, :, :] + gt_kernel_i = gt_kernels[:, i, :, :] + loss_kernel_i = self.dice_loss(kernel_i, gt_kernel_i, + selected_masks) + loss_kernels.append(loss_kernel_i) + loss_kernels = paddle.mean(paddle.stack(loss_kernels, axis=1), axis=1) + iou_kernel = iou((kernels[:, -1, :, :] > 0).astype('int64'), + gt_kernels[:, -1, :, :], + training_masks * gt_texts, + reduce=False) + losses.update(dict(loss_kernels=loss_kernels, iou_kernel=iou_kernel)) + loss = self.alpha * loss_text + (1 - self.alpha) * loss_kernels + losses['loss'] = loss + if self.reduction == 'sum': + losses = {x: paddle.sum(v) for x, v in losses.items()} + elif self.reduction == 'mean': + losses = {x: paddle.mean(v) for x, v in losses.items()} + return losses + + def dice_loss(self, input, target, mask): + input = F.sigmoid(input) + + input = input.reshape([input.shape[0], -1]) + target = target.reshape([target.shape[0], -1]) + mask = mask.reshape([mask.shape[0], -1]) + + input = input * mask + target = target * mask + + a = paddle.sum(input * target, 1) + b = paddle.sum(input * input, 1) + self.eps + c = paddle.sum(target * target, 1) + self.eps + d = (2 * a) / (b + c) + return 1 - d + + def ohem_single(self, score, gt_text, training_mask, ohem_ratio=3): + pos_num = int(paddle.sum((gt_text > 0.5).astype('float32'))) - int( + paddle.sum( + paddle.logical_and((gt_text > 0.5), (training_mask <= 0.5)) + .astype('float32'))) + + if pos_num == 0: + selected_mask = training_mask + selected_mask = selected_mask.reshape( + [1, selected_mask.shape[0], selected_mask.shape[1]]).astype( + 'float32') + return selected_mask + + neg_num = int(paddle.sum((gt_text <= 0.5).astype('float32'))) + neg_num = int(min(pos_num * ohem_ratio, neg_num)) + + if neg_num == 0: + selected_mask = training_mask + selected_mask = selected_mask.reshape( + [1, selected_mask.shape[0], selected_mask.shape[1]]).astype( + 'float32') + return selected_mask + + neg_score = paddle.masked_select(score, gt_text <= 0.5) + neg_score_sorted = paddle.sort(-neg_score) + threshold = -neg_score_sorted[neg_num - 1] + + selected_mask = paddle.logical_and( + paddle.logical_or((score >= threshold), (gt_text > 0.5)), + (training_mask > 0.5)) + selected_mask = selected_mask.reshape( + [1, selected_mask.shape[0], selected_mask.shape[1]]).astype( + 'float32') + return selected_mask + + def ohem_batch(self, scores, gt_texts, training_masks, ohem_ratio=3): + selected_masks = [] + for i in range(scores.shape[0]): + selected_masks.append( + self.ohem_single(scores[i, :, :], gt_texts[i, :, :], + training_masks[i, :, :], ohem_ratio)) + + selected_masks = paddle.concat(selected_masks, 0).astype('float32') + return selected_masks diff --git a/backend/ppocr/losses/distillation_loss.py b/backend/ppocr/losses/distillation_loss.py new file mode 100644 index 00000000..565b066d --- /dev/null +++ b/backend/ppocr/losses/distillation_loss.py @@ -0,0 +1,324 @@ +#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import paddle +import paddle.nn as nn +import numpy as np +import cv2 + +from .rec_ctc_loss import CTCLoss +from .rec_sar_loss import SARLoss +from .basic_loss import DMLLoss +from .basic_loss import DistanceLoss +from .det_db_loss import DBLoss +from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss + + +def _sum_loss(loss_dict): + if "loss" in loss_dict.keys(): + return loss_dict + else: + loss_dict["loss"] = 0. + for k, value in loss_dict.items(): + if k == "loss": + continue + else: + loss_dict["loss"] += value + return loss_dict + + +class DistillationDMLLoss(DMLLoss): + """ + """ + + def __init__(self, + model_name_pairs=[], + act=None, + use_log=False, + key=None, + multi_head=False, + dis_head='ctc', + maps_name=None, + name="dml"): + super().__init__(act=act, use_log=use_log) + assert isinstance(model_name_pairs, list) + self.key = key + self.multi_head = multi_head + self.dis_head = dis_head + self.model_name_pairs = self._check_model_name_pairs(model_name_pairs) + self.name = name + self.maps_name = self._check_maps_name(maps_name) + + def _check_model_name_pairs(self, model_name_pairs): + if not isinstance(model_name_pairs, list): + return [] + elif isinstance(model_name_pairs[0], list) and isinstance( + model_name_pairs[0][0], str): + return model_name_pairs + else: + return [model_name_pairs] + + def _check_maps_name(self, maps_name): + if maps_name is None: + return None + elif type(maps_name) == str: + return [maps_name] + elif type(maps_name) == list: + return [maps_name] + else: + return None + + def _slice_out(self, outs): + new_outs = {} + for k in self.maps_name: + if k == "thrink_maps": + new_outs[k] = outs[:, 0, :, :] + elif k == "threshold_maps": + new_outs[k] = outs[:, 1, :, :] + elif k == "binary_maps": + new_outs[k] = outs[:, 2, :, :] + else: + continue + return new_outs + + def forward(self, predicts, batch): + loss_dict = dict() + for idx, pair in enumerate(self.model_name_pairs): + out1 = predicts[pair[0]] + out2 = predicts[pair[1]] + if self.key is not None: + out1 = out1[self.key] + out2 = out2[self.key] + + if self.maps_name is None: + if self.multi_head: + loss = super().forward(out1[self.dis_head], + out2[self.dis_head]) + else: + loss = super().forward(out1, out2) + if isinstance(loss, dict): + for key in loss: + loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1], + idx)] = loss[key] + else: + loss_dict["{}_{}".format(self.name, idx)] = loss + else: + outs1 = self._slice_out(out1) + outs2 = self._slice_out(out2) + for _c, k in enumerate(outs1.keys()): + loss = super().forward(outs1[k], outs2[k]) + if isinstance(loss, dict): + for key in loss: + loss_dict["{}_{}_{}_{}_{}".format(key, pair[ + 0], pair[1], self.maps_name, idx)] = loss[key] + else: + loss_dict["{}_{}_{}".format(self.name, self.maps_name[ + _c], idx)] = loss + + loss_dict = _sum_loss(loss_dict) + + return loss_dict + + +class DistillationCTCLoss(CTCLoss): + def __init__(self, + model_name_list=[], + key=None, + multi_head=False, + name="loss_ctc"): + super().__init__() + self.model_name_list = model_name_list + self.key = key + self.name = name + self.multi_head = multi_head + + def forward(self, predicts, batch): + loss_dict = dict() + for idx, model_name in enumerate(self.model_name_list): + out = predicts[model_name] + if self.key is not None: + out = out[self.key] + if self.multi_head: + assert 'ctc' in out, 'multi head has multi out' + loss = super().forward(out['ctc'], batch[:2] + batch[3:]) + else: + loss = super().forward(out, batch) + if isinstance(loss, dict): + for key in loss: + loss_dict["{}_{}_{}".format(self.name, model_name, + idx)] = loss[key] + else: + loss_dict["{}_{}".format(self.name, model_name)] = loss + return loss_dict + + +class DistillationSARLoss(SARLoss): + def __init__(self, + model_name_list=[], + key=None, + multi_head=False, + name="loss_sar", + **kwargs): + ignore_index = kwargs.get('ignore_index', 92) + super().__init__(ignore_index=ignore_index) + self.model_name_list = model_name_list + self.key = key + self.name = name + self.multi_head = multi_head + + def forward(self, predicts, batch): + loss_dict = dict() + for idx, model_name in enumerate(self.model_name_list): + out = predicts[model_name] + if self.key is not None: + out = out[self.key] + if self.multi_head: + assert 'sar' in out, 'multi head has multi out' + loss = super().forward(out['sar'], batch[:1] + batch[2:]) + else: + loss = super().forward(out, batch) + if isinstance(loss, dict): + for key in loss: + loss_dict["{}_{}_{}".format(self.name, model_name, + idx)] = loss[key] + else: + loss_dict["{}_{}".format(self.name, model_name)] = loss + return loss_dict + + +class DistillationDBLoss(DBLoss): + def __init__(self, + model_name_list=[], + balance_loss=True, + main_loss_type='DiceLoss', + alpha=5, + beta=10, + ohem_ratio=3, + eps=1e-6, + name="db", + **kwargs): + super().__init__() + self.model_name_list = model_name_list + self.name = name + self.key = None + + def forward(self, predicts, batch): + loss_dict = {} + for idx, model_name in enumerate(self.model_name_list): + out = predicts[model_name] + if self.key is not None: + out = out[self.key] + loss = super().forward(out, batch) + + if isinstance(loss, dict): + for key in loss.keys(): + if key == "loss": + continue + name = "{}_{}_{}".format(self.name, model_name, key) + loss_dict[name] = loss[key] + else: + loss_dict["{}_{}".format(self.name, model_name)] = loss + + loss_dict = _sum_loss(loss_dict) + return loss_dict + + +class DistillationDilaDBLoss(DBLoss): + def __init__(self, + model_name_pairs=[], + key=None, + balance_loss=True, + main_loss_type='DiceLoss', + alpha=5, + beta=10, + ohem_ratio=3, + eps=1e-6, + name="dila_dbloss"): + super().__init__() + self.model_name_pairs = model_name_pairs + self.name = name + self.key = key + + def forward(self, predicts, batch): + loss_dict = dict() + for idx, pair in enumerate(self.model_name_pairs): + stu_outs = predicts[pair[0]] + tch_outs = predicts[pair[1]] + if self.key is not None: + stu_preds = stu_outs[self.key] + tch_preds = tch_outs[self.key] + + stu_shrink_maps = stu_preds[:, 0, :, :] + stu_binary_maps = stu_preds[:, 2, :, :] + + # dilation to teacher prediction + dilation_w = np.array([[1, 1], [1, 1]]) + th_shrink_maps = tch_preds[:, 0, :, :] + th_shrink_maps = th_shrink_maps.numpy() > 0.3 # thresh = 0.3 + dilate_maps = np.zeros_like(th_shrink_maps).astype(np.float32) + for i in range(th_shrink_maps.shape[0]): + dilate_maps[i] = cv2.dilate( + th_shrink_maps[i, :, :].astype(np.uint8), dilation_w) + th_shrink_maps = paddle.to_tensor(dilate_maps) + + label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = batch[ + 1:] + + # calculate the shrink map loss + bce_loss = self.alpha * self.bce_loss( + stu_shrink_maps, th_shrink_maps, label_shrink_mask) + loss_binary_maps = self.dice_loss(stu_binary_maps, th_shrink_maps, + label_shrink_mask) + + # k = f"{self.name}_{pair[0]}_{pair[1]}" + k = "{}_{}_{}".format(self.name, pair[0], pair[1]) + loss_dict[k] = bce_loss + loss_binary_maps + + loss_dict = _sum_loss(loss_dict) + return loss_dict + + +class DistillationDistanceLoss(DistanceLoss): + """ + """ + + def __init__(self, + mode="l2", + model_name_pairs=[], + key=None, + name="loss_distance", + **kargs): + super().__init__(mode=mode, **kargs) + assert isinstance(model_name_pairs, list) + self.key = key + self.model_name_pairs = model_name_pairs + self.name = name + "_l2" + + def forward(self, predicts, batch): + loss_dict = dict() + for idx, pair in enumerate(self.model_name_pairs): + out1 = predicts[pair[0]] + out2 = predicts[pair[1]] + if self.key is not None: + out1 = out1[self.key] + out2 = out2[self.key] + loss = super().forward(out1, out2) + if isinstance(loss, dict): + for key in loss: + loss_dict["{}_{}_{}".format(self.name, key, idx)] = loss[ + key] + else: + loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1], + idx)] = loss + return loss_dict diff --git a/backend/ppocr/losses/e2e_pg_loss.py b/backend/ppocr/losses/e2e_pg_loss.py new file mode 100644 index 00000000..10a8ed0a --- /dev/null +++ b/backend/ppocr/losses/e2e_pg_loss.py @@ -0,0 +1,140 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from paddle import nn +import paddle + +from .det_basic_loss import DiceLoss +from ppocr.utils.e2e_utils.extract_batchsize import pre_process + + +class PGLoss(nn.Layer): + def __init__(self, + tcl_bs, + max_text_length, + max_text_nums, + pad_num, + eps=1e-6, + **kwargs): + super(PGLoss, self).__init__() + self.tcl_bs = tcl_bs + self.max_text_nums = max_text_nums + self.max_text_length = max_text_length + self.pad_num = pad_num + self.dice_loss = DiceLoss(eps=eps) + + def border_loss(self, f_border, l_border, l_score, l_mask): + l_border_split, l_border_norm = paddle.tensor.split( + l_border, num_or_sections=[4, 1], axis=1) + f_border_split = f_border + b, c, h, w = l_border_norm.shape + l_border_norm_split = paddle.expand( + x=l_border_norm, shape=[b, 4 * c, h, w]) + b, c, h, w = l_score.shape + l_border_score = paddle.expand(x=l_score, shape=[b, 4 * c, h, w]) + b, c, h, w = l_mask.shape + l_border_mask = paddle.expand(x=l_mask, shape=[b, 4 * c, h, w]) + border_diff = l_border_split - f_border_split + abs_border_diff = paddle.abs(border_diff) + border_sign = abs_border_diff < 1.0 + border_sign = paddle.cast(border_sign, dtype='float32') + border_sign.stop_gradient = True + border_in_loss = 0.5 * abs_border_diff * abs_border_diff * border_sign + \ + (abs_border_diff - 0.5) * (1.0 - border_sign) + border_out_loss = l_border_norm_split * border_in_loss + border_loss = paddle.sum(border_out_loss * l_border_score * l_border_mask) / \ + (paddle.sum(l_border_score * l_border_mask) + 1e-5) + return border_loss + + def direction_loss(self, f_direction, l_direction, l_score, l_mask): + l_direction_split, l_direction_norm = paddle.tensor.split( + l_direction, num_or_sections=[2, 1], axis=1) + f_direction_split = f_direction + b, c, h, w = l_direction_norm.shape + l_direction_norm_split = paddle.expand( + x=l_direction_norm, shape=[b, 2 * c, h, w]) + b, c, h, w = l_score.shape + l_direction_score = paddle.expand(x=l_score, shape=[b, 2 * c, h, w]) + b, c, h, w = l_mask.shape + l_direction_mask = paddle.expand(x=l_mask, shape=[b, 2 * c, h, w]) + direction_diff = l_direction_split - f_direction_split + abs_direction_diff = paddle.abs(direction_diff) + direction_sign = abs_direction_diff < 1.0 + direction_sign = paddle.cast(direction_sign, dtype='float32') + direction_sign.stop_gradient = True + direction_in_loss = 0.5 * abs_direction_diff * abs_direction_diff * direction_sign + \ + (abs_direction_diff - 0.5) * (1.0 - direction_sign) + direction_out_loss = l_direction_norm_split * direction_in_loss + direction_loss = paddle.sum(direction_out_loss * l_direction_score * l_direction_mask) / \ + (paddle.sum(l_direction_score * l_direction_mask) + 1e-5) + return direction_loss + + def ctcloss(self, f_char, tcl_pos, tcl_mask, tcl_label, label_t): + f_char = paddle.transpose(f_char, [0, 2, 3, 1]) + tcl_pos = paddle.reshape(tcl_pos, [-1, 3]) + tcl_pos = paddle.cast(tcl_pos, dtype=int) + f_tcl_char = paddle.gather_nd(f_char, tcl_pos) + f_tcl_char = paddle.reshape(f_tcl_char, + [-1, 64, 37]) # len(Lexicon_Table)+1 + f_tcl_char_fg, f_tcl_char_bg = paddle.split(f_tcl_char, [36, 1], axis=2) + f_tcl_char_bg = f_tcl_char_bg * tcl_mask + (1.0 - tcl_mask) * 20.0 + b, c, l = tcl_mask.shape + tcl_mask_fg = paddle.expand(x=tcl_mask, shape=[b, c, 36 * l]) + tcl_mask_fg.stop_gradient = True + f_tcl_char_fg = f_tcl_char_fg * tcl_mask_fg + (1.0 - tcl_mask_fg) * ( + -20.0) + f_tcl_char_mask = paddle.concat([f_tcl_char_fg, f_tcl_char_bg], axis=2) + f_tcl_char_ld = paddle.transpose(f_tcl_char_mask, (1, 0, 2)) + N, B, _ = f_tcl_char_ld.shape + input_lengths = paddle.to_tensor([N] * B, dtype='int64') + cost = paddle.nn.functional.ctc_loss( + log_probs=f_tcl_char_ld, + labels=tcl_label, + input_lengths=input_lengths, + label_lengths=label_t, + blank=self.pad_num, + reduction='none') + cost = cost.mean() + return cost + + def forward(self, predicts, labels): + images, tcl_maps, tcl_label_maps, border_maps \ + , direction_maps, training_masks, label_list, pos_list, pos_mask = labels + # for all the batch_size + pos_list, pos_mask, label_list, label_t = pre_process( + label_list, pos_list, pos_mask, self.max_text_length, + self.max_text_nums, self.pad_num, self.tcl_bs) + + f_score, f_border, f_direction, f_char = predicts['f_score'], predicts['f_border'], predicts['f_direction'], \ + predicts['f_char'] + score_loss = self.dice_loss(f_score, tcl_maps, training_masks) + border_loss = self.border_loss(f_border, border_maps, tcl_maps, + training_masks) + direction_loss = self.direction_loss(f_direction, direction_maps, + tcl_maps, training_masks) + ctc_loss = self.ctcloss(f_char, pos_list, pos_mask, label_list, label_t) + loss_all = score_loss + border_loss + direction_loss + 5 * ctc_loss + + losses = { + 'loss': loss_all, + "score_loss": score_loss, + "border_loss": border_loss, + "direction_loss": direction_loss, + "ctc_loss": ctc_loss + } + return losses diff --git a/backend/ppocr/losses/kie_sdmgr_loss.py b/backend/ppocr/losses/kie_sdmgr_loss.py new file mode 100644 index 00000000..745671f5 --- /dev/null +++ b/backend/ppocr/losses/kie_sdmgr_loss.py @@ -0,0 +1,115 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# reference from : https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/kie/losses/sdmgr_loss.py + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from paddle import nn +import paddle + + +class SDMGRLoss(nn.Layer): + def __init__(self, node_weight=1.0, edge_weight=1.0, ignore=0): + super().__init__() + self.loss_node = nn.CrossEntropyLoss(ignore_index=ignore) + self.loss_edge = nn.CrossEntropyLoss(ignore_index=-1) + self.node_weight = node_weight + self.edge_weight = edge_weight + self.ignore = ignore + + def pre_process(self, gts, tag): + gts, tag = gts.numpy(), tag.numpy().tolist() + temp_gts = [] + batch = len(tag) + for i in range(batch): + num, recoder_len = tag[i][0], tag[i][1] + temp_gts.append( + paddle.to_tensor( + gts[i, :num, :num + 1], dtype='int64')) + return temp_gts + + def accuracy(self, pred, target, topk=1, thresh=None): + """Calculate accuracy according to the prediction and target. + + Args: + pred (torch.Tensor): The model prediction, shape (N, num_class) + target (torch.Tensor): The target of each prediction, shape (N, ) + topk (int | tuple[int], optional): If the predictions in ``topk`` + matches the target, the predictions will be regarded as + correct ones. Defaults to 1. + thresh (float, optional): If not None, predictions with scores under + this threshold are considered incorrect. Default to None. + + Returns: + float | tuple[float]: If the input ``topk`` is a single integer, + the function will return a single float as accuracy. If + ``topk`` is a tuple containing multiple integers, the + function will return a tuple containing accuracies of + each ``topk`` number. + """ + assert isinstance(topk, (int, tuple)) + if isinstance(topk, int): + topk = (topk, ) + return_single = True + else: + return_single = False + + maxk = max(topk) + if pred.shape[0] == 0: + accu = [pred.new_tensor(0.) for i in range(len(topk))] + return accu[0] if return_single else accu + pred_value, pred_label = paddle.topk(pred, maxk, axis=1) + pred_label = pred_label.transpose( + [1, 0]) # transpose to shape (maxk, N) + correct = paddle.equal(pred_label, + (target.reshape([1, -1]).expand_as(pred_label))) + res = [] + for k in topk: + correct_k = paddle.sum(correct[:k].reshape([-1]).astype('float32'), + axis=0, + keepdim=True) + res.append( + paddle.multiply(correct_k, + paddle.to_tensor(100.0 / pred.shape[0]))) + return res[0] if return_single else res + + def forward(self, pred, batch): + node_preds, edge_preds = pred + gts, tag = batch[4], batch[5] + gts = self.pre_process(gts, tag) + node_gts, edge_gts = [], [] + for gt in gts: + node_gts.append(gt[:, 0]) + edge_gts.append(gt[:, 1:].reshape([-1])) + node_gts = paddle.concat(node_gts) + edge_gts = paddle.concat(edge_gts) + + node_valids = paddle.nonzero(node_gts != self.ignore).reshape([-1]) + edge_valids = paddle.nonzero(edge_gts != -1).reshape([-1]) + loss_node = self.loss_node(node_preds, node_gts) + loss_edge = self.loss_edge(edge_preds, edge_gts) + loss = self.node_weight * loss_node + self.edge_weight * loss_edge + return dict( + loss=loss, + loss_node=loss_node, + loss_edge=loss_edge, + acc_node=self.accuracy( + paddle.gather(node_preds, node_valids), + paddle.gather(node_gts, node_valids)), + acc_edge=self.accuracy( + paddle.gather(edge_preds, edge_valids), + paddle.gather(edge_gts, edge_valids))) diff --git a/backend/ppocr/losses/rec_aster_loss.py b/backend/ppocr/losses/rec_aster_loss.py new file mode 100644 index 00000000..fbb99d29 --- /dev/null +++ b/backend/ppocr/losses/rec_aster_loss.py @@ -0,0 +1,99 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn + + +class CosineEmbeddingLoss(nn.Layer): + def __init__(self, margin=0.): + super(CosineEmbeddingLoss, self).__init__() + self.margin = margin + self.epsilon = 1e-12 + + def forward(self, x1, x2, target): + similarity = paddle.fluid.layers.reduce_sum( + x1 * x2, dim=-1) / (paddle.norm( + x1, axis=-1) * paddle.norm( + x2, axis=-1) + self.epsilon) + one_list = paddle.full_like(target, fill_value=1) + out = paddle.fluid.layers.reduce_mean( + paddle.where( + paddle.equal(target, one_list), 1. - similarity, + paddle.maximum( + paddle.zeros_like(similarity), similarity - self.margin))) + + return out + + +class AsterLoss(nn.Layer): + def __init__(self, + weight=None, + size_average=True, + ignore_index=-100, + sequence_normalize=False, + sample_normalize=True, + **kwargs): + super(AsterLoss, self).__init__() + self.weight = weight + self.size_average = size_average + self.ignore_index = ignore_index + self.sequence_normalize = sequence_normalize + self.sample_normalize = sample_normalize + self.loss_sem = CosineEmbeddingLoss() + self.is_cosin_loss = True + self.loss_func_rec = nn.CrossEntropyLoss(weight=None, reduction='none') + + def forward(self, predicts, batch): + targets = batch[1].astype("int64") + label_lengths = batch[2].astype('int64') + sem_target = batch[3].astype('float32') + embedding_vectors = predicts['embedding_vectors'] + rec_pred = predicts['rec_pred'] + + if not self.is_cosin_loss: + sem_loss = paddle.sum(self.loss_sem(embedding_vectors, sem_target)) + else: + label_target = paddle.ones([embedding_vectors.shape[0]]) + sem_loss = paddle.sum( + self.loss_sem(embedding_vectors, sem_target, label_target)) + + # rec loss + batch_size, def_max_length = targets.shape[0], targets.shape[1] + + mask = paddle.zeros([batch_size, def_max_length]) + for i in range(batch_size): + mask[i, :label_lengths[i]] = 1 + mask = paddle.cast(mask, "float32") + max_length = max(label_lengths) + assert max_length == rec_pred.shape[1] + targets = targets[:, :max_length] + mask = mask[:, :max_length] + rec_pred = paddle.reshape(rec_pred, [-1, rec_pred.shape[2]]) + input = nn.functional.log_softmax(rec_pred, axis=1) + targets = paddle.reshape(targets, [-1, 1]) + mask = paddle.reshape(mask, [-1, 1]) + output = -paddle.index_sample(input, index=targets) * mask + output = paddle.sum(output) + if self.sequence_normalize: + output = output / paddle.sum(mask) + if self.sample_normalize: + output = output / batch_size + + loss = output + sem_loss * 0.1 + return {'loss': loss} diff --git a/backend/ppocr/losses/rec_ctc_loss.py b/backend/ppocr/losses/rec_ctc_loss.py index 425de587..502fc8c5 100755 --- a/backend/ppocr/losses/rec_ctc_loss.py +++ b/backend/ppocr/losses/rec_ctc_loss.py @@ -21,16 +21,25 @@ class CTCLoss(nn.Layer): - def __init__(self, **kwargs): + def __init__(self, use_focal_loss=False, **kwargs): super(CTCLoss, self).__init__() self.loss_func = nn.CTCLoss(blank=0, reduction='none') + self.use_focal_loss = use_focal_loss - def __call__(self, predicts, batch): + def forward(self, predicts, batch): + if isinstance(predicts, (list, tuple)): + predicts = predicts[-1] predicts = predicts.transpose((1, 0, 2)) N, B, _ = predicts.shape - preds_lengths = paddle.to_tensor([N] * B, dtype='int64') + preds_lengths = paddle.to_tensor( + [N] * B, dtype='int64', place=paddle.CPUPlace()) labels = batch[1].astype("int32") label_lengths = batch[2].astype('int64') loss = self.loss_func(predicts, labels, preds_lengths, label_lengths) - loss = loss.mean() # sum + if self.use_focal_loss: + weight = paddle.exp(-loss) + weight = paddle.subtract(paddle.to_tensor([1.0]), weight) + weight = paddle.square(weight) + loss = paddle.multiply(loss, weight) + loss = loss.mean() return {'loss': loss} diff --git a/backend/ppocr/losses/rec_enhanced_ctc_loss.py b/backend/ppocr/losses/rec_enhanced_ctc_loss.py new file mode 100644 index 00000000..b57be646 --- /dev/null +++ b/backend/ppocr/losses/rec_enhanced_ctc_loss.py @@ -0,0 +1,70 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn +from .ace_loss import ACELoss +from .center_loss import CenterLoss +from .rec_ctc_loss import CTCLoss + + +class EnhancedCTCLoss(nn.Layer): + def __init__(self, + use_focal_loss=False, + use_ace_loss=False, + ace_loss_weight=0.1, + use_center_loss=False, + center_loss_weight=0.05, + num_classes=6625, + feat_dim=96, + init_center=False, + center_file_path=None, + **kwargs): + super(EnhancedCTCLoss, self).__init__() + self.ctc_loss_func = CTCLoss(use_focal_loss=use_focal_loss) + + self.use_ace_loss = False + if use_ace_loss: + self.use_ace_loss = use_ace_loss + self.ace_loss_func = ACELoss() + self.ace_loss_weight = ace_loss_weight + + self.use_center_loss = False + if use_center_loss: + self.use_center_loss = use_center_loss + self.center_loss_func = CenterLoss( + num_classes=num_classes, + feat_dim=feat_dim, + init_center=init_center, + center_file_path=center_file_path) + self.center_loss_weight = center_loss_weight + + def __call__(self, predicts, batch): + loss = self.ctc_loss_func(predicts, batch)["loss"] + + if self.use_center_loss: + center_loss = self.center_loss_func( + predicts, batch)["loss_center"] * self.center_loss_weight + loss = loss + center_loss + + if self.use_ace_loss: + ace_loss = self.ace_loss_func( + predicts, batch)["loss_ace"] * self.ace_loss_weight + loss = loss + ace_loss + + return {'enhanced_ctc_loss': loss} diff --git a/backend/ppocr/losses/rec_multi_loss.py b/backend/ppocr/losses/rec_multi_loss.py new file mode 100644 index 00000000..09f007af --- /dev/null +++ b/backend/ppocr/losses/rec_multi_loss.py @@ -0,0 +1,58 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn + +from .rec_ctc_loss import CTCLoss +from .rec_sar_loss import SARLoss + + +class MultiLoss(nn.Layer): + def __init__(self, **kwargs): + super().__init__() + self.loss_funcs = {} + self.loss_list = kwargs.pop('loss_config_list') + self.weight_1 = kwargs.get('weight_1', 1.0) + self.weight_2 = kwargs.get('weight_2', 1.0) + self.gtc_loss = kwargs.get('gtc_loss', 'sar') + for loss_info in self.loss_list: + for name, param in loss_info.items(): + if param is not None: + kwargs.update(param) + loss = eval(name)(**kwargs) + self.loss_funcs[name] = loss + + def forward(self, predicts, batch): + self.total_loss = {} + total_loss = 0.0 + # batch [image, label_ctc, label_sar, length, valid_ratio] + for name, loss_func in self.loss_funcs.items(): + if name == 'CTCLoss': + loss = loss_func(predicts['ctc'], + batch[:2] + batch[3:])['loss'] * self.weight_1 + elif name == 'SARLoss': + loss = loss_func(predicts['sar'], + batch[:1] + batch[2:])['loss'] * self.weight_2 + else: + raise NotImplementedError( + '{} is not supported in MultiLoss yet'.format(name)) + self.total_loss[name] = loss + total_loss += loss + self.total_loss['loss'] = total_loss + return self.total_loss diff --git a/backend/ppocr/losses/rec_nrtr_loss.py b/backend/ppocr/losses/rec_nrtr_loss.py new file mode 100644 index 00000000..200a6d04 --- /dev/null +++ b/backend/ppocr/losses/rec_nrtr_loss.py @@ -0,0 +1,30 @@ +import paddle +from paddle import nn +import paddle.nn.functional as F + + +class NRTRLoss(nn.Layer): + def __init__(self, smoothing=True, **kwargs): + super(NRTRLoss, self).__init__() + self.loss_func = nn.CrossEntropyLoss(reduction='mean', ignore_index=0) + self.smoothing = smoothing + + def forward(self, pred, batch): + pred = pred.reshape([-1, pred.shape[2]]) + max_len = batch[2].max() + tgt = batch[1][:, 1:2 + max_len] + tgt = tgt.reshape([-1]) + if self.smoothing: + eps = 0.1 + n_class = pred.shape[1] + one_hot = F.one_hot(tgt, pred.shape[1]) + one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) + log_prb = F.log_softmax(pred, axis=1) + non_pad_mask = paddle.not_equal( + tgt, paddle.zeros( + tgt.shape, dtype=tgt.dtype)) + loss = -(one_hot * log_prb).sum(axis=1) + loss = loss.masked_select(non_pad_mask).mean() + else: + loss = self.loss_func(pred, tgt) + return {'loss': loss} diff --git a/backend/ppocr/losses/rec_pren_loss.py b/backend/ppocr/losses/rec_pren_loss.py new file mode 100644 index 00000000..7bc53d29 --- /dev/null +++ b/backend/ppocr/losses/rec_pren_loss.py @@ -0,0 +1,30 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from paddle import nn + + +class PRENLoss(nn.Layer): + def __init__(self, **kwargs): + super(PRENLoss, self).__init__() + # note: 0 is padding idx + self.loss_func = nn.CrossEntropyLoss(reduction='mean', ignore_index=0) + + def forward(self, predicts, batch): + loss = self.loss_func(predicts, batch[1].astype('int64')) + return {'loss': loss} diff --git a/backend/ppocr/losses/rec_sar_loss.py b/backend/ppocr/losses/rec_sar_loss.py new file mode 100644 index 00000000..a4f83f03 --- /dev/null +++ b/backend/ppocr/losses/rec_sar_loss.py @@ -0,0 +1,29 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn + + +class SARLoss(nn.Layer): + def __init__(self, **kwargs): + super(SARLoss, self).__init__() + ignore_index = kwargs.get('ignore_index', 92) # 6626 + self.loss_func = paddle.nn.loss.CrossEntropyLoss( + reduction="mean", ignore_index=ignore_index) + + def forward(self, predicts, batch): + predict = predicts[:, : + -1, :] # ignore last index of outputs to be in same seq_len with targets + label = batch[1].astype( + "int64")[:, 1:] # ignore first index of target in loss calculation + batch_size, num_steps, num_classes = predict.shape[0], predict.shape[ + 1], predict.shape[2] + assert len(label.shape) == len(list(predict.shape)) - 1, \ + "The target's shape and inputs's shape is [N, d] and [N, num_steps]" + + inputs = paddle.reshape(predict, [-1, num_classes]) + targets = paddle.reshape(label, [-1]) + loss = self.loss_func(inputs, targets) + return {'loss': loss} diff --git a/backend/ppocr/losses/table_att_loss.py b/backend/ppocr/losses/table_att_loss.py new file mode 100644 index 00000000..d7fd99e6 --- /dev/null +++ b/backend/ppocr/losses/table_att_loss.py @@ -0,0 +1,109 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn +from paddle.nn import functional as F +from paddle import fluid + +class TableAttentionLoss(nn.Layer): + def __init__(self, structure_weight, loc_weight, use_giou=False, giou_weight=1.0, **kwargs): + super(TableAttentionLoss, self).__init__() + self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none') + self.structure_weight = structure_weight + self.loc_weight = loc_weight + self.use_giou = use_giou + self.giou_weight = giou_weight + + def giou_loss(self, preds, bbox, eps=1e-7, reduction='mean'): + ''' + :param preds:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,] + :param bbox:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,] + :return: loss + ''' + ix1 = fluid.layers.elementwise_max(preds[:, 0], bbox[:, 0]) + iy1 = fluid.layers.elementwise_max(preds[:, 1], bbox[:, 1]) + ix2 = fluid.layers.elementwise_min(preds[:, 2], bbox[:, 2]) + iy2 = fluid.layers.elementwise_min(preds[:, 3], bbox[:, 3]) + + iw = fluid.layers.clip(ix2 - ix1 + 1e-3, 0., 1e10) + ih = fluid.layers.clip(iy2 - iy1 + 1e-3, 0., 1e10) + + # overlap + inters = iw * ih + + # union + uni = (preds[:, 2] - preds[:, 0] + 1e-3) * (preds[:, 3] - preds[:, 1] + 1e-3 + ) + (bbox[:, 2] - bbox[:, 0] + 1e-3) * ( + bbox[:, 3] - bbox[:, 1] + 1e-3) - inters + eps + + # ious + ious = inters / uni + + ex1 = fluid.layers.elementwise_min(preds[:, 0], bbox[:, 0]) + ey1 = fluid.layers.elementwise_min(preds[:, 1], bbox[:, 1]) + ex2 = fluid.layers.elementwise_max(preds[:, 2], bbox[:, 2]) + ey2 = fluid.layers.elementwise_max(preds[:, 3], bbox[:, 3]) + ew = fluid.layers.clip(ex2 - ex1 + 1e-3, 0., 1e10) + eh = fluid.layers.clip(ey2 - ey1 + 1e-3, 0., 1e10) + + # enclose erea + enclose = ew * eh + eps + giou = ious - (enclose - uni) / enclose + + loss = 1 - giou + + if reduction == 'mean': + loss = paddle.mean(loss) + elif reduction == 'sum': + loss = paddle.sum(loss) + else: + raise NotImplementedError + return loss + + def forward(self, predicts, batch): + structure_probs = predicts['structure_probs'] + structure_targets = batch[1].astype("int64") + structure_targets = structure_targets[:, 1:] + if len(batch) == 6: + structure_mask = batch[5].astype("int64") + structure_mask = structure_mask[:, 1:] + structure_mask = paddle.reshape(structure_mask, [-1]) + structure_probs = paddle.reshape(structure_probs, [-1, structure_probs.shape[-1]]) + structure_targets = paddle.reshape(structure_targets, [-1]) + structure_loss = self.loss_func(structure_probs, structure_targets) + + if len(batch) == 6: + structure_loss = structure_loss * structure_mask + +# structure_loss = paddle.sum(structure_loss) * self.structure_weight + structure_loss = paddle.mean(structure_loss) * self.structure_weight + + loc_preds = predicts['loc_preds'] + loc_targets = batch[2].astype("float32") + loc_targets_mask = batch[4].astype("float32") + loc_targets = loc_targets[:, 1:, :] + loc_targets_mask = loc_targets_mask[:, 1:, :] + loc_loss = F.mse_loss(loc_preds * loc_targets_mask, loc_targets) * self.loc_weight + if self.use_giou: + loc_loss_giou = self.giou_loss(loc_preds * loc_targets_mask, loc_targets) * self.giou_weight + total_loss = structure_loss + loc_loss + loc_loss_giou + return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss, "loc_loss_giou":loc_loss_giou} + else: + total_loss = structure_loss + loc_loss + return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss} \ No newline at end of file diff --git a/backend/ppocr/losses/vqa_token_layoutlm_loss.py b/backend/ppocr/losses/vqa_token_layoutlm_loss.py new file mode 100755 index 00000000..244893d9 --- /dev/null +++ b/backend/ppocr/losses/vqa_token_layoutlm_loss.py @@ -0,0 +1,42 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from paddle import nn + + +class VQASerTokenLayoutLMLoss(nn.Layer): + def __init__(self, num_classes): + super().__init__() + self.loss_class = nn.CrossEntropyLoss() + self.num_classes = num_classes + self.ignore_index = self.loss_class.ignore_index + + def forward(self, predicts, batch): + labels = batch[1] + attention_mask = batch[4] + if attention_mask is not None: + active_loss = attention_mask.reshape([-1, ]) == 1 + active_outputs = predicts.reshape( + [-1, self.num_classes])[active_loss] + active_labels = labels.reshape([-1, ])[active_loss] + loss = self.loss_class(active_outputs, active_labels) + else: + loss = self.loss_class( + predicts.reshape([-1, self.num_classes]), + labels.reshape([-1, ])) + return {'loss': loss} diff --git a/backend/ppocr/metrics/__init__.py b/backend/ppocr/metrics/__init__.py index a0e7d912..c244066c 100644 --- a/backend/ppocr/metrics/__init__.py +++ b/backend/ppocr/metrics/__init__.py @@ -19,19 +19,29 @@ import copy -__all__ = ['build_metric'] +__all__ = ["build_metric"] +from .det_metric import DetMetric, DetFCEMetric +from .rec_metric import RecMetric +from .cls_metric import ClsMetric +from .e2e_metric import E2EMetric +from .distillation_metric import DistillationMetric +from .table_metric import TableMetric +from .kie_metric import KIEMetric +from .vqa_token_ser_metric import VQASerTokenMetric +from .vqa_token_re_metric import VQAReTokenMetric -def build_metric(config): - from .det_metric import DetMetric - from .rec_metric import RecMetric - from .cls_metric import ClsMetric - support_dict = ['DetMetric', 'RecMetric', 'ClsMetric'] +def build_metric(config): + support_dict = [ + "DetMetric", "DetFCEMetric", "RecMetric", "ClsMetric", "E2EMetric", + "DistillationMetric", "TableMetric", 'KIEMetric', 'VQASerTokenMetric', + 'VQAReTokenMetric' + ] config = copy.deepcopy(config) - module_name = config.pop('name') + module_name = config.pop("name") assert module_name in support_dict, Exception( - 'metric only support {}'.format(support_dict)) + "metric only support {}".format(support_dict)) module_class = eval(module_name)(**config) return module_class diff --git a/backend/ppocr/metrics/cls_metric.py b/backend/ppocr/metrics/cls_metric.py index 09817200..6c077518 100644 --- a/backend/ppocr/metrics/cls_metric.py +++ b/backend/ppocr/metrics/cls_metric.py @@ -16,6 +16,7 @@ class ClsMetric(object): def __init__(self, main_indicator='acc', **kwargs): self.main_indicator = main_indicator + self.eps = 1e-5 self.reset() def __call__(self, pred_label, *args, **kwargs): @@ -28,7 +29,7 @@ def __call__(self, pred_label, *args, **kwargs): all_num += 1 self.correct_num += correct_num self.all_num += all_num - return {'acc': correct_num / all_num, } + return {'acc': correct_num / (all_num + self.eps), } def get_metric(self): """ @@ -36,7 +37,7 @@ def get_metric(self): 'acc': 0 } """ - acc = self.correct_num / self.all_num + acc = self.correct_num / (self.all_num + self.eps) self.reset() return {'acc': acc} diff --git a/backend/ppocr/metrics/det_metric.py b/backend/ppocr/metrics/det_metric.py index 0f9e94df..dca94c09 100644 --- a/backend/ppocr/metrics/det_metric.py +++ b/backend/ppocr/metrics/det_metric.py @@ -16,7 +16,7 @@ from __future__ import division from __future__ import print_function -__all__ = ['DetMetric'] +__all__ = ['DetMetric', 'DetFCEMetric'] from .eval_det_iou import DetectionIoUEvaluator @@ -64,9 +64,91 @@ def get_metric(self): } """ - metircs = self.evaluator.combine_results(self.results) + metrics = self.evaluator.combine_results(self.results) self.reset() - return metircs + return metrics def reset(self): self.results = [] # clear results + + +class DetFCEMetric(object): + def __init__(self, main_indicator='hmean', **kwargs): + self.evaluator = DetectionIoUEvaluator() + self.main_indicator = main_indicator + self.reset() + + def __call__(self, preds, batch, **kwargs): + ''' + batch: a list produced by dataloaders. + image: np.ndarray of shape (N, C, H, W). + ratio_list: np.ndarray of shape(N,2) + polygons: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions. + ignore_tags: np.ndarray of shape (N, K), indicates whether a region is ignorable or not. + preds: a list of dict produced by post process + points: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions. + ''' + gt_polyons_batch = batch[2] + ignore_tags_batch = batch[3] + + for pred, gt_polyons, ignore_tags in zip(preds, gt_polyons_batch, + ignore_tags_batch): + # prepare gt + gt_info_list = [{ + 'points': gt_polyon, + 'text': '', + 'ignore': ignore_tag + } for gt_polyon, ignore_tag in zip(gt_polyons, ignore_tags)] + # prepare det + det_info_list = [{ + 'points': det_polyon, + 'text': '', + 'score': score + } for det_polyon, score in zip(pred['points'], pred['scores'])] + + for score_thr in self.results.keys(): + det_info_list_thr = [ + det_info for det_info in det_info_list + if det_info['score'] >= score_thr + ] + result = self.evaluator.evaluate_image(gt_info_list, + det_info_list_thr) + self.results[score_thr].append(result) + + def get_metric(self): + """ + return metrics {'heman':0, + 'thr 0.3':'precision: 0 recall: 0 hmean: 0', + 'thr 0.4':'precision: 0 recall: 0 hmean: 0', + 'thr 0.5':'precision: 0 recall: 0 hmean: 0', + 'thr 0.6':'precision: 0 recall: 0 hmean: 0', + 'thr 0.7':'precision: 0 recall: 0 hmean: 0', + 'thr 0.8':'precision: 0 recall: 0 hmean: 0', + 'thr 0.9':'precision: 0 recall: 0 hmean: 0', + } + """ + metrics = {} + hmean = 0 + for score_thr in self.results.keys(): + metric = self.evaluator.combine_results(self.results[score_thr]) + # for key, value in metric.items(): + # metrics['{}_{}'.format(key, score_thr)] = value + metric_str = 'precision:{:.5f} recall:{:.5f} hmean:{:.5f}'.format( + metric['precision'], metric['recall'], metric['hmean']) + metrics['thr {}'.format(score_thr)] = metric_str + hmean = max(hmean, metric['hmean']) + metrics['hmean'] = hmean + + self.reset() + return metrics + + def reset(self): + self.results = { + 0.3: [], + 0.4: [], + 0.5: [], + 0.6: [], + 0.7: [], + 0.8: [], + 0.9: [] + } # clear results diff --git a/backend/ppocr/metrics/distillation_metric.py b/backend/ppocr/metrics/distillation_metric.py new file mode 100644 index 00000000..c440cebd --- /dev/null +++ b/backend/ppocr/metrics/distillation_metric.py @@ -0,0 +1,73 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import copy + +from .rec_metric import RecMetric +from .det_metric import DetMetric +from .e2e_metric import E2EMetric +from .cls_metric import ClsMetric + + +class DistillationMetric(object): + def __init__(self, + key=None, + base_metric_name=None, + main_indicator=None, + **kwargs): + self.main_indicator = main_indicator + self.key = key + self.main_indicator = main_indicator + self.base_metric_name = base_metric_name + self.kwargs = kwargs + self.metrics = None + + def _init_metrcis(self, preds): + self.metrics = dict() + mod = importlib.import_module(__name__) + for key in preds: + self.metrics[key] = getattr(mod, self.base_metric_name)( + main_indicator=self.main_indicator, **self.kwargs) + self.metrics[key].reset() + + def __call__(self, preds, batch, **kwargs): + assert isinstance(preds, dict) + if self.metrics is None: + self._init_metrcis(preds) + output = dict() + for key in preds: + self.metrics[key].__call__(preds[key], batch, **kwargs) + + def get_metric(self): + """ + return metrics { + 'acc': 0, + 'norm_edit_dis': 0, + } + """ + output = dict() + for key in self.metrics: + metric = self.metrics[key].get_metric() + # main indicator + if key == self.key: + output.update(metric) + else: + for sub_key in metric: + output["{}_{}".format(key, sub_key)] = metric[sub_key] + return output + + def reset(self): + for key in self.metrics: + self.metrics[key].reset() diff --git a/backend/ppocr/metrics/e2e_metric.py b/backend/ppocr/metrics/e2e_metric.py new file mode 100644 index 00000000..2f8ba3b2 --- /dev/null +++ b/backend/ppocr/metrics/e2e_metric.py @@ -0,0 +1,86 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +__all__ = ['E2EMetric'] + +from ppocr.utils.e2e_metric.Deteval import get_socre_A, get_socre_B, combine_results +from ppocr.utils.e2e_utils.extract_textpoint_slow import get_dict + + +class E2EMetric(object): + def __init__(self, + mode, + gt_mat_dir, + character_dict_path, + main_indicator='f_score_e2e', + **kwargs): + self.mode = mode + self.gt_mat_dir = gt_mat_dir + self.label_list = get_dict(character_dict_path) + self.max_index = len(self.label_list) + self.main_indicator = main_indicator + self.reset() + + def __call__(self, preds, batch, **kwargs): + if self.mode == 'A': + gt_polyons_batch = batch[2] + temp_gt_strs_batch = batch[3][0] + ignore_tags_batch = batch[4] + gt_strs_batch = [] + + for temp_list in temp_gt_strs_batch: + t = "" + for index in temp_list: + if index < self.max_index: + t += self.label_list[index] + gt_strs_batch.append(t) + + for pred, gt_polyons, gt_strs, ignore_tags in zip( + [preds], gt_polyons_batch, [gt_strs_batch], ignore_tags_batch): + # prepare gt + gt_info_list = [{ + 'points': gt_polyon, + 'text': gt_str, + 'ignore': ignore_tag + } for gt_polyon, gt_str, ignore_tag in + zip(gt_polyons, gt_strs, ignore_tags)] + # prepare det + e2e_info_list = [{ + 'points': det_polyon, + 'texts': pred_str + } for det_polyon, pred_str in + zip(pred['points'], pred['texts'])] + + result = get_socre_A(gt_info_list, e2e_info_list) + self.results.append(result) + else: + img_id = batch[5][0] + e2e_info_list = [{ + 'points': det_polyon, + 'texts': pred_str + } for det_polyon, pred_str in zip(preds['points'], preds['texts'])] + result = get_socre_B(self.gt_mat_dir, img_id, e2e_info_list) + self.results.append(result) + + def get_metric(self): + metrics = combine_results(self.results) + self.reset() + return metrics + + def reset(self): + self.results = [] # clear results diff --git a/backend/ppocr/metrics/eval_det_iou.py b/backend/ppocr/metrics/eval_det_iou.py index a2a3f418..bc05e7df 100644 --- a/backend/ppocr/metrics/eval_det_iou.py +++ b/backend/ppocr/metrics/eval_det_iou.py @@ -150,7 +150,7 @@ def compute_ap(confList, matchList, numGtCare): pairs.append({'gt': gtNum, 'det': detNum}) detMatchedNums.append(detNum) evaluationLog += "Match GT #" + \ - str(gtNum) + " with Det #" + str(detNum) + "\n" + str(gtNum) + " with Det #" + str(detNum) + "\n" numGtCare = (len(gtPols) - len(gtDontCarePolsNum)) numDetCare = (len(detPols) - len(detDontCarePolsNum)) @@ -162,28 +162,17 @@ def compute_ap(confList, matchList, numGtCare): precision = 0 if numDetCare == 0 else float(detMatched) / numDetCare hmean = 0 if (precision + recall) == 0 else 2.0 * \ - precision * recall / (precision + recall) + precision * recall / (precision + recall) matchedSum += detMatched numGlobalCareGt += numGtCare numGlobalCareDet += numDetCare perSampleMetrics = { - 'precision': precision, - 'recall': recall, - 'hmean': hmean, - 'pairs': pairs, - 'iouMat': [] if len(detPols) > 100 else iouMat.tolist(), - 'gtPolPoints': gtPolPoints, - 'detPolPoints': detPolPoints, 'gtCare': numGtCare, 'detCare': numDetCare, - 'gtDontCare': gtDontCarePolsNum, - 'detDontCare': detDontCarePolsNum, 'detMatched': detMatched, - 'evaluationLog': evaluationLog } - return perSampleMetrics def combine_results(self, results): @@ -200,7 +189,8 @@ def combine_results(self, results): methodPrecision = 0 if numGlobalCareDet == 0 else float( matchedSum) / numGlobalCareDet methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \ - methodRecall * methodPrecision / (methodRecall + methodPrecision) + methodRecall * methodPrecision / ( + methodRecall + methodPrecision) # print(methodRecall, methodPrecision, methodHmean) # sys.exit(-1) methodMetrics = { diff --git a/backend/ppocr/metrics/kie_metric.py b/backend/ppocr/metrics/kie_metric.py new file mode 100644 index 00000000..28ab22b8 --- /dev/null +++ b/backend/ppocr/metrics/kie_metric.py @@ -0,0 +1,71 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# The code is refer from: https://github.com/open-mmlab/mmocr/blob/main/mmocr/core/evaluation/kie_metric.py + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import paddle + +__all__ = ['KIEMetric'] + + +class KIEMetric(object): + def __init__(self, main_indicator='hmean', **kwargs): + self.main_indicator = main_indicator + self.reset() + self.node = [] + self.gt = [] + + def __call__(self, preds, batch, **kwargs): + nodes, _ = preds + gts, tag = batch[4].squeeze(0), batch[5].tolist()[0] + gts = gts[:tag[0], :1].reshape([-1]) + self.node.append(nodes.numpy()) + self.gt.append(gts) + # result = self.compute_f1_score(nodes, gts) + # self.results.append(result) + + def compute_f1_score(self, preds, gts): + ignores = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 25] + C = preds.shape[1] + classes = np.array(sorted(set(range(C)) - set(ignores))) + hist = np.bincount( + (gts * C).astype('int64') + preds.argmax(1), minlength=C + **2).reshape([C, C]).astype('float32') + diag = np.diag(hist) + recalls = diag / hist.sum(1).clip(min=1) + precisions = diag / hist.sum(0).clip(min=1) + f1 = 2 * recalls * precisions / (recalls + precisions).clip(min=1e-8) + return f1[classes] + + def combine_results(self, results): + node = np.concatenate(self.node, 0) + gts = np.concatenate(self.gt, 0) + results = self.compute_f1_score(node, gts) + data = {'hmean': results.mean()} + return data + + def get_metric(self): + + metrics = self.combine_results(self.results) + self.reset() + return metrics + + def reset(self): + self.results = [] # clear results + self.node = [] + self.gt = [] diff --git a/backend/ppocr/metrics/rec_metric.py b/backend/ppocr/metrics/rec_metric.py index 66c084d7..515b9372 100644 --- a/backend/ppocr/metrics/rec_metric.py +++ b/backend/ppocr/metrics/rec_metric.py @@ -13,21 +13,38 @@ # limitations under the License. import Levenshtein +import string class RecMetric(object): - def __init__(self, main_indicator='acc', **kwargs): + def __init__(self, + main_indicator='acc', + is_filter=False, + ignore_space=True, + **kwargs): self.main_indicator = main_indicator + self.is_filter = is_filter + self.ignore_space = ignore_space + self.eps = 1e-5 self.reset() + def _normalize_text(self, text): + text = ''.join( + filter(lambda x: x in (string.digits + string.ascii_letters), text)) + return text.lower() + def __call__(self, pred_label, *args, **kwargs): preds, labels = pred_label correct_num = 0 all_num = 0 norm_edit_dis = 0.0 for (pred, pred_conf), (target, _) in zip(preds, labels): - pred = pred.replace(" ", "") - target = target.replace(" ", "") + if self.ignore_space: + pred = pred.replace(" ", "") + target = target.replace(" ", "") + if self.is_filter: + pred = self._normalize_text(pred) + target = self._normalize_text(target) norm_edit_dis += Levenshtein.distance(pred, target) / max( len(pred), len(target), 1) if pred == target: @@ -37,8 +54,8 @@ def __call__(self, pred_label, *args, **kwargs): self.all_num += all_num self.norm_edit_dis += norm_edit_dis return { - 'acc': correct_num / all_num, - 'norm_edit_dis': 1 - norm_edit_dis / all_num + 'acc': correct_num / (all_num + self.eps), + 'norm_edit_dis': 1 - norm_edit_dis / (all_num + self.eps) } def get_metric(self): @@ -48,8 +65,8 @@ def get_metric(self): 'norm_edit_dis': 0, } """ - acc = 1.0 * self.correct_num / self.all_num - norm_edit_dis = 1 - self.norm_edit_dis / self.all_num + acc = 1.0 * self.correct_num / (self.all_num + self.eps) + norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num + self.eps) self.reset() return {'acc': acc, 'norm_edit_dis': norm_edit_dis} diff --git a/backend/ppocr/metrics/table_metric.py b/backend/ppocr/metrics/table_metric.py new file mode 100644 index 00000000..ca4d6474 --- /dev/null +++ b/backend/ppocr/metrics/table_metric.py @@ -0,0 +1,51 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np + + +class TableMetric(object): + def __init__(self, main_indicator='acc', **kwargs): + self.main_indicator = main_indicator + self.eps = 1e-5 + self.reset() + + def __call__(self, pred, batch, *args, **kwargs): + structure_probs = pred['structure_probs'].numpy() + structure_labels = batch[1] + correct_num = 0 + all_num = 0 + structure_probs = np.argmax(structure_probs, axis=2) + structure_labels = structure_labels[:, 1:] + batch_size = structure_probs.shape[0] + for bno in range(batch_size): + all_num += 1 + if (structure_probs[bno] == structure_labels[bno]).all(): + correct_num += 1 + self.correct_num += correct_num + self.all_num += all_num + return {'acc': correct_num * 1.0 / (all_num + self.eps), } + + def get_metric(self): + """ + return metrics { + 'acc': 0, + } + """ + acc = 1.0 * self.correct_num / (self.all_num + self.eps) + self.reset() + return {'acc': acc} + + def reset(self): + self.correct_num = 0 + self.all_num = 0 diff --git a/backend/ppocr/metrics/vqa_token_re_metric.py b/backend/ppocr/metrics/vqa_token_re_metric.py new file mode 100644 index 00000000..8a13bc08 --- /dev/null +++ b/backend/ppocr/metrics/vqa_token_re_metric.py @@ -0,0 +1,176 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import paddle + +__all__ = ['KIEMetric'] + + +class VQAReTokenMetric(object): + def __init__(self, main_indicator='hmean', **kwargs): + self.main_indicator = main_indicator + self.reset() + + def __call__(self, preds, batch, **kwargs): + pred_relations, relations, entities = preds + self.pred_relations_list.extend(pred_relations) + self.relations_list.extend(relations) + self.entities_list.extend(entities) + + def get_metric(self): + gt_relations = [] + for b in range(len(self.relations_list)): + rel_sent = [] + for head, tail in zip(self.relations_list[b]["head"], + self.relations_list[b]["tail"]): + rel = {} + rel["head_id"] = head + rel["head"] = (self.entities_list[b]["start"][rel["head_id"]], + self.entities_list[b]["end"][rel["head_id"]]) + rel["head_type"] = self.entities_list[b]["label"][rel[ + "head_id"]] + + rel["tail_id"] = tail + rel["tail"] = (self.entities_list[b]["start"][rel["tail_id"]], + self.entities_list[b]["end"][rel["tail_id"]]) + rel["tail_type"] = self.entities_list[b]["label"][rel[ + "tail_id"]] + + rel["type"] = 1 + rel_sent.append(rel) + gt_relations.append(rel_sent) + re_metrics = self.re_score( + self.pred_relations_list, gt_relations, mode="boundaries") + metrics = { + "precision": re_metrics["ALL"]["p"], + "recall": re_metrics["ALL"]["r"], + "hmean": re_metrics["ALL"]["f1"], + } + self.reset() + return metrics + + def reset(self): + self.pred_relations_list = [] + self.relations_list = [] + self.entities_list = [] + + def re_score(self, pred_relations, gt_relations, mode="strict"): + """Evaluate RE predictions + + Args: + pred_relations (list) : list of list of predicted relations (several relations in each sentence) + gt_relations (list) : list of list of ground truth relations + + rel = { "head": (start_idx (inclusive), end_idx (exclusive)), + "tail": (start_idx (inclusive), end_idx (exclusive)), + "head_type": ent_type, + "tail_type": ent_type, + "type": rel_type} + + vocab (Vocab) : dataset vocabulary + mode (str) : in 'strict' or 'boundaries'""" + + assert mode in ["strict", "boundaries"] + + relation_types = [v for v in [0, 1] if not v == 0] + scores = { + rel: { + "tp": 0, + "fp": 0, + "fn": 0 + } + for rel in relation_types + ["ALL"] + } + + # Count GT relations and Predicted relations + n_sents = len(gt_relations) + n_rels = sum([len([rel for rel in sent]) for sent in gt_relations]) + n_found = sum([len([rel for rel in sent]) for sent in pred_relations]) + + # Count TP, FP and FN per type + for pred_sent, gt_sent in zip(pred_relations, gt_relations): + for rel_type in relation_types: + # strict mode takes argument types into account + if mode == "strict": + pred_rels = {(rel["head"], rel["head_type"], rel["tail"], + rel["tail_type"]) + for rel in pred_sent + if rel["type"] == rel_type} + gt_rels = {(rel["head"], rel["head_type"], rel["tail"], + rel["tail_type"]) + for rel in gt_sent if rel["type"] == rel_type} + + # boundaries mode only takes argument spans into account + elif mode == "boundaries": + pred_rels = {(rel["head"], rel["tail"]) + for rel in pred_sent + if rel["type"] == rel_type} + gt_rels = {(rel["head"], rel["tail"]) + for rel in gt_sent if rel["type"] == rel_type} + + scores[rel_type]["tp"] += len(pred_rels & gt_rels) + scores[rel_type]["fp"] += len(pred_rels - gt_rels) + scores[rel_type]["fn"] += len(gt_rels - pred_rels) + + # Compute per entity Precision / Recall / F1 + for rel_type in scores.keys(): + if scores[rel_type]["tp"]: + scores[rel_type]["p"] = scores[rel_type]["tp"] / ( + scores[rel_type]["fp"] + scores[rel_type]["tp"]) + scores[rel_type]["r"] = scores[rel_type]["tp"] / ( + scores[rel_type]["fn"] + scores[rel_type]["tp"]) + else: + scores[rel_type]["p"], scores[rel_type]["r"] = 0, 0 + + if not scores[rel_type]["p"] + scores[rel_type]["r"] == 0: + scores[rel_type]["f1"] = ( + 2 * scores[rel_type]["p"] * scores[rel_type]["r"] / + (scores[rel_type]["p"] + scores[rel_type]["r"])) + else: + scores[rel_type]["f1"] = 0 + + # Compute micro F1 Scores + tp = sum([scores[rel_type]["tp"] for rel_type in relation_types]) + fp = sum([scores[rel_type]["fp"] for rel_type in relation_types]) + fn = sum([scores[rel_type]["fn"] for rel_type in relation_types]) + + if tp: + precision = tp / (tp + fp) + recall = tp / (tp + fn) + f1 = 2 * precision * recall / (precision + recall) + + else: + precision, recall, f1 = 0, 0, 0 + + scores["ALL"]["p"] = precision + scores["ALL"]["r"] = recall + scores["ALL"]["f1"] = f1 + scores["ALL"]["tp"] = tp + scores["ALL"]["fp"] = fp + scores["ALL"]["fn"] = fn + + # Compute Macro F1 Scores + scores["ALL"]["Macro_f1"] = np.mean( + [scores[ent_type]["f1"] for ent_type in relation_types]) + scores["ALL"]["Macro_p"] = np.mean( + [scores[ent_type]["p"] for ent_type in relation_types]) + scores["ALL"]["Macro_r"] = np.mean( + [scores[ent_type]["r"] for ent_type in relation_types]) + + return scores diff --git a/backend/ppocr/metrics/vqa_token_ser_metric.py b/backend/ppocr/metrics/vqa_token_ser_metric.py new file mode 100644 index 00000000..286d8add --- /dev/null +++ b/backend/ppocr/metrics/vqa_token_ser_metric.py @@ -0,0 +1,47 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import paddle + +__all__ = ['KIEMetric'] + + +class VQASerTokenMetric(object): + def __init__(self, main_indicator='hmean', **kwargs): + self.main_indicator = main_indicator + self.reset() + + def __call__(self, preds, batch, **kwargs): + preds, labels = preds + self.pred_list.extend(preds) + self.gt_list.extend(labels) + + def get_metric(self): + from seqeval.metrics import f1_score, precision_score, recall_score + metrics = { + "precision": precision_score(self.gt_list, self.pred_list), + "recall": recall_score(self.gt_list, self.pred_list), + "hmean": f1_score(self.gt_list, self.pred_list), + } + self.reset() + return metrics + + def reset(self): + self.pred_list = [] + self.gt_list = [] diff --git a/backend/ppocr/modeling/architectures/__init__.py b/backend/ppocr/modeling/architectures/__init__.py index 86eaf7c9..e9a01cf0 100755 --- a/backend/ppocr/modeling/architectures/__init__.py +++ b/backend/ppocr/modeling/architectures/__init__.py @@ -13,12 +13,20 @@ # limitations under the License. import copy +import importlib + +from .base_model import BaseModel +from .distillation_model import DistillationModel __all__ = ['build_model'] + def build_model(config): - from .base_model import BaseModel - config = copy.deepcopy(config) - module_class = BaseModel(config) - return module_class \ No newline at end of file + if not "name" in config: + arch = BaseModel(config) + else: + name = config.pop("name") + mod = importlib.import_module(__name__) + arch = getattr(mod, name)(config) + return arch diff --git a/backend/ppocr/modeling/architectures/base_model.py b/backend/ppocr/modeling/architectures/base_model.py index 09b6e034..c6b50d48 100644 --- a/backend/ppocr/modeling/architectures/base_model.py +++ b/backend/ppocr/modeling/architectures/base_model.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +14,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function - from paddle import nn from ppocr.modeling.transforms import build_transform from ppocr.modeling.backbones import build_backbone @@ -32,7 +31,6 @@ def __init__(self, config): config (dict): the super parameters for module. """ super(BaseModel, self).__init__() - in_channels = config.get('in_channels', 3) model_type = config['model_type'] # build transfrom, @@ -65,17 +63,38 @@ def __init__(self, config): in_channels = self.neck.out_channels # # build head, head is need for det, rec and cls - config["Head"]['in_channels'] = in_channels - self.head = build_head(config["Head"]) + if 'Head' not in config or config['Head'] is None: + self.use_head = False + else: + self.use_head = True + config["Head"]['in_channels'] = in_channels + self.head = build_head(config["Head"]) + + self.return_all_feats = config.get("return_all_feats", False) def forward(self, x, data=None): + y = dict() if self.use_transform: x = self.transform(x) x = self.backbone(x) + y["backbone_out"] = x if self.use_neck: x = self.neck(x) - if data is None: - x = self.head(x) + y["neck_out"] = x + if self.use_head: + x = self.head(x, targets=data) + # for multi head, save ctc neck out for udml + if isinstance(x, dict) and 'ctc_neck' in x.keys(): + y["neck_out"] = x["ctc_neck"] + y["head_out"] = x + elif isinstance(x, dict): + y.update(x) + else: + y["head_out"] = x + if self.return_all_feats: + if self.training: + return y + else: + return {"head_out": y["head_out"]} else: - x = self.head(x, data) - return x + return x diff --git a/backend/ppocr/modeling/architectures/distillation_model.py b/backend/ppocr/modeling/architectures/distillation_model.py new file mode 100644 index 00000000..cce8fd31 --- /dev/null +++ b/backend/ppocr/modeling/architectures/distillation_model.py @@ -0,0 +1,60 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from paddle import nn +from ppocr.modeling.transforms import build_transform +from ppocr.modeling.backbones import build_backbone +from ppocr.modeling.necks import build_neck +from ppocr.modeling.heads import build_head +from .base_model import BaseModel +from ppocr.utils.save_load import load_pretrained_params + +__all__ = ['DistillationModel'] + + +class DistillationModel(nn.Layer): + def __init__(self, config): + """ + the module for OCR distillation. + args: + config (dict): the super parameters for module. + """ + super().__init__() + self.model_list = [] + self.model_name_list = [] + for key in config["Models"]: + model_config = config["Models"][key] + freeze_params = False + pretrained = None + if "freeze_params" in model_config: + freeze_params = model_config.pop("freeze_params") + if "pretrained" in model_config: + pretrained = model_config.pop("pretrained") + model = BaseModel(model_config) + if pretrained is not None: + load_pretrained_params(model, pretrained) + if freeze_params: + for param in model.parameters(): + param.trainable = False + self.model_list.append(self.add_sublayer(key, model)) + self.model_name_list.append(key) + + def forward(self, x, data=None): + result_dict = dict() + for idx, model_name in enumerate(self.model_name_list): + result_dict[model_name] = self.model_list[idx](x, data) + return result_dict diff --git a/backend/ppocr/modeling/backbones/__init__.py b/backend/ppocr/modeling/backbones/__init__.py index 03c15508..072d6e0f 100755 --- a/backend/ppocr/modeling/backbones/__init__.py +++ b/backend/ppocr/modeling/backbones/__init__.py @@ -12,26 +12,53 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = ['build_backbone'] +__all__ = ["build_backbone"] def build_backbone(config, model_type): - if model_type == 'det': + if model_type == "det" or model_type == "table": from .det_mobilenet_v3 import MobileNetV3 from .det_resnet_vd import ResNet from .det_resnet_vd_sast import ResNet_SAST - support_dict = ['MobileNetV3', 'ResNet', 'ResNet_SAST'] - elif model_type == 'rec' or model_type == 'cls': + support_dict = ["MobileNetV3", "ResNet", "ResNet_SAST"] + elif model_type == "rec" or model_type == "cls": from .rec_mobilenet_v3 import MobileNetV3 from .rec_resnet_vd import ResNet from .rec_resnet_fpn import ResNetFPN - support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN'] + from .rec_mv1_enhance import MobileNetV1Enhance + from .rec_nrtr_mtb import MTB + from .rec_resnet_31 import ResNet31 + from .rec_resnet_aster import ResNet_ASTER + from .rec_micronet import MicroNet + from .rec_efficientb3_pren import EfficientNetb3_PREN + from .rec_svtrnet import SVTRNet + support_dict = [ + 'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB', + "ResNet31", "ResNet_ASTER", 'MicroNet', 'EfficientNetb3_PREN', + 'SVTRNet' + ] + elif model_type == "e2e": + from .e2e_resnet_vd_pg import ResNet + support_dict = ['ResNet'] + elif model_type == 'kie': + from .kie_unet_sdmgr import Kie_backbone + support_dict = ['Kie_backbone'] + elif model_type == "table": + from .table_resnet_vd import ResNet + from .table_mobilenet_v3 import MobileNetV3 + support_dict = ["ResNet", "MobileNetV3"] + elif model_type == 'vqa': + from .vqa_layoutlm import LayoutLMForSer, LayoutLMv2ForSer, LayoutLMv2ForRe, LayoutXLMForSer, LayoutXLMForRe + support_dict = [ + "LayoutLMForSer", "LayoutLMv2ForSer", 'LayoutLMv2ForRe', + "LayoutXLMForSer", 'LayoutXLMForRe' + ] else: raise NotImplementedError - module_name = config.pop('name') + module_name = config.pop("name") assert module_name in support_dict, Exception( - 'when model typs is {}, backbone only support {}'.format(model_type, + "when model typs is {}, backbone only support {}".format(model_type, support_dict)) module_class = eval(module_name)(**config) return module_class diff --git a/backend/ppocr/modeling/backbones/det_mobilenet_v3.py b/backend/ppocr/modeling/backbones/det_mobilenet_v3.py index bb451bbe..05113ea8 100755 --- a/backend/ppocr/modeling/backbones/det_mobilenet_v3.py +++ b/backend/ppocr/modeling/backbones/det_mobilenet_v3.py @@ -102,8 +102,7 @@ def __init__(self, padding=1, groups=1, if_act=True, - act='hardswish', - name='conv1') + act='hardswish') self.stages = [] self.out_channels = [] @@ -125,8 +124,7 @@ def __init__(self, kernel_size=k, stride=s, use_se=se, - act=nl, - name="conv" + str(i + 2))) + act=nl)) inplanes = make_divisible(scale * c) i += 1 block_list.append( @@ -138,8 +136,7 @@ def __init__(self, padding=0, groups=1, if_act=True, - act='hardswish', - name='conv_last')) + act='hardswish')) self.stages.append(nn.Sequential(*block_list)) self.out_channels.append(make_divisible(scale * cls_ch_squeeze)) for i, stage in enumerate(self.stages): @@ -163,8 +160,7 @@ def __init__(self, padding, groups=1, if_act=True, - act=None, - name=None): + act=None): super(ConvBNLayer, self).__init__() self.if_act = if_act self.act = act @@ -175,16 +171,9 @@ def __init__(self, stride=stride, padding=padding, groups=groups, - weight_attr=ParamAttr(name=name + '_weights'), bias_attr=False) - self.bn = nn.BatchNorm( - num_channels=out_channels, - act=None, - param_attr=ParamAttr(name=name + "_bn_scale"), - bias_attr=ParamAttr(name=name + "_bn_offset"), - moving_mean_name=name + "_bn_mean", - moving_variance_name=name + "_bn_variance") + self.bn = nn.BatchNorm(num_channels=out_channels, act=None) def forward(self, x): x = self.conv(x) @@ -209,8 +198,7 @@ def __init__(self, kernel_size, stride, use_se, - act=None, - name=''): + act=None): super(ResidualUnit, self).__init__() self.if_shortcut = stride == 1 and in_channels == out_channels self.if_se = use_se @@ -222,8 +210,7 @@ def __init__(self, stride=1, padding=0, if_act=True, - act=act, - name=name + "_expand") + act=act) self.bottleneck_conv = ConvBNLayer( in_channels=mid_channels, out_channels=mid_channels, @@ -232,10 +219,9 @@ def __init__(self, padding=int((kernel_size - 1) // 2), groups=mid_channels, if_act=True, - act=act, - name=name + "_depthwise") + act=act) if self.if_se: - self.mid_se = SEModule(mid_channels, name=name + "_se") + self.mid_se = SEModule(mid_channels) self.linear_conv = ConvBNLayer( in_channels=mid_channels, out_channels=out_channels, @@ -243,8 +229,7 @@ def __init__(self, stride=1, padding=0, if_act=False, - act=None, - name=name + "_linear") + act=None) def forward(self, inputs): x = self.expand_conv(inputs) @@ -258,7 +243,7 @@ def forward(self, inputs): class SEModule(nn.Layer): - def __init__(self, in_channels, reduction=4, name=""): + def __init__(self, in_channels, reduction=4): super(SEModule, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2D(1) self.conv1 = nn.Conv2D( @@ -266,17 +251,13 @@ def __init__(self, in_channels, reduction=4, name=""): out_channels=in_channels // reduction, kernel_size=1, stride=1, - padding=0, - weight_attr=ParamAttr(name=name + "_1_weights"), - bias_attr=ParamAttr(name=name + "_1_offset")) + padding=0) self.conv2 = nn.Conv2D( in_channels=in_channels // reduction, out_channels=in_channels, kernel_size=1, stride=1, - padding=0, - weight_attr=ParamAttr(name + "_2_weights"), - bias_attr=ParamAttr(name=name + "_2_offset")) + padding=0) def forward(self, inputs): outputs = self.avg_pool(inputs) diff --git a/backend/ppocr/modeling/backbones/det_resnet_vd.py b/backend/ppocr/modeling/backbones/det_resnet_vd.py index 3bb4a0d5..8c955a4a 100644 --- a/backend/ppocr/modeling/backbones/det_resnet_vd.py +++ b/backend/ppocr/modeling/backbones/det_resnet_vd.py @@ -21,45 +21,116 @@ import paddle.nn as nn import paddle.nn.functional as F +from paddle.vision.ops import DeformConv2D +from paddle.regularizer import L2Decay +from paddle.nn.initializer import Normal, Constant, XavierUniform + __all__ = ["ResNet"] -class ConvBNLayer(nn.Layer): - def __init__( - self, +class DeformableConvV2(nn.Layer): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + weight_attr=None, + bias_attr=None, + lr_scale=1, + regularizer=None, + skip_quant=False, + dcn_bias_regularizer=L2Decay(0.), + dcn_bias_lr_scale=2.): + super(DeformableConvV2, self).__init__() + self.offset_channel = 2 * kernel_size**2 * groups + self.mask_channel = kernel_size**2 * groups + + if bias_attr: + # in FCOS-DCN head, specifically need learning_rate and regularizer + dcn_bias_attr = ParamAttr( + initializer=Constant(value=0), + regularizer=dcn_bias_regularizer, + learning_rate=dcn_bias_lr_scale) + else: + # in ResNet backbone, do not need bias + dcn_bias_attr = False + self.conv_dcn = DeformConv2D( in_channels, out_channels, kernel_size, - stride=1, - groups=1, - is_vd_mode=False, - act=None, - name=None, ): + stride=stride, + padding=(kernel_size - 1) // 2 * dilation, + dilation=dilation, + deformable_groups=groups, + weight_attr=weight_attr, + bias_attr=dcn_bias_attr) + + if lr_scale == 1 and regularizer is None: + offset_bias_attr = ParamAttr(initializer=Constant(0.)) + else: + offset_bias_attr = ParamAttr( + initializer=Constant(0.), + learning_rate=lr_scale, + regularizer=regularizer) + self.conv_offset = nn.Conv2D( + in_channels, + groups * 3 * kernel_size**2, + kernel_size, + stride=stride, + padding=(kernel_size - 1) // 2, + weight_attr=ParamAttr(initializer=Constant(0.0)), + bias_attr=offset_bias_attr) + if skip_quant: + self.conv_offset.skip_quant = True + + def forward(self, x): + offset_mask = self.conv_offset(x) + offset, mask = paddle.split( + offset_mask, + num_or_sections=[self.offset_channel, self.mask_channel], + axis=1) + mask = F.sigmoid(mask) + y = self.conv_dcn(x, offset, mask=mask) + return y + + +class ConvBNLayer(nn.Layer): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + groups=1, + is_vd_mode=False, + act=None, + is_dcn=False): super(ConvBNLayer, self).__init__() self.is_vd_mode = is_vd_mode self._pool2d_avg = nn.AvgPool2D( kernel_size=2, stride=2, padding=0, ceil_mode=True) - self._conv = nn.Conv2D( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=(kernel_size - 1) // 2, - groups=groups, - weight_attr=ParamAttr(name=name + "_weights"), - bias_attr=False) - if name == "conv1": - bn_name = "bn_" + name + if not is_dcn: + self._conv = nn.Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=(kernel_size - 1) // 2, + groups=groups, + bias_attr=False) else: - bn_name = "bn" + name[3:] - self._batch_norm = nn.BatchNorm( - out_channels, - act=act, - param_attr=ParamAttr(name=bn_name + '_scale'), - bias_attr=ParamAttr(bn_name + '_offset'), - moving_mean_name=bn_name + '_mean', - moving_variance_name=bn_name + '_variance') + self._conv = DeformableConvV2( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=(kernel_size - 1) // 2, + groups=2, #groups, + bias_attr=False) + self._batch_norm = nn.BatchNorm(out_channels, act=act) def forward(self, inputs): if self.is_vd_mode: @@ -70,34 +141,33 @@ def forward(self, inputs): class BottleneckBlock(nn.Layer): - def __init__(self, - in_channels, - out_channels, - stride, - shortcut=True, - if_first=False, - name=None): + def __init__( + self, + in_channels, + out_channels, + stride, + shortcut=True, + if_first=False, + is_dcn=False, ): super(BottleneckBlock, self).__init__() self.conv0 = ConvBNLayer( in_channels=in_channels, out_channels=out_channels, kernel_size=1, - act='relu', - name=name + "_branch2a") + act='relu') self.conv1 = ConvBNLayer( in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=stride, act='relu', - name=name + "_branch2b") + is_dcn=is_dcn) self.conv2 = ConvBNLayer( in_channels=out_channels, out_channels=out_channels * 4, kernel_size=1, - act=None, - name=name + "_branch2c") + act=None) if not shortcut: self.short = ConvBNLayer( @@ -105,8 +175,7 @@ def __init__(self, out_channels=out_channels * 4, kernel_size=1, stride=1, - is_vd_mode=False if if_first else True, - name=name + "_branch1") + is_vd_mode=False if if_first else True) self.shortcut = shortcut @@ -125,13 +194,13 @@ def forward(self, inputs): class BasicBlock(nn.Layer): - def __init__(self, - in_channels, - out_channels, - stride, - shortcut=True, - if_first=False, - name=None): + def __init__( + self, + in_channels, + out_channels, + stride, + shortcut=True, + if_first=False, ): super(BasicBlock, self).__init__() self.stride = stride self.conv0 = ConvBNLayer( @@ -139,14 +208,12 @@ def __init__(self, out_channels=out_channels, kernel_size=3, stride=stride, - act='relu', - name=name + "_branch2a") + act='relu') self.conv1 = ConvBNLayer( in_channels=out_channels, out_channels=out_channels, kernel_size=3, - act=None, - name=name + "_branch2b") + act=None) if not shortcut: self.short = ConvBNLayer( @@ -154,8 +221,7 @@ def __init__(self, out_channels=out_channels, kernel_size=1, stride=1, - is_vd_mode=False if if_first else True, - name=name + "_branch1") + is_vd_mode=False if if_first else True) self.shortcut = shortcut @@ -173,7 +239,12 @@ def forward(self, inputs): class ResNet(nn.Layer): - def __init__(self, in_channels=3, layers=50, **kwargs): + def __init__(self, + in_channels=3, + layers=50, + dcn_stage=None, + out_indices=None, + **kwargs): super(ResNet, self).__init__() self.layers = layers @@ -196,27 +267,31 @@ def __init__(self, in_channels=3, layers=50, **kwargs): 1024] if layers >= 50 else [64, 64, 128, 256] num_filters = [64, 128, 256, 512] + self.dcn_stage = dcn_stage if dcn_stage is not None else [ + False, False, False, False + ] + self.out_indices = out_indices if out_indices is not None else [ + 0, 1, 2, 3 + ] + self.conv1_1 = ConvBNLayer( in_channels=in_channels, out_channels=32, kernel_size=3, stride=2, - act='relu', - name="conv1_1") + act='relu') self.conv1_2 = ConvBNLayer( in_channels=32, out_channels=32, kernel_size=3, stride=1, - act='relu', - name="conv1_2") + act='relu') self.conv1_3 = ConvBNLayer( in_channels=32, out_channels=64, kernel_size=3, stride=1, - act='relu', - name="conv1_3") + act='relu') self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1) self.stages = [] @@ -225,14 +300,8 @@ def __init__(self, in_channels=3, layers=50, **kwargs): for block in range(len(depth)): block_list = [] shortcut = False + is_dcn = self.dcn_stage[block] for i in range(depth[block]): - if layers in [101, 152] and block == 2: - if i == 0: - conv_name = "res" + str(block + 2) + "a" - else: - conv_name = "res" + str(block + 2) + "b" + str(i) - else: - conv_name = "res" + str(block + 2) + chr(97 + i) bottleneck_block = self.add_sublayer( 'bb_%d_%d' % (block, i), BottleneckBlock( @@ -242,17 +311,18 @@ def __init__(self, in_channels=3, layers=50, **kwargs): stride=2 if i == 0 and block != 0 else 1, shortcut=shortcut, if_first=block == i == 0, - name=conv_name)) + is_dcn=is_dcn)) shortcut = True block_list.append(bottleneck_block) - self.out_channels.append(num_filters[block] * 4) + if block in self.out_indices: + self.out_channels.append(num_filters[block] * 4) self.stages.append(nn.Sequential(*block_list)) else: for block in range(len(depth)): block_list = [] shortcut = False + # is_dcn = self.dcn_stage[block] for i in range(depth[block]): - conv_name = "res" + str(block + 2) + chr(97 + i) basic_block = self.add_sublayer( 'bb_%d_%d' % (block, i), BasicBlock( @@ -261,11 +331,11 @@ def __init__(self, in_channels=3, layers=50, **kwargs): out_channels=num_filters[block], stride=2 if i == 0 and block != 0 else 1, shortcut=shortcut, - if_first=block == i == 0, - name=conv_name)) + if_first=block == i == 0)) shortcut = True block_list.append(basic_block) - self.out_channels.append(num_filters[block]) + if block in self.out_indices: + self.out_channels.append(num_filters[block]) self.stages.append(nn.Sequential(*block_list)) def forward(self, inputs): @@ -274,7 +344,8 @@ def forward(self, inputs): y = self.conv1_3(y) y = self.pool2d_max(y) out = [] - for block in self.stages: + for i, block in enumerate(self.stages): y = block(y) - out.append(y) + if i in self.out_indices: + out.append(y) return out diff --git a/backend/ppocr/modeling/backbones/e2e_resnet_vd_pg.py b/backend/ppocr/modeling/backbones/e2e_resnet_vd_pg.py new file mode 100644 index 00000000..97afd346 --- /dev/null +++ b/backend/ppocr/modeling/backbones/e2e_resnet_vd_pg.py @@ -0,0 +1,265 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import ParamAttr +import paddle.nn as nn +import paddle.nn.functional as F + +__all__ = ["ResNet"] + + +class ConvBNLayer(nn.Layer): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + groups=1, + is_vd_mode=False, + act=None, + name=None, ): + super(ConvBNLayer, self).__init__() + + self.is_vd_mode = is_vd_mode + self._pool2d_avg = nn.AvgPool2D( + kernel_size=2, stride=2, padding=0, ceil_mode=True) + self._conv = nn.Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=(kernel_size - 1) // 2, + groups=groups, + weight_attr=ParamAttr(name=name + "_weights"), + bias_attr=False) + if name == "conv1": + bn_name = "bn_" + name + else: + bn_name = "bn" + name[3:] + self._batch_norm = nn.BatchNorm( + out_channels, + act=act, + param_attr=ParamAttr(name=bn_name + '_scale'), + bias_attr=ParamAttr(bn_name + '_offset'), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance') + + def forward(self, inputs): + y = self._conv(inputs) + y = self._batch_norm(y) + return y + + +class BottleneckBlock(nn.Layer): + def __init__(self, + in_channels, + out_channels, + stride, + shortcut=True, + if_first=False, + name=None): + super(BottleneckBlock, self).__init__() + + self.conv0 = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + act='relu', + name=name + "_branch2a") + self.conv1 = ConvBNLayer( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + stride=stride, + act='relu', + name=name + "_branch2b") + self.conv2 = ConvBNLayer( + in_channels=out_channels, + out_channels=out_channels * 4, + kernel_size=1, + act=None, + name=name + "_branch2c") + + if not shortcut: + self.short = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels * 4, + kernel_size=1, + stride=stride, + is_vd_mode=False if if_first else True, + name=name + "_branch1") + + self.shortcut = shortcut + + def forward(self, inputs): + y = self.conv0(inputs) + conv1 = self.conv1(y) + conv2 = self.conv2(conv1) + + if self.shortcut: + short = inputs + else: + short = self.short(inputs) + y = paddle.add(x=short, y=conv2) + y = F.relu(y) + return y + + +class BasicBlock(nn.Layer): + def __init__(self, + in_channels, + out_channels, + stride, + shortcut=True, + if_first=False, + name=None): + super(BasicBlock, self).__init__() + self.stride = stride + self.conv0 = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=stride, + act='relu', + name=name + "_branch2a") + self.conv1 = ConvBNLayer( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + act=None, + name=name + "_branch2b") + + if not shortcut: + self.short = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + is_vd_mode=False if if_first else True, + name=name + "_branch1") + + self.shortcut = shortcut + + def forward(self, inputs): + y = self.conv0(inputs) + conv1 = self.conv1(y) + + if self.shortcut: + short = inputs + else: + short = self.short(inputs) + y = paddle.add(x=short, y=conv1) + y = F.relu(y) + return y + + +class ResNet(nn.Layer): + def __init__(self, in_channels=3, layers=50, **kwargs): + super(ResNet, self).__init__() + + self.layers = layers + supported_layers = [18, 34, 50, 101, 152, 200] + assert layers in supported_layers, \ + "supported layers are {} but input layer is {}".format( + supported_layers, layers) + + if layers == 18: + depth = [2, 2, 2, 2] + elif layers == 34 or layers == 50: + # depth = [3, 4, 6, 3] + depth = [3, 4, 6, 3, 3] + elif layers == 101: + depth = [3, 4, 23, 3] + elif layers == 152: + depth = [3, 8, 36, 3] + elif layers == 200: + depth = [3, 12, 48, 3] + num_channels = [64, 256, 512, 1024, + 2048] if layers >= 50 else [64, 64, 128, 256] + num_filters = [64, 128, 256, 512, 512] + + self.conv1_1 = ConvBNLayer( + in_channels=in_channels, + out_channels=64, + kernel_size=7, + stride=2, + act='relu', + name="conv1_1") + self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1) + + self.stages = [] + self.out_channels = [3, 64] + # num_filters = [64, 128, 256, 512, 512] + if layers >= 50: + for block in range(len(depth)): + block_list = [] + shortcut = False + for i in range(depth[block]): + if layers in [101, 152] and block == 2: + if i == 0: + conv_name = "res" + str(block + 2) + "a" + else: + conv_name = "res" + str(block + 2) + "b" + str(i) + else: + conv_name = "res" + str(block + 2) + chr(97 + i) + bottleneck_block = self.add_sublayer( + 'bb_%d_%d' % (block, i), + BottleneckBlock( + in_channels=num_channels[block] + if i == 0 else num_filters[block] * 4, + out_channels=num_filters[block], + stride=2 if i == 0 and block != 0 else 1, + shortcut=shortcut, + if_first=block == i == 0, + name=conv_name)) + shortcut = True + block_list.append(bottleneck_block) + self.out_channels.append(num_filters[block] * 4) + self.stages.append(nn.Sequential(*block_list)) + else: + for block in range(len(depth)): + block_list = [] + shortcut = False + for i in range(depth[block]): + conv_name = "res" + str(block + 2) + chr(97 + i) + basic_block = self.add_sublayer( + 'bb_%d_%d' % (block, i), + BasicBlock( + in_channels=num_channels[block] + if i == 0 else num_filters[block], + out_channels=num_filters[block], + stride=2 if i == 0 and block != 0 else 1, + shortcut=shortcut, + if_first=block == i == 0, + name=conv_name)) + shortcut = True + block_list.append(basic_block) + self.out_channels.append(num_filters[block]) + self.stages.append(nn.Sequential(*block_list)) + + def forward(self, inputs): + out = [inputs] + y = self.conv1_1(inputs) + out.append(y) + y = self.pool2d_max(y) + for block in self.stages: + y = block(y) + out.append(y) + return out diff --git a/backend/ppocr/modeling/backbones/kie_unet_sdmgr.py b/backend/ppocr/modeling/backbones/kie_unet_sdmgr.py new file mode 100644 index 00000000..545e4e75 --- /dev/null +++ b/backend/ppocr/modeling/backbones/kie_unet_sdmgr.py @@ -0,0 +1,186 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn +import numpy as np +import cv2 + +__all__ = ["Kie_backbone"] + + +class Encoder(nn.Layer): + def __init__(self, num_channels, num_filters): + super(Encoder, self).__init__() + self.conv1 = nn.Conv2D( + num_channels, + num_filters, + kernel_size=3, + stride=1, + padding=1, + bias_attr=False) + self.bn1 = nn.BatchNorm(num_filters, act='relu') + + self.conv2 = nn.Conv2D( + num_filters, + num_filters, + kernel_size=3, + stride=1, + padding=1, + bias_attr=False) + self.bn2 = nn.BatchNorm(num_filters, act='relu') + + self.pool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1) + + def forward(self, inputs): + x = self.conv1(inputs) + x = self.bn1(x) + x = self.conv2(x) + x = self.bn2(x) + x_pooled = self.pool(x) + return x, x_pooled + + +class Decoder(nn.Layer): + def __init__(self, num_channels, num_filters): + super(Decoder, self).__init__() + + self.conv1 = nn.Conv2D( + num_channels, + num_filters, + kernel_size=3, + stride=1, + padding=1, + bias_attr=False) + self.bn1 = nn.BatchNorm(num_filters, act='relu') + + self.conv2 = nn.Conv2D( + num_filters, + num_filters, + kernel_size=3, + stride=1, + padding=1, + bias_attr=False) + self.bn2 = nn.BatchNorm(num_filters, act='relu') + + self.conv0 = nn.Conv2D( + num_channels, + num_filters, + kernel_size=1, + stride=1, + padding=0, + bias_attr=False) + self.bn0 = nn.BatchNorm(num_filters, act='relu') + + def forward(self, inputs_prev, inputs): + x = self.conv0(inputs) + x = self.bn0(x) + x = paddle.nn.functional.interpolate( + x, scale_factor=2, mode='bilinear', align_corners=False) + x = paddle.concat([inputs_prev, x], axis=1) + x = self.conv1(x) + x = self.bn1(x) + x = self.conv2(x) + x = self.bn2(x) + return x + + +class UNet(nn.Layer): + def __init__(self): + super(UNet, self).__init__() + self.down1 = Encoder(num_channels=3, num_filters=16) + self.down2 = Encoder(num_channels=16, num_filters=32) + self.down3 = Encoder(num_channels=32, num_filters=64) + self.down4 = Encoder(num_channels=64, num_filters=128) + self.down5 = Encoder(num_channels=128, num_filters=256) + + self.up1 = Decoder(32, 16) + self.up2 = Decoder(64, 32) + self.up3 = Decoder(128, 64) + self.up4 = Decoder(256, 128) + self.out_channels = 16 + + def forward(self, inputs): + x1, _ = self.down1(inputs) + _, x2 = self.down2(x1) + _, x3 = self.down3(x2) + _, x4 = self.down4(x3) + _, x5 = self.down5(x4) + + x = self.up4(x4, x5) + x = self.up3(x3, x) + x = self.up2(x2, x) + x = self.up1(x1, x) + return x + + +class Kie_backbone(nn.Layer): + def __init__(self, in_channels, **kwargs): + super(Kie_backbone, self).__init__() + self.out_channels = 16 + self.img_feat = UNet() + self.maxpool = nn.MaxPool2D(kernel_size=7) + + def bbox2roi(self, bbox_list): + rois_list = [] + rois_num = [] + for img_id, bboxes in enumerate(bbox_list): + rois_num.append(bboxes.shape[0]) + rois_list.append(bboxes) + rois = paddle.concat(rois_list, 0) + rois_num = paddle.to_tensor(rois_num, dtype='int32') + return rois, rois_num + + def pre_process(self, img, relations, texts, gt_bboxes, tag, img_size): + img, relations, texts, gt_bboxes, tag, img_size = img.numpy( + ), relations.numpy(), texts.numpy(), gt_bboxes.numpy(), tag.numpy( + ).tolist(), img_size.numpy() + temp_relations, temp_texts, temp_gt_bboxes = [], [], [] + h, w = int(np.max(img_size[:, 0])), int(np.max(img_size[:, 1])) + img = paddle.to_tensor(img[:, :, :h, :w]) + batch = len(tag) + for i in range(batch): + num, recoder_len = tag[i][0], tag[i][1] + temp_relations.append( + paddle.to_tensor( + relations[i, :num, :num, :], dtype='float32')) + temp_texts.append( + paddle.to_tensor( + texts[i, :num, :recoder_len], dtype='float32')) + temp_gt_bboxes.append( + paddle.to_tensor( + gt_bboxes[i, :num, ...], dtype='float32')) + return img, temp_relations, temp_texts, temp_gt_bboxes + + def forward(self, inputs): + img = inputs[0] + relations, texts, gt_bboxes, tag, img_size = inputs[1], inputs[ + 2], inputs[3], inputs[5], inputs[-1] + img, relations, texts, gt_bboxes = self.pre_process( + img, relations, texts, gt_bboxes, tag, img_size) + x = self.img_feat(img) + boxes, rois_num = self.bbox2roi(gt_bboxes) + feats = paddle.fluid.layers.roi_align( + x, + boxes, + spatial_scale=1.0, + pooled_height=7, + pooled_width=7, + rois_num=rois_num) + feats = self.maxpool(feats).squeeze(-1).squeeze(-1) + return [relations, texts, feats] diff --git a/backend/ppocr/modeling/backbones/rec_efficientb3_pren.py b/backend/ppocr/modeling/backbones/rec_efficientb3_pren.py new file mode 100644 index 00000000..57eef178 --- /dev/null +++ b/backend/ppocr/modeling/backbones/rec_efficientb3_pren.py @@ -0,0 +1,228 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Code is refer from: +https://github.com/RuijieJ/pren/blob/main/Nets/EfficientNet.py +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +from collections import namedtuple +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +__all__ = ['EfficientNetb3'] + + +class EffB3Params: + @staticmethod + def get_global_params(): + """ + The fllowing are efficientnetb3's arch superparams, but to fit for scene + text recognition task, the resolution(image_size) here is changed + from 300 to 64. + """ + GlobalParams = namedtuple('GlobalParams', [ + 'drop_connect_rate', 'width_coefficient', 'depth_coefficient', + 'depth_divisor', 'image_size' + ]) + global_params = GlobalParams( + drop_connect_rate=0.3, + width_coefficient=1.2, + depth_coefficient=1.4, + depth_divisor=8, + image_size=64) + return global_params + + @staticmethod + def get_block_params(): + BlockParams = namedtuple('BlockParams', [ + 'kernel_size', 'num_repeat', 'input_filters', 'output_filters', + 'expand_ratio', 'id_skip', 'se_ratio', 'stride' + ]) + block_params = [ + BlockParams(3, 1, 32, 16, 1, True, 0.25, 1), + BlockParams(3, 2, 16, 24, 6, True, 0.25, 2), + BlockParams(5, 2, 24, 40, 6, True, 0.25, 2), + BlockParams(3, 3, 40, 80, 6, True, 0.25, 2), + BlockParams(5, 3, 80, 112, 6, True, 0.25, 1), + BlockParams(5, 4, 112, 192, 6, True, 0.25, 2), + BlockParams(3, 1, 192, 320, 6, True, 0.25, 1) + ] + return block_params + + +class EffUtils: + @staticmethod + def round_filters(filters, global_params): + """Calculate and round number of filters based on depth multiplier.""" + multiplier = global_params.width_coefficient + if not multiplier: + return filters + divisor = global_params.depth_divisor + filters *= multiplier + new_filters = int(filters + divisor / 2) // divisor * divisor + if new_filters < 0.9 * filters: + new_filters += divisor + return int(new_filters) + + @staticmethod + def round_repeats(repeats, global_params): + """Round number of filters based on depth multiplier.""" + multiplier = global_params.depth_coefficient + if not multiplier: + return repeats + return int(math.ceil(multiplier * repeats)) + + +class ConvBlock(nn.Layer): + def __init__(self, block_params): + super(ConvBlock, self).__init__() + self.block_args = block_params + self.has_se = (self.block_args.se_ratio is not None) and \ + (0 < self.block_args.se_ratio <= 1) + self.id_skip = block_params.id_skip + + # expansion phase + self.input_filters = self.block_args.input_filters + output_filters = \ + self.block_args.input_filters * self.block_args.expand_ratio + if self.block_args.expand_ratio != 1: + self.expand_conv = nn.Conv2D( + self.input_filters, output_filters, 1, bias_attr=False) + self.bn0 = nn.BatchNorm(output_filters) + + # depthwise conv phase + k = self.block_args.kernel_size + s = self.block_args.stride + self.depthwise_conv = nn.Conv2D( + output_filters, + output_filters, + groups=output_filters, + kernel_size=k, + stride=s, + padding='same', + bias_attr=False) + self.bn1 = nn.BatchNorm(output_filters) + + # squeeze and excitation layer, if desired + if self.has_se: + num_squeezed_channels = max(1, + int(self.block_args.input_filters * + self.block_args.se_ratio)) + self.se_reduce = nn.Conv2D(output_filters, num_squeezed_channels, 1) + self.se_expand = nn.Conv2D(num_squeezed_channels, output_filters, 1) + + # output phase + self.final_oup = self.block_args.output_filters + self.project_conv = nn.Conv2D( + output_filters, self.final_oup, 1, bias_attr=False) + self.bn2 = nn.BatchNorm(self.final_oup) + self.swish = nn.Swish() + + def drop_connect(self, inputs, p, training): + if not training: + return inputs + + batch_size = inputs.shape[0] + keep_prob = 1 - p + random_tensor = keep_prob + random_tensor += paddle.rand([batch_size, 1, 1, 1], dtype=inputs.dtype) + random_tensor = paddle.to_tensor(random_tensor, place=inputs.place) + binary_tensor = paddle.floor(random_tensor) + output = inputs / keep_prob * binary_tensor + return output + + def forward(self, inputs, drop_connect_rate=None): + # expansion and depthwise conv + x = inputs + if self.block_args.expand_ratio != 1: + x = self.swish(self.bn0(self.expand_conv(inputs))) + x = self.swish(self.bn1(self.depthwise_conv(x))) + + # squeeze and excitation + if self.has_se: + x_squeezed = F.adaptive_avg_pool2d(x, 1) + x_squeezed = self.se_expand(self.swish(self.se_reduce(x_squeezed))) + x = F.sigmoid(x_squeezed) * x + x = self.bn2(self.project_conv(x)) + + # skip conntection and drop connect + if self.id_skip and self.block_args.stride == 1 and \ + self.input_filters == self.final_oup: + if drop_connect_rate: + x = self.drop_connect( + x, p=drop_connect_rate, training=self.training) + x = x + inputs + return x + + +class EfficientNetb3_PREN(nn.Layer): + def __init__(self, in_channels): + super(EfficientNetb3_PREN, self).__init__() + self.blocks_params = EffB3Params.get_block_params() + self.global_params = EffB3Params.get_global_params() + self.out_channels = [] + # stem + stem_channels = EffUtils.round_filters(32, self.global_params) + self.conv_stem = nn.Conv2D( + in_channels, stem_channels, 3, 2, padding='same', bias_attr=False) + self.bn0 = nn.BatchNorm(stem_channels) + + self.blocks = [] + # to extract three feature maps for fpn based on efficientnetb3 backbone + self.concerned_block_idxes = [7, 17, 25] + concerned_idx = 0 + for i, block_params in enumerate(self.blocks_params): + block_params = block_params._replace( + input_filters=EffUtils.round_filters(block_params.input_filters, + self.global_params), + output_filters=EffUtils.round_filters( + block_params.output_filters, self.global_params), + num_repeat=EffUtils.round_repeats(block_params.num_repeat, + self.global_params)) + self.blocks.append( + self.add_sublayer("{}-0".format(i), ConvBlock(block_params))) + concerned_idx += 1 + if concerned_idx in self.concerned_block_idxes: + self.out_channels.append(block_params.output_filters) + if block_params.num_repeat > 1: + block_params = block_params._replace( + input_filters=block_params.output_filters, stride=1) + for j in range(block_params.num_repeat - 1): + self.blocks.append( + self.add_sublayer('{}-{}'.format(i, j + 1), + ConvBlock(block_params))) + concerned_idx += 1 + if concerned_idx in self.concerned_block_idxes: + self.out_channels.append(block_params.output_filters) + + self.swish = nn.Swish() + + def forward(self, inputs): + outs = [] + + x = self.swish(self.bn0(self.conv_stem(inputs))) + for idx, block in enumerate(self.blocks): + drop_connect_rate = self.global_params.drop_connect_rate + if drop_connect_rate: + drop_connect_rate *= float(idx) / len(self.blocks) + x = block(x, drop_connect_rate=drop_connect_rate) + if idx in self.concerned_block_idxes: + outs.append(x) + return outs diff --git a/backend/ppocr/modeling/backbones/rec_micronet.py b/backend/ppocr/modeling/backbones/rec_micronet.py new file mode 100644 index 00000000..b0ae5a14 --- /dev/null +++ b/backend/ppocr/modeling/backbones/rec_micronet.py @@ -0,0 +1,528 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/liyunsheng13/micronet/blob/main/backbone/micronet.py +https://github.com/liyunsheng13/micronet/blob/main/backbone/activation.py +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn as nn + +from ppocr.modeling.backbones.det_mobilenet_v3 import make_divisible + +M0_cfgs = [ + # s, n, c, ks, c1, c2, g1, g2, c3, g3, g4, y1, y2, y3, r + [2, 1, 8, 3, 2, 2, 0, 4, 8, 2, 2, 2, 0, 1, 1], + [2, 1, 12, 3, 2, 2, 0, 8, 12, 4, 4, 2, 2, 1, 1], + [2, 1, 16, 5, 2, 2, 0, 12, 16, 4, 4, 2, 2, 1, 1], + [1, 1, 32, 5, 1, 4, 4, 4, 32, 4, 4, 2, 2, 1, 1], + [2, 1, 64, 5, 1, 4, 8, 8, 64, 8, 8, 2, 2, 1, 1], + [1, 1, 96, 3, 1, 4, 8, 8, 96, 8, 8, 2, 2, 1, 2], + [1, 1, 384, 3, 1, 4, 12, 12, 0, 0, 0, 2, 2, 1, 2], +] +M1_cfgs = [ + # s, n, c, ks, c1, c2, g1, g2, c3, g3, g4 + [2, 1, 8, 3, 2, 2, 0, 6, 8, 2, 2, 2, 0, 1, 1], + [2, 1, 16, 3, 2, 2, 0, 8, 16, 4, 4, 2, 2, 1, 1], + [2, 1, 16, 5, 2, 2, 0, 16, 16, 4, 4, 2, 2, 1, 1], + [1, 1, 32, 5, 1, 6, 4, 4, 32, 4, 4, 2, 2, 1, 1], + [2, 1, 64, 5, 1, 6, 8, 8, 64, 8, 8, 2, 2, 1, 1], + [1, 1, 96, 3, 1, 6, 8, 8, 96, 8, 8, 2, 2, 1, 2], + [1, 1, 576, 3, 1, 6, 12, 12, 0, 0, 0, 2, 2, 1, 2], +] +M2_cfgs = [ + # s, n, c, ks, c1, c2, g1, g2, c3, g3, g4 + [2, 1, 12, 3, 2, 2, 0, 8, 12, 4, 4, 2, 0, 1, 1], + [2, 1, 16, 3, 2, 2, 0, 12, 16, 4, 4, 2, 2, 1, 1], + [1, 1, 24, 3, 2, 2, 0, 16, 24, 4, 4, 2, 2, 1, 1], + [2, 1, 32, 5, 1, 6, 6, 6, 32, 4, 4, 2, 2, 1, 1], + [1, 1, 32, 5, 1, 6, 8, 8, 32, 4, 4, 2, 2, 1, 2], + [1, 1, 64, 5, 1, 6, 8, 8, 64, 8, 8, 2, 2, 1, 2], + [2, 1, 96, 5, 1, 6, 8, 8, 96, 8, 8, 2, 2, 1, 2], + [1, 1, 128, 3, 1, 6, 12, 12, 128, 8, 8, 2, 2, 1, 2], + [1, 1, 768, 3, 1, 6, 16, 16, 0, 0, 0, 2, 2, 1, 2], +] +M3_cfgs = [ + # s, n, c, ks, c1, c2, g1, g2, c3, g3, g4 + [2, 1, 16, 3, 2, 2, 0, 12, 16, 4, 4, 0, 2, 0, 1], + [2, 1, 24, 3, 2, 2, 0, 16, 24, 4, 4, 0, 2, 0, 1], + [1, 1, 24, 3, 2, 2, 0, 24, 24, 4, 4, 0, 2, 0, 1], + [2, 1, 32, 5, 1, 6, 6, 6, 32, 4, 4, 0, 2, 0, 1], + [1, 1, 32, 5, 1, 6, 8, 8, 32, 4, 4, 0, 2, 0, 2], + [1, 1, 64, 5, 1, 6, 8, 8, 48, 8, 8, 0, 2, 0, 2], + [1, 1, 80, 5, 1, 6, 8, 8, 80, 8, 8, 0, 2, 0, 2], + [1, 1, 80, 5, 1, 6, 10, 10, 80, 8, 8, 0, 2, 0, 2], + [1, 1, 120, 5, 1, 6, 10, 10, 120, 10, 10, 0, 2, 0, 2], + [1, 1, 120, 5, 1, 6, 12, 12, 120, 10, 10, 0, 2, 0, 2], + [1, 1, 144, 3, 1, 6, 12, 12, 144, 12, 12, 0, 2, 0, 2], + [1, 1, 432, 3, 1, 3, 12, 12, 0, 0, 0, 0, 2, 0, 2], +] + + +def get_micronet_config(mode): + return eval(mode + '_cfgs') + + +class MaxGroupPooling(nn.Layer): + def __init__(self, channel_per_group=2): + super(MaxGroupPooling, self).__init__() + self.channel_per_group = channel_per_group + + def forward(self, x): + if self.channel_per_group == 1: + return x + # max op + b, c, h, w = x.shape + + # reshape + y = paddle.reshape(x, [b, c // self.channel_per_group, -1, h, w]) + out = paddle.max(y, axis=2) + return out + + +class SpatialSepConvSF(nn.Layer): + def __init__(self, inp, oups, kernel_size, stride): + super(SpatialSepConvSF, self).__init__() + + oup1, oup2 = oups + self.conv = nn.Sequential( + nn.Conv2D( + inp, + oup1, (kernel_size, 1), (stride, 1), (kernel_size // 2, 0), + bias_attr=False, + groups=1), + nn.BatchNorm2D(oup1), + nn.Conv2D( + oup1, + oup1 * oup2, (1, kernel_size), (1, stride), + (0, kernel_size // 2), + bias_attr=False, + groups=oup1), + nn.BatchNorm2D(oup1 * oup2), + ChannelShuffle(oup1), ) + + def forward(self, x): + out = self.conv(x) + return out + + +class ChannelShuffle(nn.Layer): + def __init__(self, groups): + super(ChannelShuffle, self).__init__() + self.groups = groups + + def forward(self, x): + b, c, h, w = x.shape + + channels_per_group = c // self.groups + + # reshape + x = paddle.reshape(x, [b, self.groups, channels_per_group, h, w]) + + x = paddle.transpose(x, (0, 2, 1, 3, 4)) + out = paddle.reshape(x, [b, -1, h, w]) + + return out + + +class StemLayer(nn.Layer): + def __init__(self, inp, oup, stride, groups=(4, 4)): + super(StemLayer, self).__init__() + + g1, g2 = groups + self.stem = nn.Sequential( + SpatialSepConvSF(inp, groups, 3, stride), + MaxGroupPooling(2) if g1 * g2 == 2 * oup else nn.ReLU6()) + + def forward(self, x): + out = self.stem(x) + return out + + +class DepthSpatialSepConv(nn.Layer): + def __init__(self, inp, expand, kernel_size, stride): + super(DepthSpatialSepConv, self).__init__() + + exp1, exp2 = expand + + hidden_dim = inp * exp1 + oup = inp * exp1 * exp2 + + self.conv = nn.Sequential( + nn.Conv2D( + inp, + inp * exp1, (kernel_size, 1), (stride, 1), + (kernel_size // 2, 0), + bias_attr=False, + groups=inp), + nn.BatchNorm2D(inp * exp1), + nn.Conv2D( + hidden_dim, + oup, (1, kernel_size), + 1, (0, kernel_size // 2), + bias_attr=False, + groups=hidden_dim), + nn.BatchNorm2D(oup)) + + def forward(self, x): + x = self.conv(x) + return x + + +class GroupConv(nn.Layer): + def __init__(self, inp, oup, groups=2): + super(GroupConv, self).__init__() + self.inp = inp + self.oup = oup + self.groups = groups + self.conv = nn.Sequential( + nn.Conv2D( + inp, oup, 1, 1, 0, bias_attr=False, groups=self.groups[0]), + nn.BatchNorm2D(oup)) + + def forward(self, x): + x = self.conv(x) + return x + + +class DepthConv(nn.Layer): + def __init__(self, inp, oup, kernel_size, stride): + super(DepthConv, self).__init__() + self.conv = nn.Sequential( + nn.Conv2D( + inp, + oup, + kernel_size, + stride, + kernel_size // 2, + bias_attr=False, + groups=inp), + nn.BatchNorm2D(oup)) + + def forward(self, x): + out = self.conv(x) + return out + + +class DYShiftMax(nn.Layer): + def __init__(self, + inp, + oup, + reduction=4, + act_max=1.0, + act_relu=True, + init_a=[0.0, 0.0], + init_b=[0.0, 0.0], + relu_before_pool=False, + g=None, + expansion=False): + super(DYShiftMax, self).__init__() + self.oup = oup + self.act_max = act_max * 2 + self.act_relu = act_relu + self.avg_pool = nn.Sequential(nn.ReLU() if relu_before_pool == True else + nn.Sequential(), nn.AdaptiveAvgPool2D(1)) + + self.exp = 4 if act_relu else 2 + self.init_a = init_a + self.init_b = init_b + + # determine squeeze + squeeze = make_divisible(inp // reduction, 4) + if squeeze < 4: + squeeze = 4 + + self.fc = nn.Sequential( + nn.Linear(inp, squeeze), + nn.ReLU(), nn.Linear(squeeze, oup * self.exp), nn.Hardsigmoid()) + + if g is None: + g = 1 + self.g = g[1] + if self.g != 1 and expansion: + self.g = inp // self.g + + self.gc = inp // self.g + index = paddle.to_tensor([range(inp)]) + index = paddle.reshape(index, [1, inp, 1, 1]) + index = paddle.reshape(index, [1, self.g, self.gc, 1, 1]) + indexgs = paddle.split(index, [1, self.g - 1], axis=1) + indexgs = paddle.concat((indexgs[1], indexgs[0]), axis=1) + indexs = paddle.split(indexgs, [1, self.gc - 1], axis=2) + indexs = paddle.concat((indexs[1], indexs[0]), axis=2) + self.index = paddle.reshape(indexs, [inp]) + self.expansion = expansion + + def forward(self, x): + x_in = x + x_out = x + + b, c, _, _ = x_in.shape + y = self.avg_pool(x_in) + y = paddle.reshape(y, [b, c]) + y = self.fc(y) + y = paddle.reshape(y, [b, self.oup * self.exp, 1, 1]) + y = (y - 0.5) * self.act_max + + n2, c2, h2, w2 = x_out.shape + x2 = paddle.to_tensor(x_out.numpy()[:, self.index.numpy(), :, :]) + + if self.exp == 4: + temp = y.shape + a1, b1, a2, b2 = paddle.split(y, temp[1] // self.oup, axis=1) + + a1 = a1 + self.init_a[0] + a2 = a2 + self.init_a[1] + + b1 = b1 + self.init_b[0] + b2 = b2 + self.init_b[1] + + z1 = x_out * a1 + x2 * b1 + z2 = x_out * a2 + x2 * b2 + + out = paddle.maximum(z1, z2) + + elif self.exp == 2: + temp = y.shape + a1, b1 = paddle.split(y, temp[1] // self.oup, axis=1) + a1 = a1 + self.init_a[0] + b1 = b1 + self.init_b[0] + out = x_out * a1 + x2 * b1 + + return out + + +class DYMicroBlock(nn.Layer): + def __init__(self, + inp, + oup, + kernel_size=3, + stride=1, + ch_exp=(2, 2), + ch_per_group=4, + groups_1x1=(1, 1), + depthsep=True, + shuffle=False, + activation_cfg=None): + super(DYMicroBlock, self).__init__() + + self.identity = stride == 1 and inp == oup + + y1, y2, y3 = activation_cfg['dy'] + act_reduction = 8 * activation_cfg['ratio'] + init_a = activation_cfg['init_a'] + init_b = activation_cfg['init_b'] + + t1 = ch_exp + gs1 = ch_per_group + hidden_fft, g1, g2 = groups_1x1 + hidden_dim2 = inp * t1[0] * t1[1] + + if gs1[0] == 0: + self.layers = nn.Sequential( + DepthSpatialSepConv(inp, t1, kernel_size, stride), + DYShiftMax( + hidden_dim2, + hidden_dim2, + act_max=2.0, + act_relu=True if y2 == 2 else False, + init_a=init_a, + reduction=act_reduction, + init_b=init_b, + g=gs1, + expansion=False) if y2 > 0 else nn.ReLU6(), + ChannelShuffle(gs1[1]) if shuffle else nn.Sequential(), + ChannelShuffle(hidden_dim2 // 2) + if shuffle and y2 != 0 else nn.Sequential(), + GroupConv(hidden_dim2, oup, (g1, g2)), + DYShiftMax( + oup, + oup, + act_max=2.0, + act_relu=False, + init_a=[1.0, 0.0], + reduction=act_reduction // 2, + init_b=[0.0, 0.0], + g=(g1, g2), + expansion=False) if y3 > 0 else nn.Sequential(), + ChannelShuffle(g2) if shuffle else nn.Sequential(), + ChannelShuffle(oup // 2) + if shuffle and oup % 2 == 0 and y3 != 0 else nn.Sequential(), ) + elif g2 == 0: + self.layers = nn.Sequential( + GroupConv(inp, hidden_dim2, gs1), + DYShiftMax( + hidden_dim2, + hidden_dim2, + act_max=2.0, + act_relu=False, + init_a=[1.0, 0.0], + reduction=act_reduction, + init_b=[0.0, 0.0], + g=gs1, + expansion=False) if y3 > 0 else nn.Sequential(), ) + else: + self.layers = nn.Sequential( + GroupConv(inp, hidden_dim2, gs1), + DYShiftMax( + hidden_dim2, + hidden_dim2, + act_max=2.0, + act_relu=True if y1 == 2 else False, + init_a=init_a, + reduction=act_reduction, + init_b=init_b, + g=gs1, + expansion=False) if y1 > 0 else nn.ReLU6(), + ChannelShuffle(gs1[1]) if shuffle else nn.Sequential(), + DepthSpatialSepConv(hidden_dim2, (1, 1), kernel_size, stride) + if depthsep else + DepthConv(hidden_dim2, hidden_dim2, kernel_size, stride), + nn.Sequential(), + DYShiftMax( + hidden_dim2, + hidden_dim2, + act_max=2.0, + act_relu=True if y2 == 2 else False, + init_a=init_a, + reduction=act_reduction, + init_b=init_b, + g=gs1, + expansion=True) if y2 > 0 else nn.ReLU6(), + ChannelShuffle(hidden_dim2 // 4) + if shuffle and y1 != 0 and y2 != 0 else nn.Sequential() + if y1 == 0 and y2 == 0 else ChannelShuffle(hidden_dim2 // 2), + GroupConv(hidden_dim2, oup, (g1, g2)), + DYShiftMax( + oup, + oup, + act_max=2.0, + act_relu=False, + init_a=[1.0, 0.0], + reduction=act_reduction // 2 + if oup < hidden_dim2 else act_reduction, + init_b=[0.0, 0.0], + g=(g1, g2), + expansion=False) if y3 > 0 else nn.Sequential(), + ChannelShuffle(g2) if shuffle else nn.Sequential(), + ChannelShuffle(oup // 2) + if shuffle and y3 != 0 else nn.Sequential(), ) + + def forward(self, x): + identity = x + out = self.layers(x) + + if self.identity: + out = out + identity + + return out + + +class MicroNet(nn.Layer): + """ + the MicroNet backbone network for recognition module. + Args: + mode(str): {'M0', 'M1', 'M2', 'M3'} + Four models are proposed based on four different computational costs (4M, 6M, 12M, 21M MAdds) + Default: 'M3'. + """ + + def __init__(self, mode='M3', **kwargs): + super(MicroNet, self).__init__() + + self.cfgs = get_micronet_config(mode) + + activation_cfg = {} + if mode == 'M0': + input_channel = 4 + stem_groups = 2, 2 + out_ch = 384 + activation_cfg['init_a'] = 1.0, 1.0 + activation_cfg['init_b'] = 0.0, 0.0 + elif mode == 'M1': + input_channel = 6 + stem_groups = 3, 2 + out_ch = 576 + activation_cfg['init_a'] = 1.0, 1.0 + activation_cfg['init_b'] = 0.0, 0.0 + elif mode == 'M2': + input_channel = 8 + stem_groups = 4, 2 + out_ch = 768 + activation_cfg['init_a'] = 1.0, 1.0 + activation_cfg['init_b'] = 0.0, 0.0 + elif mode == 'M3': + input_channel = 12 + stem_groups = 4, 3 + out_ch = 432 + activation_cfg['init_a'] = 1.0, 0.5 + activation_cfg['init_b'] = 0.0, 0.5 + else: + raise NotImplementedError("mode[" + mode + + "_model] is not implemented!") + + layers = [StemLayer(3, input_channel, stride=2, groups=stem_groups)] + + for idx, val in enumerate(self.cfgs): + s, n, c, ks, c1, c2, g1, g2, c3, g3, g4, y1, y2, y3, r = val + + t1 = (c1, c2) + gs1 = (g1, g2) + gs2 = (c3, g3, g4) + activation_cfg['dy'] = [y1, y2, y3] + activation_cfg['ratio'] = r + + output_channel = c + layers.append( + DYMicroBlock( + input_channel, + output_channel, + kernel_size=ks, + stride=s, + ch_exp=t1, + ch_per_group=gs1, + groups_1x1=gs2, + depthsep=True, + shuffle=True, + activation_cfg=activation_cfg, )) + input_channel = output_channel + for i in range(1, n): + layers.append( + DYMicroBlock( + input_channel, + output_channel, + kernel_size=ks, + stride=1, + ch_exp=t1, + ch_per_group=gs1, + groups_1x1=gs2, + depthsep=True, + shuffle=True, + activation_cfg=activation_cfg, )) + input_channel = output_channel + self.features = nn.Sequential(*layers) + + self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0) + + self.out_channels = make_divisible(out_ch) + + def forward(self, x): + x = self.features(x) + x = self.pool(x) + return x diff --git a/backend/ppocr/modeling/backbones/rec_mobilenet_v3.py b/backend/ppocr/modeling/backbones/rec_mobilenet_v3.py index 1ff17159..917e000d 100644 --- a/backend/ppocr/modeling/backbones/rec_mobilenet_v3.py +++ b/backend/ppocr/modeling/backbones/rec_mobilenet_v3.py @@ -26,8 +26,10 @@ def __init__(self, scale=0.5, large_stride=None, small_stride=None, + disable_se=False, **kwargs): super(MobileNetV3, self).__init__() + self.disable_se = disable_se if small_stride is None: small_stride = [2, 2, 2, 2] if large_stride is None: @@ -96,12 +98,12 @@ def __init__(self, padding=1, groups=1, if_act=True, - act='hardswish', - name='conv1') + act='hardswish') i = 0 block_list = [] inplanes = make_divisible(inplanes * scale) for (k, exp, c, se, nl, s) in cfg: + se = se and not self.disable_se block_list.append( ResidualUnit( in_channels=inplanes, @@ -110,8 +112,7 @@ def __init__(self, kernel_size=k, stride=s, use_se=se, - act=nl, - name='conv' + str(i + 2))) + act=nl)) inplanes = make_divisible(scale * c) i += 1 self.blocks = nn.Sequential(*block_list) @@ -124,8 +125,7 @@ def __init__(self, padding=0, groups=1, if_act=True, - act='hardswish', - name='conv_last') + act='hardswish') self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0) self.out_channels = make_divisible(scale * cls_ch_squeeze) diff --git a/backend/ppocr/modeling/backbones/rec_mv1_enhance.py b/backend/ppocr/modeling/backbones/rec_mv1_enhance.py new file mode 100644 index 00000000..bb6af5e8 --- /dev/null +++ b/backend/ppocr/modeling/backbones/rec_mv1_enhance.py @@ -0,0 +1,256 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This code is refer from: https://github.com/PaddlePaddle/PaddleClas/blob/develop/ppcls/arch/backbone/legendary_models/pp_lcnet.py + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import numpy as np +import paddle +from paddle import ParamAttr, reshape, transpose +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.nn import Conv2D, BatchNorm, Linear, Dropout +from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D +from paddle.nn.initializer import KaimingNormal +from paddle.regularizer import L2Decay +from paddle.nn.functional import hardswish, hardsigmoid + + +class ConvBNLayer(nn.Layer): + def __init__(self, + num_channels, + filter_size, + num_filters, + stride, + padding, + channels=None, + num_groups=1, + act='hard_swish'): + super(ConvBNLayer, self).__init__() + + self._conv = Conv2D( + in_channels=num_channels, + out_channels=num_filters, + kernel_size=filter_size, + stride=stride, + padding=padding, + groups=num_groups, + weight_attr=ParamAttr(initializer=KaimingNormal()), + bias_attr=False) + + self._batch_norm = BatchNorm( + num_filters, + act=act, + param_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0))) + + def forward(self, inputs): + y = self._conv(inputs) + y = self._batch_norm(y) + return y + + +class DepthwiseSeparable(nn.Layer): + def __init__(self, + num_channels, + num_filters1, + num_filters2, + num_groups, + stride, + scale, + dw_size=3, + padding=1, + use_se=False): + super(DepthwiseSeparable, self).__init__() + self.use_se = use_se + self._depthwise_conv = ConvBNLayer( + num_channels=num_channels, + num_filters=int(num_filters1 * scale), + filter_size=dw_size, + stride=stride, + padding=padding, + num_groups=int(num_groups * scale)) + if use_se: + self._se = SEModule(int(num_filters1 * scale)) + self._pointwise_conv = ConvBNLayer( + num_channels=int(num_filters1 * scale), + filter_size=1, + num_filters=int(num_filters2 * scale), + stride=1, + padding=0) + + def forward(self, inputs): + y = self._depthwise_conv(inputs) + if self.use_se: + y = self._se(y) + y = self._pointwise_conv(y) + return y + + +class MobileNetV1Enhance(nn.Layer): + def __init__(self, + in_channels=3, + scale=0.5, + last_conv_stride=1, + last_pool_type='max', + **kwargs): + super().__init__() + self.scale = scale + self.block_list = [] + + self.conv1 = ConvBNLayer( + num_channels=3, + filter_size=3, + channels=3, + num_filters=int(32 * scale), + stride=2, + padding=1) + + conv2_1 = DepthwiseSeparable( + num_channels=int(32 * scale), + num_filters1=32, + num_filters2=64, + num_groups=32, + stride=1, + scale=scale) + self.block_list.append(conv2_1) + + conv2_2 = DepthwiseSeparable( + num_channels=int(64 * scale), + num_filters1=64, + num_filters2=128, + num_groups=64, + stride=1, + scale=scale) + self.block_list.append(conv2_2) + + conv3_1 = DepthwiseSeparable( + num_channels=int(128 * scale), + num_filters1=128, + num_filters2=128, + num_groups=128, + stride=1, + scale=scale) + self.block_list.append(conv3_1) + + conv3_2 = DepthwiseSeparable( + num_channels=int(128 * scale), + num_filters1=128, + num_filters2=256, + num_groups=128, + stride=(2, 1), + scale=scale) + self.block_list.append(conv3_2) + + conv4_1 = DepthwiseSeparable( + num_channels=int(256 * scale), + num_filters1=256, + num_filters2=256, + num_groups=256, + stride=1, + scale=scale) + self.block_list.append(conv4_1) + + conv4_2 = DepthwiseSeparable( + num_channels=int(256 * scale), + num_filters1=256, + num_filters2=512, + num_groups=256, + stride=(2, 1), + scale=scale) + self.block_list.append(conv4_2) + + for _ in range(5): + conv5 = DepthwiseSeparable( + num_channels=int(512 * scale), + num_filters1=512, + num_filters2=512, + num_groups=512, + stride=1, + dw_size=5, + padding=2, + scale=scale, + use_se=False) + self.block_list.append(conv5) + + conv5_6 = DepthwiseSeparable( + num_channels=int(512 * scale), + num_filters1=512, + num_filters2=1024, + num_groups=512, + stride=(2, 1), + dw_size=5, + padding=2, + scale=scale, + use_se=True) + self.block_list.append(conv5_6) + + conv6 = DepthwiseSeparable( + num_channels=int(1024 * scale), + num_filters1=1024, + num_filters2=1024, + num_groups=1024, + stride=last_conv_stride, + dw_size=5, + padding=2, + use_se=True, + scale=scale) + self.block_list.append(conv6) + + self.block_list = nn.Sequential(*self.block_list) + if last_pool_type == 'avg': + self.pool = nn.AvgPool2D(kernel_size=2, stride=2, padding=0) + else: + self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0) + self.out_channels = int(1024 * scale) + + def forward(self, inputs): + y = self.conv1(inputs) + y = self.block_list(y) + y = self.pool(y) + return y + + +class SEModule(nn.Layer): + def __init__(self, channel, reduction=4): + super(SEModule, self).__init__() + self.avg_pool = AdaptiveAvgPool2D(1) + self.conv1 = Conv2D( + in_channels=channel, + out_channels=channel // reduction, + kernel_size=1, + stride=1, + padding=0, + weight_attr=ParamAttr(), + bias_attr=ParamAttr()) + self.conv2 = Conv2D( + in_channels=channel // reduction, + out_channels=channel, + kernel_size=1, + stride=1, + padding=0, + weight_attr=ParamAttr(), + bias_attr=ParamAttr()) + + def forward(self, inputs): + outputs = self.avg_pool(inputs) + outputs = self.conv1(outputs) + outputs = F.relu(outputs) + outputs = self.conv2(outputs) + outputs = hardsigmoid(outputs) + return paddle.multiply(x=inputs, y=outputs) diff --git a/backend/ppocr/modeling/backbones/rec_nrtr_mtb.py b/backend/ppocr/modeling/backbones/rec_nrtr_mtb.py new file mode 100644 index 00000000..22e02a63 --- /dev/null +++ b/backend/ppocr/modeling/backbones/rec_nrtr_mtb.py @@ -0,0 +1,48 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle import nn +import paddle + + +class MTB(nn.Layer): + def __init__(self, cnn_num, in_channels): + super(MTB, self).__init__() + self.block = nn.Sequential() + self.out_channels = in_channels + self.cnn_num = cnn_num + if self.cnn_num == 2: + for i in range(self.cnn_num): + self.block.add_sublayer( + 'conv_{}'.format(i), + nn.Conv2D( + in_channels=in_channels + if i == 0 else 32 * (2**(i - 1)), + out_channels=32 * (2**i), + kernel_size=3, + stride=2, + padding=1)) + self.block.add_sublayer('relu_{}'.format(i), nn.ReLU()) + self.block.add_sublayer('bn_{}'.format(i), + nn.BatchNorm2D(32 * (2**i))) + + def forward(self, images): + x = self.block(images) + if self.cnn_num == 2: + # (b, w, h, c) + x = paddle.transpose(x, [0, 3, 2, 1]) + x_shape = paddle.shape(x) + x = paddle.reshape( + x, [x_shape[0], x_shape[1], x_shape[2] * x_shape[3]]) + return x diff --git a/backend/ppocr/modeling/backbones/rec_resnet_31.py b/backend/ppocr/modeling/backbones/rec_resnet_31.py new file mode 100644 index 00000000..96517013 --- /dev/null +++ b/backend/ppocr/modeling/backbones/rec_resnet_31.py @@ -0,0 +1,210 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/layers/conv_layer.py +https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/backbones/resnet31_ocr.py +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import ParamAttr +import paddle.nn as nn +import paddle.nn.functional as F +import numpy as np + +__all__ = ["ResNet31"] + + +def conv3x3(in_channel, out_channel, stride=1): + return nn.Conv2D( + in_channel, + out_channel, + kernel_size=3, + stride=stride, + padding=1, + bias_attr=False) + + +class BasicBlock(nn.Layer): + expansion = 1 + + def __init__(self, in_channels, channels, stride=1, downsample=False): + super().__init__() + self.conv1 = conv3x3(in_channels, channels, stride) + self.bn1 = nn.BatchNorm2D(channels) + self.relu = nn.ReLU() + self.conv2 = conv3x3(channels, channels) + self.bn2 = nn.BatchNorm2D(channels) + self.downsample = downsample + if downsample: + self.downsample = nn.Sequential( + nn.Conv2D( + in_channels, + channels * self.expansion, + 1, + stride, + bias_attr=False), + nn.BatchNorm2D(channels * self.expansion), ) + else: + self.downsample = nn.Sequential() + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet31(nn.Layer): + ''' + Args: + in_channels (int): Number of channels of input image tensor. + layers (list[int]): List of BasicBlock number for each stage. + channels (list[int]): List of out_channels of Conv2d layer. + out_indices (None | Sequence[int]): Indices of output stages. + last_stage_pool (bool): If True, add `MaxPool2d` layer to last stage. + ''' + + def __init__(self, + in_channels=3, + layers=[1, 2, 5, 3], + channels=[64, 128, 256, 256, 512, 512, 512], + out_indices=None, + last_stage_pool=False): + super(ResNet31, self).__init__() + assert isinstance(in_channels, int) + assert isinstance(last_stage_pool, bool) + + self.out_indices = out_indices + self.last_stage_pool = last_stage_pool + + # conv 1 (Conv Conv) + self.conv1_1 = nn.Conv2D( + in_channels, channels[0], kernel_size=3, stride=1, padding=1) + self.bn1_1 = nn.BatchNorm2D(channels[0]) + self.relu1_1 = nn.ReLU() + + self.conv1_2 = nn.Conv2D( + channels[0], channels[1], kernel_size=3, stride=1, padding=1) + self.bn1_2 = nn.BatchNorm2D(channels[1]) + self.relu1_2 = nn.ReLU() + + # conv 2 (Max-pooling, Residual block, Conv) + self.pool2 = nn.MaxPool2D( + kernel_size=2, stride=2, padding=0, ceil_mode=True) + self.block2 = self._make_layer(channels[1], channels[2], layers[0]) + self.conv2 = nn.Conv2D( + channels[2], channels[2], kernel_size=3, stride=1, padding=1) + self.bn2 = nn.BatchNorm2D(channels[2]) + self.relu2 = nn.ReLU() + + # conv 3 (Max-pooling, Residual block, Conv) + self.pool3 = nn.MaxPool2D( + kernel_size=2, stride=2, padding=0, ceil_mode=True) + self.block3 = self._make_layer(channels[2], channels[3], layers[1]) + self.conv3 = nn.Conv2D( + channels[3], channels[3], kernel_size=3, stride=1, padding=1) + self.bn3 = nn.BatchNorm2D(channels[3]) + self.relu3 = nn.ReLU() + + # conv 4 (Max-pooling, Residual block, Conv) + self.pool4 = nn.MaxPool2D( + kernel_size=(2, 1), stride=(2, 1), padding=0, ceil_mode=True) + self.block4 = self._make_layer(channels[3], channels[4], layers[2]) + self.conv4 = nn.Conv2D( + channels[4], channels[4], kernel_size=3, stride=1, padding=1) + self.bn4 = nn.BatchNorm2D(channels[4]) + self.relu4 = nn.ReLU() + + # conv 5 ((Max-pooling), Residual block, Conv) + self.pool5 = None + if self.last_stage_pool: + self.pool5 = nn.MaxPool2D( + kernel_size=2, stride=2, padding=0, ceil_mode=True) + self.block5 = self._make_layer(channels[4], channels[5], layers[3]) + self.conv5 = nn.Conv2D( + channels[5], channels[5], kernel_size=3, stride=1, padding=1) + self.bn5 = nn.BatchNorm2D(channels[5]) + self.relu5 = nn.ReLU() + + self.out_channels = channels[-1] + + def _make_layer(self, input_channels, output_channels, blocks): + layers = [] + for _ in range(blocks): + downsample = None + if input_channels != output_channels: + downsample = nn.Sequential( + nn.Conv2D( + input_channels, + output_channels, + kernel_size=1, + stride=1, + bias_attr=False), + nn.BatchNorm2D(output_channels), ) + + layers.append( + BasicBlock( + input_channels, output_channels, downsample=downsample)) + input_channels = output_channels + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1_1(x) + x = self.bn1_1(x) + x = self.relu1_1(x) + + x = self.conv1_2(x) + x = self.bn1_2(x) + x = self.relu1_2(x) + + outs = [] + for i in range(4): + layer_index = i + 2 + pool_layer = getattr(self, f'pool{layer_index}') + block_layer = getattr(self, f'block{layer_index}') + conv_layer = getattr(self, f'conv{layer_index}') + bn_layer = getattr(self, f'bn{layer_index}') + relu_layer = getattr(self, f'relu{layer_index}') + + if pool_layer is not None: + x = pool_layer(x) + x = block_layer(x) + x = conv_layer(x) + x = bn_layer(x) + x = relu_layer(x) + + outs.append(x) + + if self.out_indices is not None: + return tuple([outs[i] for i in self.out_indices]) + + return x diff --git a/backend/ppocr/modeling/backbones/rec_resnet_aster.py b/backend/ppocr/modeling/backbones/rec_resnet_aster.py new file mode 100644 index 00000000..6a2710df --- /dev/null +++ b/backend/ppocr/modeling/backbones/rec_resnet_aster.py @@ -0,0 +1,143 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/ayumiymk/aster.pytorch/blob/master/lib/models/resnet_aster.py +""" +import paddle +import paddle.nn as nn + +import sys +import math + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2D( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=1, + bias_attr=False) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2D( + in_planes, out_planes, kernel_size=1, stride=stride, bias_attr=False) + + +def get_sinusoid_encoding(n_position, feat_dim, wave_length=10000): + # [n_position] + positions = paddle.arange(0, n_position) + # [feat_dim] + dim_range = paddle.arange(0, feat_dim) + dim_range = paddle.pow(wave_length, 2 * (dim_range // 2) / feat_dim) + # [n_position, feat_dim] + angles = paddle.unsqueeze( + positions, axis=1) / paddle.unsqueeze( + dim_range, axis=0) + angles = paddle.cast(angles, "float32") + angles[:, 0::2] = paddle.sin(angles[:, 0::2]) + angles[:, 1::2] = paddle.cos(angles[:, 1::2]) + return angles + + +class AsterBlock(nn.Layer): + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(AsterBlock, self).__init__() + self.conv1 = conv1x1(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2D(planes) + self.relu = nn.ReLU() + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2D(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + out += residual + out = self.relu(out) + return out + + +class ResNet_ASTER(nn.Layer): + """For aster or crnn""" + + def __init__(self, with_lstm=True, n_group=1, in_channels=3): + super(ResNet_ASTER, self).__init__() + self.with_lstm = with_lstm + self.n_group = n_group + + self.layer0 = nn.Sequential( + nn.Conv2D( + in_channels, + 32, + kernel_size=(3, 3), + stride=1, + padding=1, + bias_attr=False), + nn.BatchNorm2D(32), + nn.ReLU()) + + self.inplanes = 32 + self.layer1 = self._make_layer(32, 3, [2, 2]) # [16, 50] + self.layer2 = self._make_layer(64, 4, [2, 2]) # [8, 25] + self.layer3 = self._make_layer(128, 6, [2, 1]) # [4, 25] + self.layer4 = self._make_layer(256, 6, [2, 1]) # [2, 25] + self.layer5 = self._make_layer(512, 3, [2, 1]) # [1, 25] + + if with_lstm: + self.rnn = nn.LSTM(512, 256, direction="bidirect", num_layers=2) + self.out_channels = 2 * 256 + else: + self.out_channels = 512 + + def _make_layer(self, planes, blocks, stride): + downsample = None + if stride != [1, 1] or self.inplanes != planes: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes, stride), nn.BatchNorm2D(planes)) + + layers = [] + layers.append(AsterBlock(self.inplanes, planes, stride, downsample)) + self.inplanes = planes + for _ in range(1, blocks): + layers.append(AsterBlock(self.inplanes, planes)) + return nn.Sequential(*layers) + + def forward(self, x): + x0 = self.layer0(x) + x1 = self.layer1(x0) + x2 = self.layer2(x1) + x3 = self.layer3(x2) + x4 = self.layer4(x3) + x5 = self.layer5(x4) + + cnn_feat = x5.squeeze(2) # [N, c, w] + cnn_feat = paddle.transpose(cnn_feat, perm=[0, 2, 1]) + if self.with_lstm: + rnn_feat, _ = self.rnn(cnn_feat) + return rnn_feat + else: + return cnn_feat diff --git a/backend/ppocr/modeling/backbones/rec_resnet_vd.py b/backend/ppocr/modeling/backbones/rec_resnet_vd.py index 6837ea0f..0187deb9 100644 --- a/backend/ppocr/modeling/backbones/rec_resnet_vd.py +++ b/backend/ppocr/modeling/backbones/rec_resnet_vd.py @@ -249,7 +249,7 @@ def __init__(self, in_channels=3, layers=50, **kwargs): name=conv_name)) shortcut = True self.block_list.append(bottleneck_block) - self.out_channels = num_filters[block] + self.out_channels = num_filters[block] * 4 else: for block in range(len(depth)): shortcut = False diff --git a/backend/ppocr/modeling/backbones/rec_svtrnet.py b/backend/ppocr/modeling/backbones/rec_svtrnet.py new file mode 100644 index 00000000..c57bf463 --- /dev/null +++ b/backend/ppocr/modeling/backbones/rec_svtrnet.py @@ -0,0 +1,584 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle import ParamAttr +from paddle.nn.initializer import KaimingNormal +import numpy as np +import paddle +import paddle.nn as nn +from paddle.nn.initializer import TruncatedNormal, Constant, Normal + +trunc_normal_ = TruncatedNormal(std=.02) +normal_ = Normal +zeros_ = Constant(value=0.) +ones_ = Constant(value=1.) + + +def drop_path(x, drop_prob=0., training=False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... + """ + if drop_prob == 0. or not training: + return x + keep_prob = paddle.to_tensor(1 - drop_prob) + shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1) + random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype) + random_tensor = paddle.floor(random_tensor) # binarize + output = x.divide(keep_prob) * random_tensor + return output + + +class ConvBNLayer(nn.Layer): + def __init__(self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=0, + bias_attr=False, + groups=1, + act=nn.GELU): + super().__init__() + self.conv = nn.Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + weight_attr=paddle.ParamAttr( + initializer=nn.initializer.KaimingUniform()), + bias_attr=bias_attr) + self.norm = nn.BatchNorm2D(out_channels) + self.act = act() + + def forward(self, inputs): + out = self.conv(inputs) + out = self.norm(out) + out = self.act(out) + return out + + +class DropPath(nn.Layer): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class Identity(nn.Layer): + def __init__(self): + super(Identity, self).__init__() + + def forward(self, input): + return input + + +class Mlp(nn.Layer): + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class ConvMixer(nn.Layer): + def __init__( + self, + dim, + num_heads=8, + HW=[8, 25], + local_k=[3, 3], ): + super().__init__() + self.HW = HW + self.dim = dim + self.local_mixer = nn.Conv2D( + dim, + dim, + local_k, + 1, [local_k[0] // 2, local_k[1] // 2], + groups=num_heads, + weight_attr=ParamAttr(initializer=KaimingNormal())) + + def forward(self, x): + h = self.HW[0] + w = self.HW[1] + x = x.transpose([0, 2, 1]).reshape([0, self.dim, h, w]) + x = self.local_mixer(x) + x = x.flatten(2).transpose([0, 2, 1]) + return x + + +class Attention(nn.Layer): + def __init__(self, + dim, + num_heads=8, + mixer='Global', + HW=[8, 25], + local_k=[7, 11], + qkv_bias=False, + qk_scale=None, + attn_drop=0., + proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.HW = HW + if HW is not None: + H = HW[0] + W = HW[1] + self.N = H * W + self.C = dim + if mixer == 'Local' and HW is not None: + hk = local_k[0] + wk = local_k[1] + mask = paddle.ones([H * W, H + hk - 1, W + wk - 1], dtype='float32') + for h in range(0, H): + for w in range(0, W): + mask[h * W + w, h:h + hk, w:w + wk] = 0. + mask_paddle = mask[:, hk // 2:H + hk // 2, wk // 2:W + wk // + 2].flatten(1) + mask_inf = paddle.full([H * W, H * W], '-inf', dtype='float32') + mask = paddle.where(mask_paddle < 1, mask_paddle, mask_inf) + self.mask = mask.unsqueeze([0, 1]) + self.mixer = mixer + + def forward(self, x): + if self.HW is not None: + N = self.N + C = self.C + else: + _, N, C = x.shape + qkv = self.qkv(x).reshape((0, N, 3, self.num_heads, C // + self.num_heads)).transpose((2, 0, 3, 1, 4)) + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + + attn = (q.matmul(k.transpose((0, 1, 3, 2)))) + if self.mixer == 'Local': + attn += self.mask + attn = nn.functional.softmax(attn, axis=-1) + attn = self.attn_drop(attn) + + x = (attn.matmul(v)).transpose((0, 2, 1, 3)).reshape((0, N, C)) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Layer): + def __init__(self, + dim, + num_heads, + mixer='Global', + local_mixer=[7, 11], + HW=[8, 25], + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer='nn.LayerNorm', + epsilon=1e-6, + prenorm=True): + super().__init__() + if isinstance(norm_layer, str): + self.norm1 = eval(norm_layer)(dim, epsilon=epsilon) + else: + self.norm1 = norm_layer(dim) + if mixer == 'Global' or mixer == 'Local': + self.mixer = Attention( + dim, + num_heads=num_heads, + mixer=mixer, + HW=HW, + local_k=local_mixer, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop) + elif mixer == 'Conv': + self.mixer = ConvMixer( + dim, num_heads=num_heads, HW=HW, local_k=local_mixer) + else: + raise TypeError("The mixer must be one of [Global, Local, Conv]") + + self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity() + if isinstance(norm_layer, str): + self.norm2 = eval(norm_layer)(dim, epsilon=epsilon) + else: + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp_ratio = mlp_ratio + self.mlp = Mlp(in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + self.prenorm = prenorm + + def forward(self, x): + if self.prenorm: + x = self.norm1(x + self.drop_path(self.mixer(x))) + x = self.norm2(x + self.drop_path(self.mlp(x))) + else: + x = x + self.drop_path(self.mixer(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Layer): + """ Image to Patch Embedding + """ + + def __init__(self, + img_size=[32, 100], + in_channels=3, + embed_dim=768, + sub_num=2): + super().__init__() + num_patches = (img_size[1] // (2 ** sub_num)) * \ + (img_size[0] // (2 ** sub_num)) + self.img_size = img_size + self.num_patches = num_patches + self.embed_dim = embed_dim + self.norm = None + if sub_num == 2: + self.proj = nn.Sequential( + ConvBNLayer( + in_channels=in_channels, + out_channels=embed_dim // 2, + kernel_size=3, + stride=2, + padding=1, + act=nn.GELU, + bias_attr=None), + ConvBNLayer( + in_channels=embed_dim // 2, + out_channels=embed_dim, + kernel_size=3, + stride=2, + padding=1, + act=nn.GELU, + bias_attr=None)) + if sub_num == 3: + self.proj = nn.Sequential( + ConvBNLayer( + in_channels=in_channels, + out_channels=embed_dim // 4, + kernel_size=3, + stride=2, + padding=1, + act=nn.GELU, + bias_attr=None), + ConvBNLayer( + in_channels=embed_dim // 4, + out_channels=embed_dim // 2, + kernel_size=3, + stride=2, + padding=1, + act=nn.GELU, + bias_attr=None), + ConvBNLayer( + in_channels=embed_dim // 2, + out_channels=embed_dim, + kernel_size=3, + stride=2, + padding=1, + act=nn.GELU, + bias_attr=None)) + + def forward(self, x): + B, C, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose((0, 2, 1)) + return x + + +class SubSample(nn.Layer): + def __init__(self, + in_channels, + out_channels, + types='Pool', + stride=[2, 1], + sub_norm='nn.LayerNorm', + act=None): + super().__init__() + self.types = types + if types == 'Pool': + self.avgpool = nn.AvgPool2D( + kernel_size=[3, 5], stride=stride, padding=[1, 2]) + self.maxpool = nn.MaxPool2D( + kernel_size=[3, 5], stride=stride, padding=[1, 2]) + self.proj = nn.Linear(in_channels, out_channels) + else: + self.conv = nn.Conv2D( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=1, + weight_attr=ParamAttr(initializer=KaimingNormal())) + self.norm = eval(sub_norm)(out_channels) + if act is not None: + self.act = act() + else: + self.act = None + + def forward(self, x): + + if self.types == 'Pool': + x1 = self.avgpool(x) + x2 = self.maxpool(x) + x = (x1 + x2) * 0.5 + out = self.proj(x.flatten(2).transpose((0, 2, 1))) + else: + x = self.conv(x) + out = x.flatten(2).transpose((0, 2, 1)) + out = self.norm(out) + if self.act is not None: + out = self.act(out) + + return out + + +class SVTRNet(nn.Layer): + def __init__( + self, + img_size=[32, 100], + in_channels=3, + embed_dim=[64, 128, 256], + depth=[3, 6, 3], + num_heads=[2, 4, 8], + mixer=['Local'] * 6 + ['Global'] * + 6, # Local atten, Global atten, Conv + local_mixer=[[7, 11], [7, 11], [7, 11]], + patch_merging='Conv', # Conv, Pool, None + mlp_ratio=4, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + last_drop=0.1, + attn_drop_rate=0., + drop_path_rate=0.1, + norm_layer='nn.LayerNorm', + sub_norm='nn.LayerNorm', + epsilon=1e-6, + out_channels=192, + out_char_num=25, + block_unit='Block', + act='nn.GELU', + last_stage=True, + sub_num=2, + prenorm=True, + use_lenhead=False, + **kwargs): + super().__init__() + self.img_size = img_size + self.embed_dim = embed_dim + self.out_channels = out_channels + self.prenorm = prenorm + patch_merging = None if patch_merging != 'Conv' and patch_merging != 'Pool' else patch_merging + self.patch_embed = PatchEmbed( + img_size=img_size, + in_channels=in_channels, + embed_dim=embed_dim[0], + sub_num=sub_num) + num_patches = self.patch_embed.num_patches + self.HW = [img_size[0] // (2**sub_num), img_size[1] // (2**sub_num)] + self.pos_embed = self.create_parameter( + shape=[1, num_patches, embed_dim[0]], default_initializer=zeros_) + self.add_parameter("pos_embed", self.pos_embed) + self.pos_drop = nn.Dropout(p=drop_rate) + Block_unit = eval(block_unit) + + dpr = np.linspace(0, drop_path_rate, sum(depth)) + self.blocks1 = nn.LayerList([ + Block_unit( + dim=embed_dim[0], + num_heads=num_heads[0], + mixer=mixer[0:depth[0]][i], + HW=self.HW, + local_mixer=local_mixer[0], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + act_layer=eval(act), + attn_drop=attn_drop_rate, + drop_path=dpr[0:depth[0]][i], + norm_layer=norm_layer, + epsilon=epsilon, + prenorm=prenorm) for i in range(depth[0]) + ]) + if patch_merging is not None: + self.sub_sample1 = SubSample( + embed_dim[0], + embed_dim[1], + sub_norm=sub_norm, + stride=[2, 1], + types=patch_merging) + HW = [self.HW[0] // 2, self.HW[1]] + else: + HW = self.HW + self.patch_merging = patch_merging + self.blocks2 = nn.LayerList([ + Block_unit( + dim=embed_dim[1], + num_heads=num_heads[1], + mixer=mixer[depth[0]:depth[0] + depth[1]][i], + HW=HW, + local_mixer=local_mixer[1], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + act_layer=eval(act), + attn_drop=attn_drop_rate, + drop_path=dpr[depth[0]:depth[0] + depth[1]][i], + norm_layer=norm_layer, + epsilon=epsilon, + prenorm=prenorm) for i in range(depth[1]) + ]) + if patch_merging is not None: + self.sub_sample2 = SubSample( + embed_dim[1], + embed_dim[2], + sub_norm=sub_norm, + stride=[2, 1], + types=patch_merging) + HW = [self.HW[0] // 4, self.HW[1]] + else: + HW = self.HW + self.blocks3 = nn.LayerList([ + Block_unit( + dim=embed_dim[2], + num_heads=num_heads[2], + mixer=mixer[depth[0] + depth[1]:][i], + HW=HW, + local_mixer=local_mixer[2], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + act_layer=eval(act), + attn_drop=attn_drop_rate, + drop_path=dpr[depth[0] + depth[1]:][i], + norm_layer=norm_layer, + epsilon=epsilon, + prenorm=prenorm) for i in range(depth[2]) + ]) + self.last_stage = last_stage + if last_stage: + self.avg_pool = nn.AdaptiveAvgPool2D([1, out_char_num]) + self.last_conv = nn.Conv2D( + in_channels=embed_dim[2], + out_channels=self.out_channels, + kernel_size=1, + stride=1, + padding=0, + bias_attr=False) + self.hardswish = nn.Hardswish() + self.dropout = nn.Dropout(p=last_drop, mode="downscale_in_infer") + if not prenorm: + self.norm = eval(norm_layer)(embed_dim[-1], epsilon=epsilon) + self.use_lenhead = use_lenhead + if use_lenhead: + self.len_conv = nn.Linear(embed_dim[2], self.out_channels) + self.hardswish_len = nn.Hardswish() + self.dropout_len = nn.Dropout( + p=last_drop, mode="downscale_in_infer") + + trunc_normal_(self.pos_embed) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + zeros_(m.bias) + ones_(m.weight) + + def forward_features(self, x): + x = self.patch_embed(x) + x = x + self.pos_embed + x = self.pos_drop(x) + for blk in self.blocks1: + x = blk(x) + if self.patch_merging is not None: + x = self.sub_sample1( + x.transpose([0, 2, 1]).reshape( + [0, self.embed_dim[0], self.HW[0], self.HW[1]])) + for blk in self.blocks2: + x = blk(x) + if self.patch_merging is not None: + x = self.sub_sample2( + x.transpose([0, 2, 1]).reshape( + [0, self.embed_dim[1], self.HW[0] // 2, self.HW[1]])) + for blk in self.blocks3: + x = blk(x) + if not self.prenorm: + x = self.norm(x) + return x + + def forward(self, x): + x = self.forward_features(x) + if self.use_lenhead: + len_x = self.len_conv(x.mean(1)) + len_x = self.dropout_len(self.hardswish_len(len_x)) + if self.last_stage: + if self.patch_merging is not None: + h = self.HW[0] // 4 + else: + h = self.HW[0] + x = self.avg_pool( + x.transpose([0, 2, 1]).reshape( + [0, self.embed_dim[2], h, self.HW[1]])) + x = self.last_conv(x) + x = self.hardswish(x) + x = self.dropout(x) + if self.use_lenhead: + return x, len_x + return x diff --git a/backend/ppocr/modeling/backbones/vqa_layoutlm.py b/backend/ppocr/modeling/backbones/vqa_layoutlm.py new file mode 100644 index 00000000..ede5b7a3 --- /dev/null +++ b/backend/ppocr/modeling/backbones/vqa_layoutlm.py @@ -0,0 +1,172 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +from paddle import nn + +from paddlenlp.transformers import LayoutXLMModel, LayoutXLMForTokenClassification, LayoutXLMForRelationExtraction +from paddlenlp.transformers import LayoutLMModel, LayoutLMForTokenClassification +from paddlenlp.transformers import LayoutLMv2Model, LayoutLMv2ForTokenClassification, LayoutLMv2ForRelationExtraction + +__all__ = ["LayoutXLMForSer", 'LayoutLMForSer'] + +pretrained_model_dict = { + LayoutXLMModel: 'layoutxlm-base-uncased', + LayoutLMModel: 'layoutlm-base-uncased', + LayoutLMv2Model: 'layoutlmv2-base-uncased' +} + + +class NLPBaseModel(nn.Layer): + def __init__(self, + base_model_class, + model_class, + type='ser', + pretrained=True, + checkpoints=None, + **kwargs): + super(NLPBaseModel, self).__init__() + if checkpoints is not None: + self.model = model_class.from_pretrained(checkpoints) + else: + pretrained_model_name = pretrained_model_dict[base_model_class] + if pretrained: + base_model = base_model_class.from_pretrained( + pretrained_model_name) + else: + base_model = base_model_class( + **base_model_class.pretrained_init_configuration[ + pretrained_model_name]) + if type == 'ser': + self.model = model_class( + base_model, num_classes=kwargs['num_classes'], dropout=None) + else: + self.model = model_class(base_model, dropout=None) + self.out_channels = 1 + + +class LayoutLMForSer(NLPBaseModel): + def __init__(self, num_classes, pretrained=True, checkpoints=None, + **kwargs): + super(LayoutLMForSer, self).__init__( + LayoutLMModel, + LayoutLMForTokenClassification, + 'ser', + pretrained, + checkpoints, + num_classes=num_classes) + + def forward(self, x): + x = self.model( + input_ids=x[0], + bbox=x[2], + attention_mask=x[4], + token_type_ids=x[5], + position_ids=None, + output_hidden_states=False) + return x + + +class LayoutLMv2ForSer(NLPBaseModel): + def __init__(self, num_classes, pretrained=True, checkpoints=None, + **kwargs): + super(LayoutLMv2ForSer, self).__init__( + LayoutLMv2Model, + LayoutLMv2ForTokenClassification, + 'ser', + pretrained, + checkpoints, + num_classes=num_classes) + + def forward(self, x): + x = self.model( + input_ids=x[0], + bbox=x[2], + image=x[3], + attention_mask=x[4], + token_type_ids=x[5], + position_ids=None, + head_mask=None, + labels=None) + return x[0] + + +class LayoutXLMForSer(NLPBaseModel): + def __init__(self, num_classes, pretrained=True, checkpoints=None, + **kwargs): + super(LayoutXLMForSer, self).__init__( + LayoutXLMModel, + LayoutXLMForTokenClassification, + 'ser', + pretrained, + checkpoints, + num_classes=num_classes) + + def forward(self, x): + x = self.model( + input_ids=x[0], + bbox=x[2], + image=x[3], + attention_mask=x[4], + token_type_ids=x[5], + position_ids=None, + head_mask=None, + labels=None) + return x[0] + + +class LayoutLMv2ForRe(NLPBaseModel): + def __init__(self, pretrained=True, checkpoints=None, **kwargs): + super(LayoutLMv2ForRe, self).__init__(LayoutLMv2Model, + LayoutLMv2ForRelationExtraction, + 're', pretrained, checkpoints) + + def forward(self, x): + x = self.model( + input_ids=x[0], + bbox=x[1], + labels=None, + image=x[2], + attention_mask=x[3], + token_type_ids=x[4], + position_ids=None, + head_mask=None, + entities=x[5], + relations=x[6]) + return x + + +class LayoutXLMForRe(NLPBaseModel): + def __init__(self, pretrained=True, checkpoints=None, **kwargs): + super(LayoutXLMForRe, self).__init__(LayoutXLMModel, + LayoutXLMForRelationExtraction, + 're', pretrained, checkpoints) + + def forward(self, x): + x = self.model( + input_ids=x[0], + bbox=x[1], + labels=None, + image=x[2], + attention_mask=x[3], + token_type_ids=x[4], + position_ids=None, + head_mask=None, + entities=x[5], + relations=x[6]) + return x diff --git a/backend/ppocr/modeling/heads/__init__.py b/backend/ppocr/modeling/heads/__init__.py index efe05718..1670ea38 100755 --- a/backend/ppocr/modeling/heads/__init__.py +++ b/backend/ppocr/modeling/heads/__init__.py @@ -20,19 +20,37 @@ def build_head(config): from .det_db_head import DBHead from .det_east_head import EASTHead from .det_sast_head import SASTHead + from .det_pse_head import PSEHead + from .det_fce_head import FCEHead + from .e2e_pg_head import PGHead # rec head from .rec_ctc_head import CTCHead from .rec_att_head import AttentionHead from .rec_srn_head import SRNHead + from .rec_nrtr_head import Transformer + from .rec_sar_head import SARHead + from .rec_aster_head import AsterHead + from .rec_pren_head import PRENHead + from .rec_multi_head import MultiHead # cls head from .cls_head import ClsHead + + #kie head + from .kie_sdmgr_head import SDMGRHead + + from .table_att_head import TableAttentionHead + support_dict = [ - 'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead', - 'SRNHead' + 'DBHead', 'PSEHead', 'FCEHead', 'EASTHead', 'SASTHead', 'CTCHead', + 'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer', + 'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead', + 'MultiHead' ] + #table head + module_name = config.pop('name') assert module_name in support_dict, Exception('head only support {}'.format( support_dict)) diff --git a/backend/ppocr/modeling/heads/cls_head.py b/backend/ppocr/modeling/heads/cls_head.py index d9b78b84..91bfa615 100644 --- a/backend/ppocr/modeling/heads/cls_head.py +++ b/backend/ppocr/modeling/heads/cls_head.py @@ -43,7 +43,7 @@ def __init__(self, in_channels, class_dim, **kwargs): initializer=nn.initializer.Uniform(-stdv, stdv)), bias_attr=ParamAttr(name="fc_0.b_0"), ) - def forward(self, x): + def forward(self, x, targets=None): x = self.pool(x) x = paddle.reshape(x, shape=[x.shape[0], x.shape[1]]) x = self.fc(x) diff --git a/backend/ppocr/modeling/heads/det_db_head.py b/backend/ppocr/modeling/heads/det_db_head.py index ca18d74a..a686ae5a 100644 --- a/backend/ppocr/modeling/heads/det_db_head.py +++ b/backend/ppocr/modeling/heads/det_db_head.py @@ -23,64 +23,54 @@ from paddle import ParamAttr -def get_bias_attr(k, name): +def get_bias_attr(k): stdv = 1.0 / math.sqrt(k * 1.0) initializer = paddle.nn.initializer.Uniform(-stdv, stdv) - bias_attr = ParamAttr(initializer=initializer, name=name + "_b_attr") + bias_attr = ParamAttr(initializer=initializer) return bias_attr class Head(nn.Layer): - def __init__(self, in_channels, name_list): + def __init__(self, in_channels, name_list, kernel_list=[3, 2, 2], **kwargs): super(Head, self).__init__() + self.conv1 = nn.Conv2D( in_channels=in_channels, out_channels=in_channels // 4, - kernel_size=3, - padding=1, - weight_attr=ParamAttr(name=name_list[0] + '.w_0'), + kernel_size=kernel_list[0], + padding=int(kernel_list[0] // 2), + weight_attr=ParamAttr(), bias_attr=False) self.conv_bn1 = nn.BatchNorm( num_channels=in_channels // 4, param_attr=ParamAttr( - name=name_list[1] + '.w_0', initializer=paddle.nn.initializer.Constant(value=1.0)), bias_attr=ParamAttr( - name=name_list[1] + '.b_0', initializer=paddle.nn.initializer.Constant(value=1e-4)), - moving_mean_name=name_list[1] + '.w_1', - moving_variance_name=name_list[1] + '.w_2', act='relu') self.conv2 = nn.Conv2DTranspose( in_channels=in_channels // 4, out_channels=in_channels // 4, - kernel_size=2, + kernel_size=kernel_list[1], stride=2, weight_attr=ParamAttr( - name=name_list[2] + '.w_0', initializer=paddle.nn.initializer.KaimingUniform()), - bias_attr=get_bias_attr(in_channels // 4, name_list[-1] + "conv2")) + bias_attr=get_bias_attr(in_channels // 4)) self.conv_bn2 = nn.BatchNorm( num_channels=in_channels // 4, param_attr=ParamAttr( - name=name_list[3] + '.w_0', initializer=paddle.nn.initializer.Constant(value=1.0)), bias_attr=ParamAttr( - name=name_list[3] + '.b_0', initializer=paddle.nn.initializer.Constant(value=1e-4)), - moving_mean_name=name_list[3] + '.w_1', - moving_variance_name=name_list[3] + '.w_2', act="relu") self.conv3 = nn.Conv2DTranspose( in_channels=in_channels // 4, out_channels=1, - kernel_size=2, + kernel_size=kernel_list[2], stride=2, weight_attr=ParamAttr( - name=name_list[4] + '.w_0', initializer=paddle.nn.initializer.KaimingUniform()), - bias_attr=get_bias_attr(in_channels // 4, name_list[-1] + "conv3"), - ) + bias_attr=get_bias_attr(in_channels // 4), ) def forward(self, x): x = self.conv1(x) @@ -111,13 +101,13 @@ def __init__(self, in_channels, k=50, **kwargs): 'conv2d_57', 'batch_norm_49', 'conv2d_transpose_2', 'batch_norm_50', 'conv2d_transpose_3', 'thresh' ] - self.binarize = Head(in_channels, binarize_name_list) - self.thresh = Head(in_channels, thresh_name_list) + self.binarize = Head(in_channels, binarize_name_list, **kwargs) + self.thresh = Head(in_channels, thresh_name_list, **kwargs) def step_function(self, x, y): return paddle.reciprocal(1 + paddle.exp(-self.k * (x - y))) - def forward(self, x): + def forward(self, x, targets=None): shrink_maps = self.binarize(x) if not self.training: return {'maps': shrink_maps} diff --git a/backend/ppocr/modeling/heads/det_east_head.py b/backend/ppocr/modeling/heads/det_east_head.py index 9d0c3c4c..004eb5d7 100644 --- a/backend/ppocr/modeling/heads/det_east_head.py +++ b/backend/ppocr/modeling/heads/det_east_head.py @@ -109,7 +109,7 @@ def __init__(self, in_channels, model_name, **kwargs): act=None, name="f_geo") - def forward(self, x): + def forward(self, x, targets=None): f_det = self.det_conv1(x) f_det = self.det_conv2(f_det) f_score = self.score_conv(f_det) diff --git a/backend/ppocr/modeling/heads/det_fce_head.py b/backend/ppocr/modeling/heads/det_fce_head.py new file mode 100644 index 00000000..9503989f --- /dev/null +++ b/backend/ppocr/modeling/heads/det_fce_head.py @@ -0,0 +1,99 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textdet/dense_heads/fce_head.py +""" + +from paddle import nn +from paddle import ParamAttr +import paddle.nn.functional as F +from paddle.nn.initializer import Normal +import paddle +from functools import partial + + +def multi_apply(func, *args, **kwargs): + pfunc = partial(func, **kwargs) if kwargs else func + map_results = map(pfunc, *args) + return tuple(map(list, zip(*map_results))) + + +class FCEHead(nn.Layer): + """The class for implementing FCENet head. + FCENet(CVPR2021): Fourier Contour Embedding for Arbitrary-shaped Text + Detection. + + [https://arxiv.org/abs/2104.10442] + + Args: + in_channels (int): The number of input channels. + scales (list[int]) : The scale of each layer. + fourier_degree (int) : The maximum Fourier transform degree k. + """ + + def __init__(self, in_channels, fourier_degree=5): + super().__init__() + assert isinstance(in_channels, int) + + self.downsample_ratio = 1.0 + self.in_channels = in_channels + self.fourier_degree = fourier_degree + self.out_channels_cls = 4 + self.out_channels_reg = (2 * self.fourier_degree + 1) * 2 + + self.out_conv_cls = nn.Conv2D( + in_channels=self.in_channels, + out_channels=self.out_channels_cls, + kernel_size=3, + stride=1, + padding=1, + groups=1, + weight_attr=ParamAttr( + name='cls_weights', + initializer=Normal( + mean=0., std=0.01)), + bias_attr=True) + self.out_conv_reg = nn.Conv2D( + in_channels=self.in_channels, + out_channels=self.out_channels_reg, + kernel_size=3, + stride=1, + padding=1, + groups=1, + weight_attr=ParamAttr( + name='reg_weights', + initializer=Normal( + mean=0., std=0.01)), + bias_attr=True) + + def forward(self, feats, targets=None): + cls_res, reg_res = multi_apply(self.forward_single, feats) + level_num = len(cls_res) + outs = {} + if not self.training: + for i in range(level_num): + tr_pred = F.softmax(cls_res[i][:, 0:2, :, :], axis=1) + tcl_pred = F.softmax(cls_res[i][:, 2:, :, :], axis=1) + outs['level_{}'.format(i)] = paddle.concat( + [tr_pred, tcl_pred, reg_res[i]], axis=1) + else: + preds = [[cls_res[i], reg_res[i]] for i in range(level_num)] + outs['levels'] = preds + return outs + + def forward_single(self, x): + cls_predict = self.out_conv_cls(x) + reg_predict = self.out_conv_reg(x) + return cls_predict, reg_predict diff --git a/backend/ppocr/modeling/heads/det_pse_head.py b/backend/ppocr/modeling/heads/det_pse_head.py new file mode 100644 index 00000000..32a5b48e --- /dev/null +++ b/backend/ppocr/modeling/heads/det_pse_head.py @@ -0,0 +1,37 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/whai362/PSENet/blob/python3/models/head/psenet_head.py +""" + +from paddle import nn + + +class PSEHead(nn.Layer): + def __init__(self, in_channels, hidden_dim=256, out_channels=7, **kwargs): + super(PSEHead, self).__init__() + self.conv1 = nn.Conv2D( + in_channels, hidden_dim, kernel_size=3, stride=1, padding=1) + self.bn1 = nn.BatchNorm2D(hidden_dim) + self.relu1 = nn.ReLU() + + self.conv2 = nn.Conv2D( + hidden_dim, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x, **kwargs): + out = self.conv1(x) + out = self.relu1(self.bn1(out)) + out = self.conv2(out) + return {'maps': out} diff --git a/backend/ppocr/modeling/heads/det_sast_head.py b/backend/ppocr/modeling/heads/det_sast_head.py index 263b2867..7a88a2db 100644 --- a/backend/ppocr/modeling/heads/det_sast_head.py +++ b/backend/ppocr/modeling/heads/det_sast_head.py @@ -116,7 +116,7 @@ def __init__(self, in_channels, **kwargs): self.head1 = SAST_Header1(in_channels) self.head2 = SAST_Header2(in_channels) - def forward(self, x): + def forward(self, x, targets=None): f_score, f_border = self.head1(x) f_tvo, f_tco = self.head2(x) diff --git a/backend/ppocr/modeling/heads/e2e_pg_head.py b/backend/ppocr/modeling/heads/e2e_pg_head.py new file mode 100644 index 00000000..274e1cda --- /dev/null +++ b/backend/ppocr/modeling/heads/e2e_pg_head.py @@ -0,0 +1,253 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import paddle +from paddle import nn +import paddle.nn.functional as F +from paddle import ParamAttr + + +class ConvBNLayer(nn.Layer): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + groups=1, + if_act=True, + act=None, + name=None): + super(ConvBNLayer, self).__init__() + self.if_act = if_act + self.act = act + self.conv = nn.Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + weight_attr=ParamAttr(name=name + '_weights'), + bias_attr=False) + + self.bn = nn.BatchNorm( + num_channels=out_channels, + act=act, + param_attr=ParamAttr(name="bn_" + name + "_scale"), + bias_attr=ParamAttr(name="bn_" + name + "_offset"), + moving_mean_name="bn_" + name + "_mean", + moving_variance_name="bn_" + name + "_variance", + use_global_stats=False) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + +class PGHead(nn.Layer): + """ + """ + + def __init__(self, in_channels, **kwargs): + super(PGHead, self).__init__() + self.conv_f_score1 = ConvBNLayer( + in_channels=in_channels, + out_channels=64, + kernel_size=1, + stride=1, + padding=0, + act='relu', + name="conv_f_score{}".format(1)) + self.conv_f_score2 = ConvBNLayer( + in_channels=64, + out_channels=64, + kernel_size=3, + stride=1, + padding=1, + act='relu', + name="conv_f_score{}".format(2)) + self.conv_f_score3 = ConvBNLayer( + in_channels=64, + out_channels=128, + kernel_size=1, + stride=1, + padding=0, + act='relu', + name="conv_f_score{}".format(3)) + + self.conv1 = nn.Conv2D( + in_channels=128, + out_channels=1, + kernel_size=3, + stride=1, + padding=1, + groups=1, + weight_attr=ParamAttr(name="conv_f_score{}".format(4)), + bias_attr=False) + + self.conv_f_boder1 = ConvBNLayer( + in_channels=in_channels, + out_channels=64, + kernel_size=1, + stride=1, + padding=0, + act='relu', + name="conv_f_boder{}".format(1)) + self.conv_f_boder2 = ConvBNLayer( + in_channels=64, + out_channels=64, + kernel_size=3, + stride=1, + padding=1, + act='relu', + name="conv_f_boder{}".format(2)) + self.conv_f_boder3 = ConvBNLayer( + in_channels=64, + out_channels=128, + kernel_size=1, + stride=1, + padding=0, + act='relu', + name="conv_f_boder{}".format(3)) + self.conv2 = nn.Conv2D( + in_channels=128, + out_channels=4, + kernel_size=3, + stride=1, + padding=1, + groups=1, + weight_attr=ParamAttr(name="conv_f_boder{}".format(4)), + bias_attr=False) + self.conv_f_char1 = ConvBNLayer( + in_channels=in_channels, + out_channels=128, + kernel_size=1, + stride=1, + padding=0, + act='relu', + name="conv_f_char{}".format(1)) + self.conv_f_char2 = ConvBNLayer( + in_channels=128, + out_channels=128, + kernel_size=3, + stride=1, + padding=1, + act='relu', + name="conv_f_char{}".format(2)) + self.conv_f_char3 = ConvBNLayer( + in_channels=128, + out_channels=256, + kernel_size=1, + stride=1, + padding=0, + act='relu', + name="conv_f_char{}".format(3)) + self.conv_f_char4 = ConvBNLayer( + in_channels=256, + out_channels=256, + kernel_size=3, + stride=1, + padding=1, + act='relu', + name="conv_f_char{}".format(4)) + self.conv_f_char5 = ConvBNLayer( + in_channels=256, + out_channels=256, + kernel_size=1, + stride=1, + padding=0, + act='relu', + name="conv_f_char{}".format(5)) + self.conv3 = nn.Conv2D( + in_channels=256, + out_channels=37, + kernel_size=3, + stride=1, + padding=1, + groups=1, + weight_attr=ParamAttr(name="conv_f_char{}".format(6)), + bias_attr=False) + + self.conv_f_direc1 = ConvBNLayer( + in_channels=in_channels, + out_channels=64, + kernel_size=1, + stride=1, + padding=0, + act='relu', + name="conv_f_direc{}".format(1)) + self.conv_f_direc2 = ConvBNLayer( + in_channels=64, + out_channels=64, + kernel_size=3, + stride=1, + padding=1, + act='relu', + name="conv_f_direc{}".format(2)) + self.conv_f_direc3 = ConvBNLayer( + in_channels=64, + out_channels=128, + kernel_size=1, + stride=1, + padding=0, + act='relu', + name="conv_f_direc{}".format(3)) + self.conv4 = nn.Conv2D( + in_channels=128, + out_channels=2, + kernel_size=3, + stride=1, + padding=1, + groups=1, + weight_attr=ParamAttr(name="conv_f_direc{}".format(4)), + bias_attr=False) + + def forward(self, x, targets=None): + f_score = self.conv_f_score1(x) + f_score = self.conv_f_score2(f_score) + f_score = self.conv_f_score3(f_score) + f_score = self.conv1(f_score) + f_score = F.sigmoid(f_score) + + # f_border + f_border = self.conv_f_boder1(x) + f_border = self.conv_f_boder2(f_border) + f_border = self.conv_f_boder3(f_border) + f_border = self.conv2(f_border) + + f_char = self.conv_f_char1(x) + f_char = self.conv_f_char2(f_char) + f_char = self.conv_f_char3(f_char) + f_char = self.conv_f_char4(f_char) + f_char = self.conv_f_char5(f_char) + f_char = self.conv3(f_char) + + f_direction = self.conv_f_direc1(x) + f_direction = self.conv_f_direc2(f_direction) + f_direction = self.conv_f_direc3(f_direction) + f_direction = self.conv4(f_direction) + + predicts = {} + predicts['f_score'] = f_score + predicts['f_border'] = f_border + predicts['f_char'] = f_char + predicts['f_direction'] = f_direction + return predicts diff --git a/backend/ppocr/modeling/heads/kie_sdmgr_head.py b/backend/ppocr/modeling/heads/kie_sdmgr_head.py new file mode 100644 index 00000000..ac5f73fa --- /dev/null +++ b/backend/ppocr/modeling/heads/kie_sdmgr_head.py @@ -0,0 +1,207 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# reference from : https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/kie/heads/sdmgr_head.py + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import paddle +from paddle import nn +import paddle.nn.functional as F +from paddle import ParamAttr + + +class SDMGRHead(nn.Layer): + def __init__(self, + in_channels, + num_chars=92, + visual_dim=16, + fusion_dim=1024, + node_input=32, + node_embed=256, + edge_input=5, + edge_embed=256, + num_gnn=2, + num_classes=26, + bidirectional=False): + super().__init__() + + self.fusion = Block([visual_dim, node_embed], node_embed, fusion_dim) + self.node_embed = nn.Embedding(num_chars, node_input, 0) + hidden = node_embed // 2 if bidirectional else node_embed + self.rnn = nn.LSTM( + input_size=node_input, hidden_size=hidden, num_layers=1) + self.edge_embed = nn.Linear(edge_input, edge_embed) + self.gnn_layers = nn.LayerList( + [GNNLayer(node_embed, edge_embed) for _ in range(num_gnn)]) + self.node_cls = nn.Linear(node_embed, num_classes) + self.edge_cls = nn.Linear(edge_embed, 2) + + def forward(self, input, targets): + relations, texts, x = input + node_nums, char_nums = [], [] + for text in texts: + node_nums.append(text.shape[0]) + char_nums.append(paddle.sum((text > -1).astype(int), axis=-1)) + + max_num = max([char_num.max() for char_num in char_nums]) + all_nodes = paddle.concat([ + paddle.concat( + [text, paddle.zeros( + (text.shape[0], max_num - text.shape[1]))], -1) + for text in texts + ]) + temp = paddle.clip(all_nodes, min=0).astype(int) + embed_nodes = self.node_embed(temp) + rnn_nodes, _ = self.rnn(embed_nodes) + + b, h, w = rnn_nodes.shape + nodes = paddle.zeros([b, w]) + all_nums = paddle.concat(char_nums) + valid = paddle.nonzero((all_nums > 0).astype(int)) + temp_all_nums = ( + paddle.gather(all_nums, valid) - 1).unsqueeze(-1).unsqueeze(-1) + temp_all_nums = paddle.expand(temp_all_nums, [ + temp_all_nums.shape[0], temp_all_nums.shape[1], rnn_nodes.shape[-1] + ]) + temp_all_nodes = paddle.gather(rnn_nodes, valid) + N, C, A = temp_all_nodes.shape + one_hot = F.one_hot( + temp_all_nums[:, 0, :], num_classes=C).transpose([0, 2, 1]) + one_hot = paddle.multiply( + temp_all_nodes, one_hot.astype("float32")).sum(axis=1, keepdim=True) + t = one_hot.expand([N, 1, A]).squeeze(1) + nodes = paddle.scatter(nodes, valid.squeeze(1), t) + + if x is not None: + nodes = self.fusion([x, nodes]) + + all_edges = paddle.concat( + [rel.reshape([-1, rel.shape[-1]]) for rel in relations]) + embed_edges = self.edge_embed(all_edges.astype('float32')) + embed_edges = F.normalize(embed_edges) + + for gnn_layer in self.gnn_layers: + nodes, cat_nodes = gnn_layer(nodes, embed_edges, node_nums) + + node_cls, edge_cls = self.node_cls(nodes), self.edge_cls(cat_nodes) + return node_cls, edge_cls + + +class GNNLayer(nn.Layer): + def __init__(self, node_dim=256, edge_dim=256): + super().__init__() + self.in_fc = nn.Linear(node_dim * 2 + edge_dim, node_dim) + self.coef_fc = nn.Linear(node_dim, 1) + self.out_fc = nn.Linear(node_dim, node_dim) + self.relu = nn.ReLU() + + def forward(self, nodes, edges, nums): + start, cat_nodes = 0, [] + for num in nums: + sample_nodes = nodes[start:start + num] + cat_nodes.append( + paddle.concat([ + paddle.expand(sample_nodes.unsqueeze(1), [-1, num, -1]), + paddle.expand(sample_nodes.unsqueeze(0), [num, -1, -1]) + ], -1).reshape([num**2, -1])) + start += num + cat_nodes = paddle.concat([paddle.concat(cat_nodes), edges], -1) + cat_nodes = self.relu(self.in_fc(cat_nodes)) + coefs = self.coef_fc(cat_nodes) + + start, residuals = 0, [] + for num in nums: + residual = F.softmax( + -paddle.eye(num).unsqueeze(-1) * 1e9 + + coefs[start:start + num**2].reshape([num, num, -1]), 1) + residuals.append((residual * cat_nodes[start:start + num**2] + .reshape([num, num, -1])).sum(1)) + start += num**2 + + nodes += self.relu(self.out_fc(paddle.concat(residuals))) + return [nodes, cat_nodes] + + +class Block(nn.Layer): + def __init__(self, + input_dims, + output_dim, + mm_dim=1600, + chunks=20, + rank=15, + shared=False, + dropout_input=0., + dropout_pre_lin=0., + dropout_output=0., + pos_norm='before_cat'): + super().__init__() + self.rank = rank + self.dropout_input = dropout_input + self.dropout_pre_lin = dropout_pre_lin + self.dropout_output = dropout_output + assert (pos_norm in ['before_cat', 'after_cat']) + self.pos_norm = pos_norm + # Modules + self.linear0 = nn.Linear(input_dims[0], mm_dim) + self.linear1 = (self.linear0 + if shared else nn.Linear(input_dims[1], mm_dim)) + self.merge_linears0 = nn.LayerList() + self.merge_linears1 = nn.LayerList() + self.chunks = self.chunk_sizes(mm_dim, chunks) + for size in self.chunks: + ml0 = nn.Linear(size, size * rank) + self.merge_linears0.append(ml0) + ml1 = ml0 if shared else nn.Linear(size, size * rank) + self.merge_linears1.append(ml1) + self.linear_out = nn.Linear(mm_dim, output_dim) + + def forward(self, x): + x0 = self.linear0(x[0]) + x1 = self.linear1(x[1]) + bs = x1.shape[0] + if self.dropout_input > 0: + x0 = F.dropout(x0, p=self.dropout_input, training=self.training) + x1 = F.dropout(x1, p=self.dropout_input, training=self.training) + x0_chunks = paddle.split(x0, self.chunks, -1) + x1_chunks = paddle.split(x1, self.chunks, -1) + zs = [] + for x0_c, x1_c, m0, m1 in zip(x0_chunks, x1_chunks, self.merge_linears0, + self.merge_linears1): + m = m0(x0_c) * m1(x1_c) # bs x split_size*rank + m = m.reshape([bs, self.rank, -1]) + z = paddle.sum(m, 1) + if self.pos_norm == 'before_cat': + z = paddle.sqrt(F.relu(z)) - paddle.sqrt(F.relu(-z)) + z = F.normalize(z) + zs.append(z) + z = paddle.concat(zs, 1) + if self.pos_norm == 'after_cat': + z = paddle.sqrt(F.relu(z)) - paddle.sqrt(F.relu(-z)) + z = F.normalize(z) + + if self.dropout_pre_lin > 0: + z = F.dropout(z, p=self.dropout_pre_lin, training=self.training) + z = self.linear_out(z) + if self.dropout_output > 0: + z = F.dropout(z, p=self.dropout_output, training=self.training) + return z + + def chunk_sizes(self, dim, chunks): + split_size = (dim + chunks - 1) // chunks + sizes_list = [split_size] * chunks + sizes_list[-1] = sizes_list[-1] - (sum(sizes_list) - dim) + return sizes_list diff --git a/backend/ppocr/modeling/heads/multiheadAttention.py b/backend/ppocr/modeling/heads/multiheadAttention.py new file mode 100755 index 00000000..900865ba --- /dev/null +++ b/backend/ppocr/modeling/heads/multiheadAttention.py @@ -0,0 +1,163 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle import nn +import paddle.nn.functional as F +from paddle.nn import Linear +from paddle.nn.initializer import XavierUniform as xavier_uniform_ +from paddle.nn.initializer import Constant as constant_ +from paddle.nn.initializer import XavierNormal as xavier_normal_ + +zeros_ = constant_(value=0.) +ones_ = constant_(value=1.) + + +class MultiheadAttention(nn.Layer): + """Allows the model to jointly attend to information + from different representation subspaces. + See reference: Attention Is All You Need + + .. math:: + \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O + \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) + + Args: + embed_dim: total dimension of the model + num_heads: parallel attention layers, or heads + + """ + + def __init__(self, + embed_dim, + num_heads, + dropout=0., + bias=True, + add_bias_kv=False, + add_zero_attn=False): + super(MultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim**-0.5 + self.out_proj = Linear(embed_dim, embed_dim, bias_attr=bias) + self._reset_parameters() + self.conv1 = paddle.nn.Conv2D( + in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1)) + self.conv2 = paddle.nn.Conv2D( + in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1)) + self.conv3 = paddle.nn.Conv2D( + in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1)) + + def _reset_parameters(self): + xavier_uniform_(self.out_proj.weight) + + def forward(self, + query, + key, + value, + key_padding_mask=None, + incremental_state=None, + attn_mask=None): + """ + Inputs of forward function + query: [target length, batch size, embed dim] + key: [sequence length, batch size, embed dim] + value: [sequence length, batch size, embed dim] + key_padding_mask: if True, mask padding based on batch size + incremental_state: if provided, previous time steps are cashed + need_weights: output attn_output_weights + static_kv: key and value are static + + Outputs of forward function + attn_output: [target length, batch size, embed dim] + attn_output_weights: [batch size, target length, sequence length] + """ + q_shape = paddle.shape(query) + src_shape = paddle.shape(key) + q = self._in_proj_q(query) + k = self._in_proj_k(key) + v = self._in_proj_v(value) + q *= self.scaling + q = paddle.transpose( + paddle.reshape( + q, [q_shape[0], q_shape[1], self.num_heads, self.head_dim]), + [1, 2, 0, 3]) + k = paddle.transpose( + paddle.reshape( + k, [src_shape[0], q_shape[1], self.num_heads, self.head_dim]), + [1, 2, 0, 3]) + v = paddle.transpose( + paddle.reshape( + v, [src_shape[0], q_shape[1], self.num_heads, self.head_dim]), + [1, 2, 0, 3]) + if key_padding_mask is not None: + assert key_padding_mask.shape[0] == q_shape[1] + assert key_padding_mask.shape[1] == src_shape[0] + attn_output_weights = paddle.matmul(q, + paddle.transpose(k, [0, 1, 3, 2])) + if attn_mask is not None: + attn_mask = paddle.unsqueeze(paddle.unsqueeze(attn_mask, 0), 0) + attn_output_weights += attn_mask + if key_padding_mask is not None: + attn_output_weights = paddle.reshape( + attn_output_weights, + [q_shape[1], self.num_heads, q_shape[0], src_shape[0]]) + key = paddle.unsqueeze(paddle.unsqueeze(key_padding_mask, 1), 2) + key = paddle.cast(key, 'float32') + y = paddle.full( + shape=paddle.shape(key), dtype='float32', fill_value='-inf') + y = paddle.where(key == 0., key, y) + attn_output_weights += y + attn_output_weights = F.softmax( + attn_output_weights.astype('float32'), + axis=-1, + dtype=paddle.float32 if attn_output_weights.dtype == paddle.float16 + else attn_output_weights.dtype) + attn_output_weights = F.dropout( + attn_output_weights, p=self.dropout, training=self.training) + + attn_output = paddle.matmul(attn_output_weights, v) + attn_output = paddle.reshape( + paddle.transpose(attn_output, [2, 0, 1, 3]), + [q_shape[0], q_shape[1], self.embed_dim]) + attn_output = self.out_proj(attn_output) + + return attn_output + + def _in_proj_q(self, query): + query = paddle.transpose(query, [1, 2, 0]) + query = paddle.unsqueeze(query, axis=2) + res = self.conv1(query) + res = paddle.squeeze(res, axis=2) + res = paddle.transpose(res, [2, 0, 1]) + return res + + def _in_proj_k(self, key): + key = paddle.transpose(key, [1, 2, 0]) + key = paddle.unsqueeze(key, axis=2) + res = self.conv2(key) + res = paddle.squeeze(res, axis=2) + res = paddle.transpose(res, [2, 0, 1]) + return res + + def _in_proj_v(self, value): + value = paddle.transpose(value, [1, 2, 0]) #(1, 2, 0) + value = paddle.unsqueeze(value, axis=2) + res = self.conv3(value) + res = paddle.squeeze(res, axis=2) + res = paddle.transpose(res, [2, 0, 1]) + return res diff --git a/backend/ppocr/modeling/heads/rec_aster_head.py b/backend/ppocr/modeling/heads/rec_aster_head.py new file mode 100644 index 00000000..c95e8fd3 --- /dev/null +++ b/backend/ppocr/modeling/heads/rec_aster_head.py @@ -0,0 +1,393 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/ayumiymk/aster.pytorch/blob/master/lib/models/attention_recognition_head.py +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +import paddle +from paddle import nn +from paddle.nn import functional as F + + +class AsterHead(nn.Layer): + def __init__(self, + in_channels, + out_channels, + sDim, + attDim, + max_len_labels, + time_step=25, + beam_width=5, + **kwargs): + super(AsterHead, self).__init__() + self.num_classes = out_channels + self.in_planes = in_channels + self.sDim = sDim + self.attDim = attDim + self.max_len_labels = max_len_labels + self.decoder = AttentionRecognitionHead(in_channels, out_channels, sDim, + attDim, max_len_labels) + self.time_step = time_step + self.embeder = Embedding(self.time_step, in_channels) + self.beam_width = beam_width + self.eos = self.num_classes - 3 + + def forward(self, x, targets=None, embed=None): + return_dict = {} + embedding_vectors = self.embeder(x) + + if self.training: + rec_targets, rec_lengths, _ = targets + rec_pred = self.decoder([x, rec_targets, rec_lengths], + embedding_vectors) + return_dict['rec_pred'] = rec_pred + return_dict['embedding_vectors'] = embedding_vectors + else: + rec_pred, rec_pred_scores = self.decoder.beam_search( + x, self.beam_width, self.eos, embedding_vectors) + return_dict['rec_pred'] = rec_pred + return_dict['rec_pred_scores'] = rec_pred_scores + return_dict['embedding_vectors'] = embedding_vectors + + return return_dict + + +class Embedding(nn.Layer): + def __init__(self, in_timestep, in_planes, mid_dim=4096, embed_dim=300): + super(Embedding, self).__init__() + self.in_timestep = in_timestep + self.in_planes = in_planes + self.embed_dim = embed_dim + self.mid_dim = mid_dim + self.eEmbed = nn.Linear( + in_timestep * in_planes, + self.embed_dim) # Embed encoder output to a word-embedding like + + def forward(self, x): + x = paddle.reshape(x, [paddle.shape(x)[0], -1]) + x = self.eEmbed(x) + return x + + +class AttentionRecognitionHead(nn.Layer): + """ + input: [b x 16 x 64 x in_planes] + output: probability sequence: [b x T x num_classes] + """ + + def __init__(self, in_channels, out_channels, sDim, attDim, max_len_labels): + super(AttentionRecognitionHead, self).__init__() + self.num_classes = out_channels # this is the output classes. So it includes the . + self.in_planes = in_channels + self.sDim = sDim + self.attDim = attDim + self.max_len_labels = max_len_labels + + self.decoder = DecoderUnit( + sDim=sDim, xDim=in_channels, yDim=self.num_classes, attDim=attDim) + + def forward(self, x, embed): + x, targets, lengths = x + batch_size = paddle.shape(x)[0] + # Decoder + state = self.decoder.get_initial_state(embed) + outputs = [] + for i in range(max(lengths)): + if i == 0: + y_prev = paddle.full( + shape=[batch_size], fill_value=self.num_classes) + else: + y_prev = targets[:, i - 1] + output, state = self.decoder(x, state, y_prev) + outputs.append(output) + outputs = paddle.concat([_.unsqueeze(1) for _ in outputs], 1) + return outputs + + # inference stage. + def sample(self, x): + x, _, _ = x + batch_size = x.size(0) + # Decoder + state = paddle.zeros([1, batch_size, self.sDim]) + + predicted_ids, predicted_scores = [], [] + for i in range(self.max_len_labels): + if i == 0: + y_prev = paddle.full( + shape=[batch_size], fill_value=self.num_classes) + else: + y_prev = predicted + + output, state = self.decoder(x, state, y_prev) + output = F.softmax(output, axis=1) + score, predicted = output.max(1) + predicted_ids.append(predicted.unsqueeze(1)) + predicted_scores.append(score.unsqueeze(1)) + predicted_ids = paddle.concat([predicted_ids, 1]) + predicted_scores = paddle.concat([predicted_scores, 1]) + # return predicted_ids.squeeze(), predicted_scores.squeeze() + return predicted_ids, predicted_scores + + def beam_search(self, x, beam_width, eos, embed): + def _inflate(tensor, times, dim): + repeat_dims = [1] * tensor.dim() + repeat_dims[dim] = times + output = paddle.tile(tensor, repeat_dims) + return output + + # https://github.com/IBM/pytorch-seq2seq/blob/fede87655ddce6c94b38886089e05321dc9802af/seq2seq/models/TopKDecoder.py + batch_size, l, d = x.shape + x = paddle.tile( + paddle.transpose( + x.unsqueeze(1), perm=[1, 0, 2, 3]), [beam_width, 1, 1, 1]) + inflated_encoder_feats = paddle.reshape( + paddle.transpose( + x, perm=[1, 0, 2, 3]), [-1, l, d]) + + # Initialize the decoder + state = self.decoder.get_initial_state(embed, tile_times=beam_width) + + pos_index = paddle.reshape( + paddle.arange(batch_size) * beam_width, shape=[-1, 1]) + + # Initialize the scores + sequence_scores = paddle.full( + shape=[batch_size * beam_width, 1], fill_value=-float('Inf')) + index = [i * beam_width for i in range(0, batch_size)] + sequence_scores[index] = 0.0 + + # Initialize the input vector + y_prev = paddle.full( + shape=[batch_size * beam_width], fill_value=self.num_classes) + + # Store decisions for backtracking + stored_scores = list() + stored_predecessors = list() + stored_emitted_symbols = list() + + for i in range(self.max_len_labels): + output, state = self.decoder(inflated_encoder_feats, state, y_prev) + state = paddle.unsqueeze(state, axis=0) + log_softmax_output = paddle.nn.functional.log_softmax( + output, axis=1) + + sequence_scores = _inflate(sequence_scores, self.num_classes, 1) + sequence_scores += log_softmax_output + scores, candidates = paddle.topk( + paddle.reshape(sequence_scores, [batch_size, -1]), + beam_width, + axis=1) + + # Reshape input = (bk, 1) and sequence_scores = (bk, 1) + y_prev = paddle.reshape( + candidates % self.num_classes, shape=[batch_size * beam_width]) + sequence_scores = paddle.reshape( + scores, shape=[batch_size * beam_width, 1]) + + # Update fields for next timestep + pos_index = paddle.expand_as(pos_index, candidates) + predecessors = paddle.cast( + candidates / self.num_classes + pos_index, dtype='int64') + predecessors = paddle.reshape( + predecessors, shape=[batch_size * beam_width, 1]) + state = paddle.index_select( + state, index=predecessors.squeeze(), axis=1) + + # Update sequence socres and erase scores for symbol so that they aren't expanded + stored_scores.append(sequence_scores.clone()) + y_prev = paddle.reshape(y_prev, shape=[-1, 1]) + eos_prev = paddle.full_like(y_prev, fill_value=eos) + mask = eos_prev == y_prev + mask = paddle.nonzero(mask) + if mask.dim() > 0: + sequence_scores = sequence_scores.numpy() + mask = mask.numpy() + sequence_scores[mask] = -float('inf') + sequence_scores = paddle.to_tensor(sequence_scores) + + # Cache results for backtracking + stored_predecessors.append(predecessors) + y_prev = paddle.squeeze(y_prev) + stored_emitted_symbols.append(y_prev) + + # Do backtracking to return the optimal values + #====== backtrak ======# + # Initialize return variables given different types + p = list() + l = [[self.max_len_labels] * beam_width for _ in range(batch_size) + ] # Placeholder for lengths of top-k sequences + + # the last step output of the beams are not sorted + # thus they are sorted here + sorted_score, sorted_idx = paddle.topk( + paddle.reshape( + stored_scores[-1], shape=[batch_size, beam_width]), + beam_width) + + # initialize the sequence scores with the sorted last step beam scores + s = sorted_score.clone() + + batch_eos_found = [0] * batch_size # the number of EOS found + # in the backward loop below for each batch + t = self.max_len_labels - 1 + # initialize the back pointer with the sorted order of the last step beams. + # add pos_index for indexing variable with b*k as the first dimension. + t_predecessors = paddle.reshape( + sorted_idx + pos_index.expand_as(sorted_idx), + shape=[batch_size * beam_width]) + while t >= 0: + # Re-order the variables with the back pointer + current_symbol = paddle.index_select( + stored_emitted_symbols[t], index=t_predecessors, axis=0) + t_predecessors = paddle.index_select( + stored_predecessors[t].squeeze(), index=t_predecessors, axis=0) + eos_indices = stored_emitted_symbols[t] == eos + eos_indices = paddle.nonzero(eos_indices) + + if eos_indices.dim() > 0: + for i in range(eos_indices.shape[0] - 1, -1, -1): + # Indices of the EOS symbol for both variables + # with b*k as the first dimension, and b, k for + # the first two dimensions + idx = eos_indices[i] + b_idx = int(idx[0] / beam_width) + # The indices of the replacing position + # according to the replacement strategy noted above + res_k_idx = beam_width - (batch_eos_found[b_idx] % + beam_width) - 1 + batch_eos_found[b_idx] += 1 + res_idx = b_idx * beam_width + res_k_idx + + # Replace the old information in return variables + # with the new ended sequence information + t_predecessors[res_idx] = stored_predecessors[t][idx[0]] + current_symbol[res_idx] = stored_emitted_symbols[t][idx[0]] + s[b_idx, res_k_idx] = stored_scores[t][idx[0], 0] + l[b_idx][res_k_idx] = t + 1 + + # record the back tracked results + p.append(current_symbol) + t -= 1 + + # Sort and re-order again as the added ended sequences may change + # the order (very unlikely) + s, re_sorted_idx = s.topk(beam_width) + for b_idx in range(batch_size): + l[b_idx] = [ + l[b_idx][k_idx.item()] for k_idx in re_sorted_idx[b_idx, :] + ] + + re_sorted_idx = paddle.reshape( + re_sorted_idx + pos_index.expand_as(re_sorted_idx), + [batch_size * beam_width]) + + # Reverse the sequences and re-order at the same time + # It is reversed because the backtracking happens in reverse time order + p = [ + paddle.reshape( + paddle.index_select(step, re_sorted_idx, 0), + shape=[batch_size, beam_width, -1]) for step in reversed(p) + ] + p = paddle.concat(p, -1)[:, 0, :] + return p, paddle.ones_like(p) + + +class AttentionUnit(nn.Layer): + def __init__(self, sDim, xDim, attDim): + super(AttentionUnit, self).__init__() + + self.sDim = sDim + self.xDim = xDim + self.attDim = attDim + + self.sEmbed = nn.Linear(sDim, attDim) + self.xEmbed = nn.Linear(xDim, attDim) + self.wEmbed = nn.Linear(attDim, 1) + + def forward(self, x, sPrev): + batch_size, T, _ = x.shape # [b x T x xDim] + x = paddle.reshape(x, [-1, self.xDim]) # [(b x T) x xDim] + xProj = self.xEmbed(x) # [(b x T) x attDim] + xProj = paddle.reshape(xProj, [batch_size, T, -1]) # [b x T x attDim] + + sPrev = sPrev.squeeze(0) + sProj = self.sEmbed(sPrev) # [b x attDim] + sProj = paddle.unsqueeze(sProj, 1) # [b x 1 x attDim] + sProj = paddle.expand(sProj, + [batch_size, T, self.attDim]) # [b x T x attDim] + + sumTanh = paddle.tanh(sProj + xProj) + sumTanh = paddle.reshape(sumTanh, [-1, self.attDim]) + + vProj = self.wEmbed(sumTanh) # [(b x T) x 1] + vProj = paddle.reshape(vProj, [batch_size, T]) + alpha = F.softmax( + vProj, axis=1) # attention weights for each sample in the minibatch + return alpha + + +class DecoderUnit(nn.Layer): + def __init__(self, sDim, xDim, yDim, attDim): + super(DecoderUnit, self).__init__() + self.sDim = sDim + self.xDim = xDim + self.yDim = yDim + self.attDim = attDim + self.emdDim = attDim + + self.attention_unit = AttentionUnit(sDim, xDim, attDim) + self.tgt_embedding = nn.Embedding( + yDim + 1, self.emdDim, weight_attr=nn.initializer.Normal( + std=0.01)) # the last is used for + self.gru = nn.GRUCell(input_size=xDim + self.emdDim, hidden_size=sDim) + self.fc = nn.Linear( + sDim, + yDim, + weight_attr=nn.initializer.Normal(std=0.01), + bias_attr=nn.initializer.Constant(value=0)) + self.embed_fc = nn.Linear(300, self.sDim) + + def get_initial_state(self, embed, tile_times=1): + assert embed.shape[1] == 300 + state = self.embed_fc(embed) # N * sDim + if tile_times != 1: + state = state.unsqueeze(1) + trans_state = paddle.transpose(state, perm=[1, 0, 2]) + state = paddle.tile(trans_state, repeat_times=[tile_times, 1, 1]) + trans_state = paddle.transpose(state, perm=[1, 0, 2]) + state = paddle.reshape(trans_state, shape=[-1, self.sDim]) + state = state.unsqueeze(0) # 1 * N * sDim + return state + + def forward(self, x, sPrev, yPrev): + # x: feature sequence from the image decoder. + batch_size, T, _ = x.shape + alpha = self.attention_unit(x, sPrev) + context = paddle.squeeze(paddle.matmul(alpha.unsqueeze(1), x), axis=1) + yPrev = paddle.cast(yPrev, dtype="int64") + yProj = self.tgt_embedding(yPrev) + + concat_context = paddle.concat([yProj, context], 1) + concat_context = paddle.squeeze(concat_context, 1) + sPrev = paddle.squeeze(sPrev, 0) + output, state = self.gru(concat_context, sPrev) + output = paddle.squeeze(output, axis=1) + output = self.fc(output) + return output, state \ No newline at end of file diff --git a/backend/ppocr/modeling/heads/rec_att_head.py b/backend/ppocr/modeling/heads/rec_att_head.py index 0d222714..ab8b119f 100644 --- a/backend/ppocr/modeling/heads/rec_att_head.py +++ b/backend/ppocr/modeling/heads/rec_att_head.py @@ -38,7 +38,7 @@ def _char_to_onehot(self, input_char, onehot_dim): return input_ont_hot def forward(self, inputs, targets=None, batch_max_length=25): - batch_size = inputs.shape[0] + batch_size = paddle.shape(inputs)[0] num_steps = batch_max_length hidden = paddle.zeros((batch_size, self.hidden_size)) @@ -53,7 +53,6 @@ def forward(self, inputs, targets=None, batch_max_length=25): output_hiddens.append(paddle.unsqueeze(outputs, axis=1)) output = paddle.concat(output_hiddens, axis=1) probs = self.generator(output) - else: targets = paddle.zeros(shape=[batch_size], dtype="int32") probs = None @@ -75,7 +74,8 @@ def forward(self, inputs, targets=None, batch_max_length=25): probs_step, axis=1)], axis=1) next_input = probs_step.argmax(axis=1) targets = next_input - + if not self.training: + probs = paddle.nn.functional.softmax(probs, axis=2) return probs diff --git a/backend/ppocr/modeling/heads/rec_ctc_head.py b/backend/ppocr/modeling/heads/rec_ctc_head.py index 69d4ef50..6c1cf065 100755 --- a/backend/ppocr/modeling/heads/rec_ctc_head.py +++ b/backend/ppocr/modeling/heads/rec_ctc_head.py @@ -23,32 +23,65 @@ from paddle.nn import functional as F -def get_para_bias_attr(l2_decay, k, name): +def get_para_bias_attr(l2_decay, k): regularizer = paddle.regularizer.L2Decay(l2_decay) stdv = 1.0 / math.sqrt(k * 1.0) initializer = nn.initializer.Uniform(-stdv, stdv) - weight_attr = ParamAttr( - regularizer=regularizer, initializer=initializer, name=name + "_w_attr") - bias_attr = ParamAttr( - regularizer=regularizer, initializer=initializer, name=name + "_b_attr") + weight_attr = ParamAttr(regularizer=regularizer, initializer=initializer) + bias_attr = ParamAttr(regularizer=regularizer, initializer=initializer) return [weight_attr, bias_attr] class CTCHead(nn.Layer): - def __init__(self, in_channels, out_channels, fc_decay=0.0004, **kwargs): + def __init__(self, + in_channels, + out_channels, + fc_decay=0.0004, + mid_channels=None, + return_feats=False, + **kwargs): super(CTCHead, self).__init__() - weight_attr, bias_attr = get_para_bias_attr( - l2_decay=fc_decay, k=in_channels, name='ctc_fc') - self.fc = nn.Linear( - in_channels, - out_channels, - weight_attr=weight_attr, - bias_attr=bias_attr, - name='ctc_fc') + if mid_channels is None: + weight_attr, bias_attr = get_para_bias_attr( + l2_decay=fc_decay, k=in_channels) + self.fc = nn.Linear( + in_channels, + out_channels, + weight_attr=weight_attr, + bias_attr=bias_attr) + else: + weight_attr1, bias_attr1 = get_para_bias_attr( + l2_decay=fc_decay, k=in_channels) + self.fc1 = nn.Linear( + in_channels, + mid_channels, + weight_attr=weight_attr1, + bias_attr=bias_attr1) + + weight_attr2, bias_attr2 = get_para_bias_attr( + l2_decay=fc_decay, k=mid_channels) + self.fc2 = nn.Linear( + mid_channels, + out_channels, + weight_attr=weight_attr2, + bias_attr=bias_attr2) self.out_channels = out_channels + self.mid_channels = mid_channels + self.return_feats = return_feats + + def forward(self, x, targets=None): + if self.mid_channels is None: + predicts = self.fc(x) + else: + x = self.fc1(x) + predicts = self.fc2(x) - def forward(self, x, labels=None): - predicts = self.fc(x) + if self.return_feats: + result = (x, predicts) + else: + result = predicts if not self.training: predicts = F.softmax(predicts, axis=2) - return predicts + result = predicts + + return result diff --git a/backend/ppocr/modeling/heads/rec_multi_head.py b/backend/ppocr/modeling/heads/rec_multi_head.py new file mode 100644 index 00000000..ef78bf98 --- /dev/null +++ b/backend/ppocr/modeling/heads/rec_multi_head.py @@ -0,0 +1,73 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import paddle +from paddle import ParamAttr +import paddle.nn as nn +import paddle.nn.functional as F + +from ppocr.modeling.necks.rnn import Im2Seq, EncoderWithRNN, EncoderWithFC, SequenceEncoder, EncoderWithSVTR +from .rec_ctc_head import CTCHead +from .rec_sar_head import SARHead + + +class MultiHead(nn.Layer): + def __init__(self, in_channels, out_channels_list, **kwargs): + super().__init__() + self.head_list = kwargs.pop('head_list') + self.gtc_head = 'sar' + assert len(self.head_list) >= 2 + for idx, head_name in enumerate(self.head_list): + name = list(head_name)[0] + if name == 'SARHead': + # sar head + sar_args = self.head_list[idx][name] + self.sar_head = eval(name)(in_channels=in_channels, \ + out_channels=out_channels_list['SARLabelDecode'], **sar_args) + elif name == 'CTCHead': + # ctc neck + self.encoder_reshape = Im2Seq(in_channels) + neck_args = self.head_list[idx][name]['Neck'] + encoder_type = neck_args.pop('name') + self.encoder = encoder_type + self.ctc_encoder = SequenceEncoder(in_channels=in_channels, \ + encoder_type=encoder_type, **neck_args) + # ctc head + head_args = self.head_list[idx][name]['Head'] + self.ctc_head = eval(name)(in_channels=self.ctc_encoder.out_channels, \ + out_channels=out_channels_list['CTCLabelDecode'], **head_args) + else: + raise NotImplementedError( + '{} is not supported in MultiHead yet'.format(name)) + + def forward(self, x, targets=None): + ctc_encoder = self.ctc_encoder(x) + ctc_out = self.ctc_head(ctc_encoder, targets) + head_out = dict() + head_out['ctc'] = ctc_out + head_out['ctc_neck'] = ctc_encoder + # eval mode + if not self.training: + return ctc_out + if self.gtc_head == 'sar': + sar_out = self.sar_head(x, targets[1:]) + head_out['sar'] = sar_out + return head_out + else: + return head_out diff --git a/backend/ppocr/modeling/heads/rec_nrtr_head.py b/backend/ppocr/modeling/heads/rec_nrtr_head.py new file mode 100644 index 00000000..38ba0c91 --- /dev/null +++ b/backend/ppocr/modeling/heads/rec_nrtr_head.py @@ -0,0 +1,826 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import paddle +import copy +from paddle import nn +import paddle.nn.functional as F +from paddle.nn import LayerList +from paddle.nn.initializer import XavierNormal as xavier_uniform_ +from paddle.nn import Dropout, Linear, LayerNorm, Conv2D +import numpy as np +from ppocr.modeling.heads.multiheadAttention import MultiheadAttention +from paddle.nn.initializer import Constant as constant_ +from paddle.nn.initializer import XavierNormal as xavier_normal_ + +zeros_ = constant_(value=0.) +ones_ = constant_(value=1.) + + +class Transformer(nn.Layer): + """A transformer model. User is able to modify the attributes as needed. The architechture + is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer, + Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and + Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information + Processing Systems, pages 6000-6010. + + Args: + d_model: the number of expected features in the encoder/decoder inputs (default=512). + nhead: the number of heads in the multiheadattention models (default=8). + num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6). + num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + custom_encoder: custom encoder (default=None). + custom_decoder: custom decoder (default=None). + + """ + + def __init__(self, + d_model=512, + nhead=8, + num_encoder_layers=6, + beam_size=0, + num_decoder_layers=6, + dim_feedforward=1024, + attention_dropout_rate=0.0, + residual_dropout_rate=0.1, + custom_encoder=None, + custom_decoder=None, + in_channels=0, + out_channels=0, + scale_embedding=True): + super(Transformer, self).__init__() + self.out_channels = out_channels + 1 + self.embedding = Embeddings( + d_model=d_model, + vocab=self.out_channels, + padding_idx=0, + scale_embedding=scale_embedding) + self.positional_encoding = PositionalEncoding( + dropout=residual_dropout_rate, + dim=d_model, ) + if custom_encoder is not None: + self.encoder = custom_encoder + else: + if num_encoder_layers > 0: + encoder_layer = TransformerEncoderLayer( + d_model, nhead, dim_feedforward, attention_dropout_rate, + residual_dropout_rate) + self.encoder = TransformerEncoder(encoder_layer, + num_encoder_layers) + else: + self.encoder = None + + if custom_decoder is not None: + self.decoder = custom_decoder + else: + decoder_layer = TransformerDecoderLayer( + d_model, nhead, dim_feedforward, attention_dropout_rate, + residual_dropout_rate) + self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers) + + self._reset_parameters() + self.beam_size = beam_size + self.d_model = d_model + self.nhead = nhead + self.tgt_word_prj = nn.Linear( + d_model, self.out_channels, bias_attr=False) + w0 = np.random.normal(0.0, d_model**-0.5, + (d_model, self.out_channels)).astype(np.float32) + self.tgt_word_prj.weight.set_value(w0) + self.apply(self._init_weights) + + def _init_weights(self, m): + + if isinstance(m, nn.Conv2D): + xavier_normal_(m.weight) + if m.bias is not None: + zeros_(m.bias) + + def forward_train(self, src, tgt): + tgt = tgt[:, :-1] + + tgt_key_padding_mask = self.generate_padding_mask(tgt) + tgt = self.embedding(tgt).transpose([1, 0, 2]) + tgt = self.positional_encoding(tgt) + tgt_mask = self.generate_square_subsequent_mask(tgt.shape[0]) + + if self.encoder is not None: + src = self.positional_encoding(src.transpose([1, 0, 2])) + memory = self.encoder(src) + else: + memory = src.squeeze(2).transpose([2, 0, 1]) + output = self.decoder( + tgt, + memory, + tgt_mask=tgt_mask, + memory_mask=None, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=None) + output = output.transpose([1, 0, 2]) + logit = self.tgt_word_prj(output) + return logit + + def forward(self, src, targets=None): + """Take in and process masked source/target sequences. + Args: + src: the sequence to the encoder (required). + tgt: the sequence to the decoder (required). + Shape: + - src: :math:`(S, N, E)`. + - tgt: :math:`(T, N, E)`. + Examples: + >>> output = transformer_model(src, tgt) + """ + + if self.training: + max_len = targets[1].max() + tgt = targets[0][:, :2 + max_len] + return self.forward_train(src, tgt) + else: + if self.beam_size > 0: + return self.forward_beam(src) + else: + return self.forward_test(src) + + def forward_test(self, src): + bs = paddle.shape(src)[0] + if self.encoder is not None: + src = self.positional_encoding(paddle.transpose(src, [1, 0, 2])) + memory = self.encoder(src) + else: + memory = paddle.transpose(paddle.squeeze(src, 2), [2, 0, 1]) + dec_seq = paddle.full((bs, 1), 2, dtype=paddle.int64) + dec_prob = paddle.full((bs, 1), 1., dtype=paddle.float32) + for len_dec_seq in range(1, 25): + dec_seq_embed = paddle.transpose(self.embedding(dec_seq), [1, 0, 2]) + dec_seq_embed = self.positional_encoding(dec_seq_embed) + tgt_mask = self.generate_square_subsequent_mask( + paddle.shape(dec_seq_embed)[0]) + output = self.decoder( + dec_seq_embed, + memory, + tgt_mask=tgt_mask, + memory_mask=None, + tgt_key_padding_mask=None, + memory_key_padding_mask=None) + dec_output = paddle.transpose(output, [1, 0, 2]) + dec_output = dec_output[:, -1, :] + word_prob = F.softmax(self.tgt_word_prj(dec_output), axis=1) + preds_idx = paddle.argmax(word_prob, axis=1) + if paddle.equal_all( + preds_idx, + paddle.full( + paddle.shape(preds_idx), 3, dtype='int64')): + break + preds_prob = paddle.max(word_prob, axis=1) + dec_seq = paddle.concat( + [dec_seq, paddle.reshape(preds_idx, [-1, 1])], axis=1) + dec_prob = paddle.concat( + [dec_prob, paddle.reshape(preds_prob, [-1, 1])], axis=1) + return [dec_seq, dec_prob] + + def forward_beam(self, images): + ''' Translation work in one batch ''' + + def get_inst_idx_to_tensor_position_map(inst_idx_list): + ''' Indicate the position of an instance in a tensor. ''' + return { + inst_idx: tensor_position + for tensor_position, inst_idx in enumerate(inst_idx_list) + } + + def collect_active_part(beamed_tensor, curr_active_inst_idx, + n_prev_active_inst, n_bm): + ''' Collect tensor parts associated to active instances. ''' + + beamed_tensor_shape = paddle.shape(beamed_tensor) + n_curr_active_inst = len(curr_active_inst_idx) + new_shape = (n_curr_active_inst * n_bm, beamed_tensor_shape[1], + beamed_tensor_shape[2]) + + beamed_tensor = beamed_tensor.reshape([n_prev_active_inst, -1]) + beamed_tensor = beamed_tensor.index_select( + curr_active_inst_idx, axis=0) + beamed_tensor = beamed_tensor.reshape(new_shape) + + return beamed_tensor + + def collate_active_info(src_enc, inst_idx_to_position_map, + active_inst_idx_list): + # Sentences which are still active are collected, + # so the decoder will not run on completed sentences. + + n_prev_active_inst = len(inst_idx_to_position_map) + active_inst_idx = [ + inst_idx_to_position_map[k] for k in active_inst_idx_list + ] + active_inst_idx = paddle.to_tensor(active_inst_idx, dtype='int64') + active_src_enc = collect_active_part( + src_enc.transpose([1, 0, 2]), active_inst_idx, + n_prev_active_inst, n_bm).transpose([1, 0, 2]) + active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map( + active_inst_idx_list) + return active_src_enc, active_inst_idx_to_position_map + + def beam_decode_step(inst_dec_beams, len_dec_seq, enc_output, + inst_idx_to_position_map, n_bm, + memory_key_padding_mask): + ''' Decode and update beam status, and then return active beam idx ''' + + def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq): + dec_partial_seq = [ + b.get_current_state() for b in inst_dec_beams if not b.done + ] + dec_partial_seq = paddle.stack(dec_partial_seq) + dec_partial_seq = dec_partial_seq.reshape([-1, len_dec_seq]) + return dec_partial_seq + + def predict_word(dec_seq, enc_output, n_active_inst, n_bm, + memory_key_padding_mask): + dec_seq = paddle.transpose(self.embedding(dec_seq), [1, 0, 2]) + dec_seq = self.positional_encoding(dec_seq) + tgt_mask = self.generate_square_subsequent_mask( + paddle.shape(dec_seq)[0]) + dec_output = self.decoder( + dec_seq, + enc_output, + tgt_mask=tgt_mask, + tgt_key_padding_mask=None, + memory_key_padding_mask=memory_key_padding_mask, ) + dec_output = paddle.transpose(dec_output, [1, 0, 2]) + dec_output = dec_output[:, + -1, :] # Pick the last step: (bh * bm) * d_h + word_prob = F.softmax(self.tgt_word_prj(dec_output), axis=1) + word_prob = paddle.reshape(word_prob, [n_active_inst, n_bm, -1]) + return word_prob + + def collect_active_inst_idx_list(inst_beams, word_prob, + inst_idx_to_position_map): + active_inst_idx_list = [] + for inst_idx, inst_position in inst_idx_to_position_map.items(): + is_inst_complete = inst_beams[inst_idx].advance(word_prob[ + inst_position]) + if not is_inst_complete: + active_inst_idx_list += [inst_idx] + + return active_inst_idx_list + + n_active_inst = len(inst_idx_to_position_map) + dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq) + word_prob = predict_word(dec_seq, enc_output, n_active_inst, n_bm, + None) + # Update the beam with predicted word prob information and collect incomplete instances + active_inst_idx_list = collect_active_inst_idx_list( + inst_dec_beams, word_prob, inst_idx_to_position_map) + return active_inst_idx_list + + def collect_hypothesis_and_scores(inst_dec_beams, n_best): + all_hyp, all_scores = [], [] + for inst_idx in range(len(inst_dec_beams)): + scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores() + all_scores += [scores[:n_best]] + hyps = [ + inst_dec_beams[inst_idx].get_hypothesis(i) + for i in tail_idxs[:n_best] + ] + all_hyp += [hyps] + return all_hyp, all_scores + + with paddle.no_grad(): + #-- Encode + if self.encoder is not None: + src = self.positional_encoding(images.transpose([1, 0, 2])) + src_enc = self.encoder(src) + else: + src_enc = images.squeeze(2).transpose([0, 2, 1]) + + n_bm = self.beam_size + src_shape = paddle.shape(src_enc) + inst_dec_beams = [Beam(n_bm) for _ in range(1)] + active_inst_idx_list = list(range(1)) + # Repeat data for beam search + src_enc = paddle.tile(src_enc, [1, n_bm, 1]) + inst_idx_to_position_map = get_inst_idx_to_tensor_position_map( + active_inst_idx_list) + # Decode + for len_dec_seq in range(1, 25): + src_enc_copy = src_enc.clone() + active_inst_idx_list = beam_decode_step( + inst_dec_beams, len_dec_seq, src_enc_copy, + inst_idx_to_position_map, n_bm, None) + if not active_inst_idx_list: + break # all instances have finished their path to + src_enc, inst_idx_to_position_map = collate_active_info( + src_enc_copy, inst_idx_to_position_map, + active_inst_idx_list) + batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams, + 1) + result_hyp = [] + hyp_scores = [] + for bs_hyp, score in zip(batch_hyp, batch_scores): + l = len(bs_hyp[0]) + bs_hyp_pad = bs_hyp[0] + [3] * (25 - l) + result_hyp.append(bs_hyp_pad) + score = float(score) / l + hyp_score = [score for _ in range(25)] + hyp_scores.append(hyp_score) + return [ + paddle.to_tensor( + np.array(result_hyp), dtype=paddle.int64), + paddle.to_tensor(hyp_scores) + ] + + def generate_square_subsequent_mask(self, sz): + """Generate a square mask for the sequence. The masked positions are filled with float('-inf'). + Unmasked positions are filled with float(0.0). + """ + mask = paddle.zeros([sz, sz], dtype='float32') + mask_inf = paddle.triu( + paddle.full( + shape=[sz, sz], dtype='float32', fill_value='-inf'), + diagonal=1) + mask = mask + mask_inf + return mask + + def generate_padding_mask(self, x): + padding_mask = paddle.equal(x, paddle.to_tensor(0, dtype=x.dtype)) + return padding_mask + + def _reset_parameters(self): + """Initiate parameters in the transformer model.""" + + for p in self.parameters(): + if p.dim() > 1: + xavier_uniform_(p) + + +class TransformerEncoder(nn.Layer): + """TransformerEncoder is a stack of N encoder layers + Args: + encoder_layer: an instance of the TransformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + norm: the layer normalization component (optional). + """ + + def __init__(self, encoder_layer, num_layers): + super(TransformerEncoder, self).__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + + def forward(self, src): + """Pass the input through the endocder layers in turn. + Args: + src: the sequnce to the encoder (required). + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + """ + output = src + + for i in range(self.num_layers): + output = self.layers[i](output, + src_mask=None, + src_key_padding_mask=None) + + return output + + +class TransformerDecoder(nn.Layer): + """TransformerDecoder is a stack of N decoder layers + + Args: + decoder_layer: an instance of the TransformerDecoderLayer() class (required). + num_layers: the number of sub-decoder-layers in the decoder (required). + norm: the layer normalization component (optional). + + """ + + def __init__(self, decoder_layer, num_layers): + super(TransformerDecoder, self).__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + + def forward(self, + tgt, + memory, + tgt_mask=None, + memory_mask=None, + tgt_key_padding_mask=None, + memory_key_padding_mask=None): + """Pass the inputs (and mask) through the decoder layer in turn. + + Args: + tgt: the sequence to the decoder (required). + memory: the sequnce from the last layer of the encoder (required). + tgt_mask: the mask for the tgt sequence (optional). + memory_mask: the mask for the memory sequence (optional). + tgt_key_padding_mask: the mask for the tgt keys per batch (optional). + memory_key_padding_mask: the mask for the memory keys per batch (optional). + """ + output = tgt + for i in range(self.num_layers): + output = self.layers[i]( + output, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask) + + return output + + +class TransformerEncoderLayer(nn.Layer): + """TransformerEncoderLayer is made up of self-attn and feedforward network. + This standard encoder layer is based on the paper "Attention Is All You Need". + Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, + Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in + Neural Information Processing Systems, pages 6000-6010. Users may modify or implement + in a different way during application. + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + + """ + + def __init__(self, + d_model, + nhead, + dim_feedforward=2048, + attention_dropout_rate=0.0, + residual_dropout_rate=0.1): + super(TransformerEncoderLayer, self).__init__() + self.self_attn = MultiheadAttention( + d_model, nhead, dropout=attention_dropout_rate) + + self.conv1 = Conv2D( + in_channels=d_model, + out_channels=dim_feedforward, + kernel_size=(1, 1)) + self.conv2 = Conv2D( + in_channels=dim_feedforward, + out_channels=d_model, + kernel_size=(1, 1)) + + self.norm1 = LayerNorm(d_model) + self.norm2 = LayerNorm(d_model) + self.dropout1 = Dropout(residual_dropout_rate) + self.dropout2 = Dropout(residual_dropout_rate) + + def forward(self, src, src_mask=None, src_key_padding_mask=None): + """Pass the input through the endocder layer. + Args: + src: the sequnce to the encoder layer (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + """ + src2 = self.self_attn( + src, + src, + src, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask) + src = src + self.dropout1(src2) + src = self.norm1(src) + + src = paddle.transpose(src, [1, 2, 0]) + src = paddle.unsqueeze(src, 2) + src2 = self.conv2(F.relu(self.conv1(src))) + src2 = paddle.squeeze(src2, 2) + src2 = paddle.transpose(src2, [2, 0, 1]) + src = paddle.squeeze(src, 2) + src = paddle.transpose(src, [2, 0, 1]) + + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + +class TransformerDecoderLayer(nn.Layer): + """TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network. + This standard decoder layer is based on the paper "Attention Is All You Need". + Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, + Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in + Neural Information Processing Systems, pages 6000-6010. Users may modify or implement + in a different way during application. + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + + """ + + def __init__(self, + d_model, + nhead, + dim_feedforward=2048, + attention_dropout_rate=0.0, + residual_dropout_rate=0.1): + super(TransformerDecoderLayer, self).__init__() + self.self_attn = MultiheadAttention( + d_model, nhead, dropout=attention_dropout_rate) + self.multihead_attn = MultiheadAttention( + d_model, nhead, dropout=attention_dropout_rate) + + self.conv1 = Conv2D( + in_channels=d_model, + out_channels=dim_feedforward, + kernel_size=(1, 1)) + self.conv2 = Conv2D( + in_channels=dim_feedforward, + out_channels=d_model, + kernel_size=(1, 1)) + + self.norm1 = LayerNorm(d_model) + self.norm2 = LayerNorm(d_model) + self.norm3 = LayerNorm(d_model) + self.dropout1 = Dropout(residual_dropout_rate) + self.dropout2 = Dropout(residual_dropout_rate) + self.dropout3 = Dropout(residual_dropout_rate) + + def forward(self, + tgt, + memory, + tgt_mask=None, + memory_mask=None, + tgt_key_padding_mask=None, + memory_key_padding_mask=None): + """Pass the inputs (and mask) through the decoder layer. + + Args: + tgt: the sequence to the decoder layer (required). + memory: the sequnce from the last layer of the encoder (required). + tgt_mask: the mask for the tgt sequence (optional). + memory_mask: the mask for the memory sequence (optional). + tgt_key_padding_mask: the mask for the tgt keys per batch (optional). + memory_key_padding_mask: the mask for the memory keys per batch (optional). + + """ + tgt2 = self.self_attn( + tgt, + tgt, + tgt, + attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask) + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn( + tgt, + memory, + memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask) + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + + # default + tgt = paddle.transpose(tgt, [1, 2, 0]) + tgt = paddle.unsqueeze(tgt, 2) + tgt2 = self.conv2(F.relu(self.conv1(tgt))) + tgt2 = paddle.squeeze(tgt2, 2) + tgt2 = paddle.transpose(tgt2, [2, 0, 1]) + tgt = paddle.squeeze(tgt, 2) + tgt = paddle.transpose(tgt, [2, 0, 1]) + + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + +def _get_clones(module, N): + return LayerList([copy.deepcopy(module) for i in range(N)]) + + +class PositionalEncoding(nn.Layer): + """Inject some information about the relative or absolute position of the tokens + in the sequence. The positional encodings have the same dimension as + the embeddings, so that the two can be summed. Here, we use sine and cosine + functions of different frequencies. + .. math:: + \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model)) + \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model)) + \text{where pos is the word position and i is the embed idx) + Args: + d_model: the embed dim (required). + dropout: the dropout value (default=0.1). + max_len: the max. length of the incoming sequence (default=5000). + Examples: + >>> pos_encoder = PositionalEncoding(d_model) + """ + + def __init__(self, dropout, dim, max_len=5000): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + pe = paddle.zeros([max_len, dim]) + position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1) + div_term = paddle.exp( + paddle.arange(0, dim, 2).astype('float32') * + (-math.log(10000.0) / dim)) + pe[:, 0::2] = paddle.sin(position * div_term) + pe[:, 1::2] = paddle.cos(position * div_term) + pe = paddle.unsqueeze(pe, 0) + pe = paddle.transpose(pe, [1, 0, 2]) + self.register_buffer('pe', pe) + + def forward(self, x): + """Inputs of forward function + Args: + x: the sequence fed to the positional encoder model (required). + Shape: + x: [sequence length, batch size, embed dim] + output: [sequence length, batch size, embed dim] + Examples: + >>> output = pos_encoder(x) + """ + x = x + self.pe[:paddle.shape(x)[0], :] + return self.dropout(x) + + +class PositionalEncoding_2d(nn.Layer): + """Inject some information about the relative or absolute position of the tokens + in the sequence. The positional encodings have the same dimension as + the embeddings, so that the two can be summed. Here, we use sine and cosine + functions of different frequencies. + .. math:: + \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model)) + \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model)) + \text{where pos is the word position and i is the embed idx) + Args: + d_model: the embed dim (required). + dropout: the dropout value (default=0.1). + max_len: the max. length of the incoming sequence (default=5000). + Examples: + >>> pos_encoder = PositionalEncoding(d_model) + """ + + def __init__(self, dropout, dim, max_len=5000): + super(PositionalEncoding_2d, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + pe = paddle.zeros([max_len, dim]) + position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1) + div_term = paddle.exp( + paddle.arange(0, dim, 2).astype('float32') * + (-math.log(10000.0) / dim)) + pe[:, 0::2] = paddle.sin(position * div_term) + pe[:, 1::2] = paddle.cos(position * div_term) + pe = paddle.transpose(paddle.unsqueeze(pe, 0), [1, 0, 2]) + self.register_buffer('pe', pe) + + self.avg_pool_1 = nn.AdaptiveAvgPool2D((1, 1)) + self.linear1 = nn.Linear(dim, dim) + self.linear1.weight.data.fill_(1.) + self.avg_pool_2 = nn.AdaptiveAvgPool2D((1, 1)) + self.linear2 = nn.Linear(dim, dim) + self.linear2.weight.data.fill_(1.) + + def forward(self, x): + """Inputs of forward function + Args: + x: the sequence fed to the positional encoder model (required). + Shape: + x: [sequence length, batch size, embed dim] + output: [sequence length, batch size, embed dim] + Examples: + >>> output = pos_encoder(x) + """ + w_pe = self.pe[:paddle.shape(x)[-1], :] + w1 = self.linear1(self.avg_pool_1(x).squeeze()).unsqueeze(0) + w_pe = w_pe * w1 + w_pe = paddle.transpose(w_pe, [1, 2, 0]) + w_pe = paddle.unsqueeze(w_pe, 2) + + h_pe = self.pe[:paddle.shape(x).shape[-2], :] + w2 = self.linear2(self.avg_pool_2(x).squeeze()).unsqueeze(0) + h_pe = h_pe * w2 + h_pe = paddle.transpose(h_pe, [1, 2, 0]) + h_pe = paddle.unsqueeze(h_pe, 3) + + x = x + w_pe + h_pe + x = paddle.transpose( + paddle.reshape(x, + [x.shape[0], x.shape[1], x.shape[2] * x.shape[3]]), + [2, 0, 1]) + + return self.dropout(x) + + +class Embeddings(nn.Layer): + def __init__(self, d_model, vocab, padding_idx, scale_embedding): + super(Embeddings, self).__init__() + self.embedding = nn.Embedding(vocab, d_model, padding_idx=padding_idx) + w0 = np.random.normal(0.0, d_model**-0.5, + (vocab, d_model)).astype(np.float32) + self.embedding.weight.set_value(w0) + self.d_model = d_model + self.scale_embedding = scale_embedding + + def forward(self, x): + if self.scale_embedding: + x = self.embedding(x) + return x * math.sqrt(self.d_model) + return self.embedding(x) + + +class Beam(): + ''' Beam search ''' + + def __init__(self, size, device=False): + + self.size = size + self._done = False + # The score for each translation on the beam. + self.scores = paddle.zeros((size, ), dtype=paddle.float32) + self.all_scores = [] + # The backpointers at each time-step. + self.prev_ks = [] + # The outputs at each time-step. + self.next_ys = [paddle.full((size, ), 0, dtype=paddle.int64)] + self.next_ys[0][0] = 2 + + def get_current_state(self): + "Get the outputs for the current timestep." + return self.get_tentative_hypothesis() + + def get_current_origin(self): + "Get the backpointers for the current timestep." + return self.prev_ks[-1] + + @property + def done(self): + return self._done + + def advance(self, word_prob): + "Update beam status and check if finished or not." + num_words = word_prob.shape[1] + + # Sum the previous scores. + if len(self.prev_ks) > 0: + beam_lk = word_prob + self.scores.unsqueeze(1).expand_as(word_prob) + else: + beam_lk = word_prob[0] + + flat_beam_lk = beam_lk.reshape([-1]) + best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, + True) # 1st sort + self.all_scores.append(self.scores) + self.scores = best_scores + # bestScoresId is flattened as a (beam x word) array, + # so we need to calculate which word and beam each score came from + prev_k = best_scores_id // num_words + self.prev_ks.append(prev_k) + self.next_ys.append(best_scores_id - prev_k * num_words) + # End condition is when top-of-beam is EOS. + if self.next_ys[-1][0] == 3: + self._done = True + self.all_scores.append(self.scores) + + return self._done + + def sort_scores(self): + "Sort the scores." + return self.scores, paddle.to_tensor( + [i for i in range(int(self.scores.shape[0]))], dtype='int32') + + def get_the_best_score_and_idx(self): + "Get the score of the best in the beam." + scores, ids = self.sort_scores() + return scores[1], ids[1] + + def get_tentative_hypothesis(self): + "Get the decoded sequence for the current timestep." + if len(self.next_ys) == 1: + dec_seq = self.next_ys[0].unsqueeze(1) + else: + _, keys = self.sort_scores() + hyps = [self.get_hypothesis(k) for k in keys] + hyps = [[2] + h for h in hyps] + dec_seq = paddle.to_tensor(hyps, dtype='int64') + return dec_seq + + def get_hypothesis(self, k): + """ Walk back to construct the full hypothesis. """ + hyp = [] + for j in range(len(self.prev_ks) - 1, -1, -1): + hyp.append(self.next_ys[j + 1][k]) + k = self.prev_ks[j][k] + return list(map(lambda x: x.item(), hyp[::-1])) diff --git a/backend/ppocr/modeling/heads/rec_pren_head.py b/backend/ppocr/modeling/heads/rec_pren_head.py new file mode 100644 index 00000000..c9e4b3e9 --- /dev/null +++ b/backend/ppocr/modeling/heads/rec_pren_head.py @@ -0,0 +1,34 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from paddle import nn +from paddle.nn import functional as F + + +class PRENHead(nn.Layer): + def __init__(self, in_channels, out_channels, **kwargs): + super(PRENHead, self).__init__() + self.linear = nn.Linear(in_channels, out_channels) + + def forward(self, x, targets=None): + predicts = self.linear(x) + + if not self.training: + predicts = F.softmax(predicts, axis=2) + + return predicts diff --git a/backend/ppocr/modeling/heads/rec_sar_head.py b/backend/ppocr/modeling/heads/rec_sar_head.py new file mode 100644 index 00000000..0e6b3440 --- /dev/null +++ b/backend/ppocr/modeling/heads/rec_sar_head.py @@ -0,0 +1,410 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/encoders/sar_encoder.py +https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/decoders/sar_decoder.py +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import paddle +from paddle import ParamAttr +import paddle.nn as nn +import paddle.nn.functional as F + + +class SAREncoder(nn.Layer): + """ + Args: + enc_bi_rnn (bool): If True, use bidirectional RNN in encoder. + enc_drop_rnn (float): Dropout probability of RNN layer in encoder. + enc_gru (bool): If True, use GRU, else LSTM in encoder. + d_model (int): Dim of channels from backbone. + d_enc (int): Dim of encoder RNN layer. + mask (bool): If True, mask padding in RNN sequence. + """ + + def __init__(self, + enc_bi_rnn=False, + enc_drop_rnn=0.1, + enc_gru=False, + d_model=512, + d_enc=512, + mask=True, + **kwargs): + super().__init__() + assert isinstance(enc_bi_rnn, bool) + assert isinstance(enc_drop_rnn, (int, float)) + assert 0 <= enc_drop_rnn < 1.0 + assert isinstance(enc_gru, bool) + assert isinstance(d_model, int) + assert isinstance(d_enc, int) + assert isinstance(mask, bool) + + self.enc_bi_rnn = enc_bi_rnn + self.enc_drop_rnn = enc_drop_rnn + self.mask = mask + + # LSTM Encoder + if enc_bi_rnn: + direction = 'bidirectional' + else: + direction = 'forward' + kwargs = dict( + input_size=d_model, + hidden_size=d_enc, + num_layers=2, + time_major=False, + dropout=enc_drop_rnn, + direction=direction) + if enc_gru: + self.rnn_encoder = nn.GRU(**kwargs) + else: + self.rnn_encoder = nn.LSTM(**kwargs) + + # global feature transformation + encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1) + self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size) + + def forward(self, feat, img_metas=None): + if img_metas is not None: + assert len(img_metas[0]) == feat.shape[0] + + valid_ratios = None + if img_metas is not None and self.mask: + valid_ratios = img_metas[-1] + + h_feat = feat.shape[2] # bsz c h w + feat_v = F.max_pool2d( + feat, kernel_size=(h_feat, 1), stride=1, padding=0) + feat_v = feat_v.squeeze(2) # bsz * C * W + feat_v = paddle.transpose(feat_v, perm=[0, 2, 1]) # bsz * W * C + holistic_feat = self.rnn_encoder(feat_v)[0] # bsz * T * C + + if valid_ratios is not None: + valid_hf = [] + T = holistic_feat.shape[1] + for i in range(len(valid_ratios)): + valid_step = min(T, math.ceil(T * valid_ratios[i])) - 1 + valid_hf.append(holistic_feat[i, valid_step, :]) + valid_hf = paddle.stack(valid_hf, axis=0) + else: + valid_hf = holistic_feat[:, -1, :] # bsz * C + holistic_feat = self.linear(valid_hf) # bsz * C + + return holistic_feat + + +class BaseDecoder(nn.Layer): + def __init__(self, **kwargs): + super().__init__() + + def forward_train(self, feat, out_enc, targets, img_metas): + raise NotImplementedError + + def forward_test(self, feat, out_enc, img_metas): + raise NotImplementedError + + def forward(self, + feat, + out_enc, + label=None, + img_metas=None, + train_mode=True): + self.train_mode = train_mode + + if train_mode: + return self.forward_train(feat, out_enc, label, img_metas) + return self.forward_test(feat, out_enc, img_metas) + + +class ParallelSARDecoder(BaseDecoder): + """ + Args: + out_channels (int): Output class number. + enc_bi_rnn (bool): If True, use bidirectional RNN in encoder. + dec_bi_rnn (bool): If True, use bidirectional RNN in decoder. + dec_drop_rnn (float): Dropout of RNN layer in decoder. + dec_gru (bool): If True, use GRU, else LSTM in decoder. + d_model (int): Dim of channels from backbone. + d_enc (int): Dim of encoder RNN layer. + d_k (int): Dim of channels of attention module. + pred_dropout (float): Dropout probability of prediction layer. + max_seq_len (int): Maximum sequence length for decoding. + mask (bool): If True, mask padding in feature map. + start_idx (int): Index of start token. + padding_idx (int): Index of padding token. + pred_concat (bool): If True, concat glimpse feature from + attention with holistic feature and hidden state. + """ + + def __init__( + self, + out_channels, # 90 + unknown + start + padding + enc_bi_rnn=False, + dec_bi_rnn=False, + dec_drop_rnn=0.0, + dec_gru=False, + d_model=512, + d_enc=512, + d_k=64, + pred_dropout=0.1, + max_text_length=30, + mask=True, + pred_concat=True, + **kwargs): + super().__init__() + + self.num_classes = out_channels + self.enc_bi_rnn = enc_bi_rnn + self.d_k = d_k + self.start_idx = out_channels - 2 + self.padding_idx = out_channels - 1 + self.max_seq_len = max_text_length + self.mask = mask + self.pred_concat = pred_concat + + encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1) + decoder_rnn_out_size = encoder_rnn_out_size * (int(dec_bi_rnn) + 1) + + # 2D attention layer + self.conv1x1_1 = nn.Linear(decoder_rnn_out_size, d_k) + self.conv3x3_1 = nn.Conv2D( + d_model, d_k, kernel_size=3, stride=1, padding=1) + self.conv1x1_2 = nn.Linear(d_k, 1) + + # Decoder RNN layer + if dec_bi_rnn: + direction = 'bidirectional' + else: + direction = 'forward' + + kwargs = dict( + input_size=encoder_rnn_out_size, + hidden_size=encoder_rnn_out_size, + num_layers=2, + time_major=False, + dropout=dec_drop_rnn, + direction=direction) + if dec_gru: + self.rnn_decoder = nn.GRU(**kwargs) + else: + self.rnn_decoder = nn.LSTM(**kwargs) + + # Decoder input embedding + self.embedding = nn.Embedding( + self.num_classes, + encoder_rnn_out_size, + padding_idx=self.padding_idx) + + # Prediction layer + self.pred_dropout = nn.Dropout(pred_dropout) + pred_num_classes = self.num_classes - 1 + if pred_concat: + fc_in_channel = decoder_rnn_out_size + d_model + encoder_rnn_out_size + else: + fc_in_channel = d_model + self.prediction = nn.Linear(fc_in_channel, pred_num_classes) + + def _2d_attention(self, + decoder_input, + feat, + holistic_feat, + valid_ratios=None): + + y = self.rnn_decoder(decoder_input)[0] + # y: bsz * (seq_len + 1) * hidden_size + + attn_query = self.conv1x1_1(y) # bsz * (seq_len + 1) * attn_size + bsz, seq_len, attn_size = attn_query.shape + attn_query = paddle.unsqueeze(attn_query, axis=[3, 4]) + # (bsz, seq_len + 1, attn_size, 1, 1) + + attn_key = self.conv3x3_1(feat) + # bsz * attn_size * h * w + attn_key = attn_key.unsqueeze(1) + # bsz * 1 * attn_size * h * w + + attn_weight = paddle.tanh(paddle.add(attn_key, attn_query)) + + # bsz * (seq_len + 1) * attn_size * h * w + attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 3, 4, 2]) + # bsz * (seq_len + 1) * h * w * attn_size + attn_weight = self.conv1x1_2(attn_weight) + # bsz * (seq_len + 1) * h * w * 1 + bsz, T, h, w, c = attn_weight.shape + assert c == 1 + + if valid_ratios is not None: + # cal mask of attention weight + for i in range(len(valid_ratios)): + valid_width = min(w, math.ceil(w * valid_ratios[i])) + if valid_width < w: + attn_weight[i, :, :, valid_width:, :] = float('-inf') + + attn_weight = paddle.reshape(attn_weight, [bsz, T, -1]) + attn_weight = F.softmax(attn_weight, axis=-1) + + attn_weight = paddle.reshape(attn_weight, [bsz, T, h, w, c]) + attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 4, 2, 3]) + # attn_weight: bsz * T * c * h * w + # feat: bsz * c * h * w + attn_feat = paddle.sum(paddle.multiply(feat.unsqueeze(1), attn_weight), + (3, 4), + keepdim=False) + # bsz * (seq_len + 1) * C + + # Linear transformation + if self.pred_concat: + hf_c = holistic_feat.shape[-1] + holistic_feat = paddle.expand( + holistic_feat, shape=[bsz, seq_len, hf_c]) + y = self.prediction(paddle.concat((y, attn_feat, holistic_feat), 2)) + else: + y = self.prediction(attn_feat) + # bsz * (seq_len + 1) * num_classes + if self.train_mode: + y = self.pred_dropout(y) + + return y + + def forward_train(self, feat, out_enc, label, img_metas): + ''' + img_metas: [label, valid_ratio] + ''' + if img_metas is not None: + assert len(img_metas[0]) == feat.shape[0] + + valid_ratios = None + if img_metas is not None and self.mask: + valid_ratios = img_metas[-1] + + lab_embedding = self.embedding(label) + # bsz * seq_len * emb_dim + out_enc = out_enc.unsqueeze(1) + # bsz * 1 * emb_dim + in_dec = paddle.concat((out_enc, lab_embedding), axis=1) + # bsz * (seq_len + 1) * C + out_dec = self._2d_attention( + in_dec, feat, out_enc, valid_ratios=valid_ratios) + # bsz * (seq_len + 1) * num_classes + + return out_dec[:, 1:, :] # bsz * seq_len * num_classes + + def forward_test(self, feat, out_enc, img_metas): + if img_metas is not None: + assert len(img_metas[0]) == feat.shape[0] + + valid_ratios = None + if img_metas is not None and self.mask: + valid_ratios = img_metas[-1] + + seq_len = self.max_seq_len + bsz = feat.shape[0] + start_token = paddle.full( + (bsz, ), fill_value=self.start_idx, dtype='int64') + # bsz + start_token = self.embedding(start_token) + # bsz * emb_dim + emb_dim = start_token.shape[1] + start_token = start_token.unsqueeze(1) + start_token = paddle.expand(start_token, shape=[bsz, seq_len, emb_dim]) + # bsz * seq_len * emb_dim + out_enc = out_enc.unsqueeze(1) + # bsz * 1 * emb_dim + decoder_input = paddle.concat((out_enc, start_token), axis=1) + # bsz * (seq_len + 1) * emb_dim + + outputs = [] + for i in range(1, seq_len + 1): + decoder_output = self._2d_attention( + decoder_input, feat, out_enc, valid_ratios=valid_ratios) + char_output = decoder_output[:, i, :] # bsz * num_classes + char_output = F.softmax(char_output, -1) + outputs.append(char_output) + max_idx = paddle.argmax(char_output, axis=1, keepdim=False) + char_embedding = self.embedding(max_idx) # bsz * emb_dim + if i < seq_len: + decoder_input[:, i + 1, :] = char_embedding + + outputs = paddle.stack(outputs, 1) # bsz * seq_len * num_classes + + return outputs + + +class SARHead(nn.Layer): + def __init__(self, + in_channels, + out_channels, + enc_dim=512, + max_text_length=30, + enc_bi_rnn=False, + enc_drop_rnn=0.1, + enc_gru=False, + dec_bi_rnn=False, + dec_drop_rnn=0.0, + dec_gru=False, + d_k=512, + pred_dropout=0.1, + pred_concat=True, + **kwargs): + super(SARHead, self).__init__() + + # encoder module + self.encoder = SAREncoder( + enc_bi_rnn=enc_bi_rnn, + enc_drop_rnn=enc_drop_rnn, + enc_gru=enc_gru, + d_model=in_channels, + d_enc=enc_dim) + + # decoder module + self.decoder = ParallelSARDecoder( + out_channels=out_channels, + enc_bi_rnn=enc_bi_rnn, + dec_bi_rnn=dec_bi_rnn, + dec_drop_rnn=dec_drop_rnn, + dec_gru=dec_gru, + d_model=in_channels, + d_enc=enc_dim, + d_k=d_k, + pred_dropout=pred_dropout, + max_text_length=max_text_length, + pred_concat=pred_concat) + + def forward(self, feat, targets=None): + ''' + img_metas: [label, valid_ratio] + ''' + holistic_feat = self.encoder(feat, targets) # bsz c + + if self.training: + label = targets[0] # label + label = paddle.to_tensor(label, dtype='int64') + final_out = self.decoder( + feat, holistic_feat, label, img_metas=targets) + else: + final_out = self.decoder( + feat, + holistic_feat, + label=None, + img_metas=targets, + train_mode=False) + # (bsz, seq_len, num_classes) + + return final_out diff --git a/backend/ppocr/modeling/heads/rec_srn_head.py b/backend/ppocr/modeling/heads/rec_srn_head.py index d2c7fc02..8d59e471 100644 --- a/backend/ppocr/modeling/heads/rec_srn_head.py +++ b/backend/ppocr/modeling/heads/rec_srn_head.py @@ -250,7 +250,8 @@ def __init__(self, in_channels, out_channels, max_text_length, num_heads, self.gsrm.wrap_encoder1.prepare_decoder.emb0 = self.gsrm.wrap_encoder0.prepare_decoder.emb0 - def forward(self, inputs, others): + def forward(self, inputs, targets=None): + others = targets[-4:] encoder_word_pos = others[0] gsrm_word_pos = others[1] gsrm_slf_attn_bias1 = others[2] diff --git a/backend/ppocr/modeling/heads/self_attention.py b/backend/ppocr/modeling/heads/self_attention.py index 51d5198f..6c27fdbe 100644 --- a/backend/ppocr/modeling/heads/self_attention.py +++ b/backend/ppocr/modeling/heads/self_attention.py @@ -285,8 +285,7 @@ def __init__(self, process_cmd, d_model, dropout_rate): elif cmd == "n": # add layer normalization self.functors.append( self.add_sublayer( - "layer_norm_%d" % len( - self.sublayers(include_sublayers=False)), + "layer_norm_%d" % len(self.sublayers()), paddle.nn.LayerNorm( normalized_shape=d_model, weight_attr=fluid.ParamAttr( @@ -320,9 +319,7 @@ def __init__(self, self.src_emb_dim = src_emb_dim self.src_max_len = src_max_len self.emb = paddle.nn.Embedding( - num_embeddings=self.src_max_len, - embedding_dim=self.src_emb_dim, - sparse=True) + num_embeddings=self.src_max_len, embedding_dim=self.src_emb_dim) self.dropout_rate = dropout_rate def forward(self, src_word, src_pos): diff --git a/backend/ppocr/modeling/heads/table_att_head.py b/backend/ppocr/modeling/heads/table_att_head.py new file mode 100644 index 00000000..e354f40d --- /dev/null +++ b/backend/ppocr/modeling/heads/table_att_head.py @@ -0,0 +1,246 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +import numpy as np + + +class TableAttentionHead(nn.Layer): + def __init__(self, + in_channels, + hidden_size, + loc_type, + in_max_len=488, + max_text_length=100, + max_elem_length=800, + max_cell_num=500, + **kwargs): + super(TableAttentionHead, self).__init__() + self.input_size = in_channels[-1] + self.hidden_size = hidden_size + self.elem_num = 30 + self.max_text_length = max_text_length + self.max_elem_length = max_elem_length + self.max_cell_num = max_cell_num + + self.structure_attention_cell = AttentionGRUCell( + self.input_size, hidden_size, self.elem_num, use_gru=False) + self.structure_generator = nn.Linear(hidden_size, self.elem_num) + self.loc_type = loc_type + self.in_max_len = in_max_len + + if self.loc_type == 1: + self.loc_generator = nn.Linear(hidden_size, 4) + else: + if self.in_max_len == 640: + self.loc_fea_trans = nn.Linear(400, self.max_elem_length + 1) + elif self.in_max_len == 800: + self.loc_fea_trans = nn.Linear(625, self.max_elem_length + 1) + else: + self.loc_fea_trans = nn.Linear(256, self.max_elem_length + 1) + self.loc_generator = nn.Linear(self.input_size + hidden_size, 4) + + def _char_to_onehot(self, input_char, onehot_dim): + input_ont_hot = F.one_hot(input_char, onehot_dim) + return input_ont_hot + + def forward(self, inputs, targets=None): + # if and else branch are both needed when you want to assign a variable + # if you modify the var in just one branch, then the modification will not work. + fea = inputs[-1] + if len(fea.shape) == 3: + pass + else: + last_shape = int(np.prod(fea.shape[2:])) # gry added + fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape]) + fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels) + batch_size = fea.shape[0] + + hidden = paddle.zeros((batch_size, self.hidden_size)) + output_hiddens = [] + if self.training and targets is not None: + structure = targets[0] + for i in range(self.max_elem_length + 1): + elem_onehots = self._char_to_onehot( + structure[:, i], onehot_dim=self.elem_num) + (outputs, hidden), alpha = self.structure_attention_cell( + hidden, fea, elem_onehots) + output_hiddens.append(paddle.unsqueeze(outputs, axis=1)) + output = paddle.concat(output_hiddens, axis=1) + structure_probs = self.structure_generator(output) + if self.loc_type == 1: + loc_preds = self.loc_generator(output) + loc_preds = F.sigmoid(loc_preds) + else: + loc_fea = fea.transpose([0, 2, 1]) + loc_fea = self.loc_fea_trans(loc_fea) + loc_fea = loc_fea.transpose([0, 2, 1]) + loc_concat = paddle.concat([output, loc_fea], axis=2) + loc_preds = self.loc_generator(loc_concat) + loc_preds = F.sigmoid(loc_preds) + else: + temp_elem = paddle.zeros(shape=[batch_size], dtype="int32") + structure_probs = None + loc_preds = None + elem_onehots = None + outputs = None + alpha = None + max_elem_length = paddle.to_tensor(self.max_elem_length) + i = 0 + while i < max_elem_length + 1: + elem_onehots = self._char_to_onehot( + temp_elem, onehot_dim=self.elem_num) + (outputs, hidden), alpha = self.structure_attention_cell( + hidden, fea, elem_onehots) + output_hiddens.append(paddle.unsqueeze(outputs, axis=1)) + structure_probs_step = self.structure_generator(outputs) + temp_elem = structure_probs_step.argmax(axis=1, dtype="int32") + i += 1 + + output = paddle.concat(output_hiddens, axis=1) + structure_probs = self.structure_generator(output) + structure_probs = F.softmax(structure_probs) + if self.loc_type == 1: + loc_preds = self.loc_generator(output) + loc_preds = F.sigmoid(loc_preds) + else: + loc_fea = fea.transpose([0, 2, 1]) + loc_fea = self.loc_fea_trans(loc_fea) + loc_fea = loc_fea.transpose([0, 2, 1]) + loc_concat = paddle.concat([output, loc_fea], axis=2) + loc_preds = self.loc_generator(loc_concat) + loc_preds = F.sigmoid(loc_preds) + return {'structure_probs': structure_probs, 'loc_preds': loc_preds} + + +class AttentionGRUCell(nn.Layer): + def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False): + super(AttentionGRUCell, self).__init__() + self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False) + self.h2h = nn.Linear(hidden_size, hidden_size) + self.score = nn.Linear(hidden_size, 1, bias_attr=False) + self.rnn = nn.GRUCell( + input_size=input_size + num_embeddings, hidden_size=hidden_size) + self.hidden_size = hidden_size + + def forward(self, prev_hidden, batch_H, char_onehots): + batch_H_proj = self.i2h(batch_H) + prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden), axis=1) + res = paddle.add(batch_H_proj, prev_hidden_proj) + res = paddle.tanh(res) + e = self.score(res) + alpha = F.softmax(e, axis=1) + alpha = paddle.transpose(alpha, [0, 2, 1]) + context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1) + concat_context = paddle.concat([context, char_onehots], 1) + cur_hidden = self.rnn(concat_context, prev_hidden) + return cur_hidden, alpha + + +class AttentionLSTM(nn.Layer): + def __init__(self, in_channels, out_channels, hidden_size, **kwargs): + super(AttentionLSTM, self).__init__() + self.input_size = in_channels + self.hidden_size = hidden_size + self.num_classes = out_channels + + self.attention_cell = AttentionLSTMCell( + in_channels, hidden_size, out_channels, use_gru=False) + self.generator = nn.Linear(hidden_size, out_channels) + + def _char_to_onehot(self, input_char, onehot_dim): + input_ont_hot = F.one_hot(input_char, onehot_dim) + return input_ont_hot + + def forward(self, inputs, targets=None, batch_max_length=25): + batch_size = inputs.shape[0] + num_steps = batch_max_length + + hidden = (paddle.zeros((batch_size, self.hidden_size)), paddle.zeros( + (batch_size, self.hidden_size))) + output_hiddens = [] + + if targets is not None: + for i in range(num_steps): + # one-hot vectors for a i-th char + char_onehots = self._char_to_onehot( + targets[:, i], onehot_dim=self.num_classes) + hidden, alpha = self.attention_cell(hidden, inputs, + char_onehots) + + hidden = (hidden[1][0], hidden[1][1]) + output_hiddens.append(paddle.unsqueeze(hidden[0], axis=1)) + output = paddle.concat(output_hiddens, axis=1) + probs = self.generator(output) + + else: + targets = paddle.zeros(shape=[batch_size], dtype="int32") + probs = None + + for i in range(num_steps): + char_onehots = self._char_to_onehot( + targets, onehot_dim=self.num_classes) + hidden, alpha = self.attention_cell(hidden, inputs, + char_onehots) + probs_step = self.generator(hidden[0]) + hidden = (hidden[1][0], hidden[1][1]) + if probs is None: + probs = paddle.unsqueeze(probs_step, axis=1) + else: + probs = paddle.concat( + [probs, paddle.unsqueeze( + probs_step, axis=1)], axis=1) + + next_input = probs_step.argmax(axis=1) + + targets = next_input + + return probs + + +class AttentionLSTMCell(nn.Layer): + def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False): + super(AttentionLSTMCell, self).__init__() + self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False) + self.h2h = nn.Linear(hidden_size, hidden_size) + self.score = nn.Linear(hidden_size, 1, bias_attr=False) + if not use_gru: + self.rnn = nn.LSTMCell( + input_size=input_size + num_embeddings, hidden_size=hidden_size) + else: + self.rnn = nn.GRUCell( + input_size=input_size + num_embeddings, hidden_size=hidden_size) + + self.hidden_size = hidden_size + + def forward(self, prev_hidden, batch_H, char_onehots): + batch_H_proj = self.i2h(batch_H) + prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden[0]), axis=1) + res = paddle.add(batch_H_proj, prev_hidden_proj) + res = paddle.tanh(res) + e = self.score(res) + + alpha = F.softmax(e, axis=1) + alpha = paddle.transpose(alpha, [0, 2, 1]) + context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1) + concat_context = paddle.concat([context, char_onehots], 1) + cur_hidden = self.rnn(concat_context, prev_hidden) + + return cur_hidden, alpha diff --git a/backend/ppocr/modeling/necks/__init__.py b/backend/ppocr/modeling/necks/__init__.py index 405e062b..e10b082d 100644 --- a/backend/ppocr/modeling/necks/__init__.py +++ b/backend/ppocr/modeling/necks/__init__.py @@ -14,12 +14,21 @@ __all__ = ['build_neck'] + def build_neck(config): - from .db_fpn import DBFPN + from .db_fpn import DBFPN, RSEFPN, LKPAN from .east_fpn import EASTFPN from .sast_fpn import SASTFPN from .rnn import SequenceEncoder - support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder'] + from .pg_fpn import PGFPN + from .table_fpn import TableFPN + from .fpn import FPN + from .fce_fpn import FCEFPN + from .pren_fpn import PRENFPN + support_dict = [ + 'FPN', 'FCEFPN', 'LKPAN', 'DBFPN', 'RSEFPN', 'EASTFPN', 'SASTFPN', + 'SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN' + ] module_name = config.pop('name') assert module_name in support_dict, Exception('neck only support {}'.format( diff --git a/backend/ppocr/modeling/necks/db_fpn.py b/backend/ppocr/modeling/necks/db_fpn.py index 710023f3..93ed2dbf 100644 --- a/backend/ppocr/modeling/necks/db_fpn.py +++ b/backend/ppocr/modeling/necks/db_fpn.py @@ -20,6 +20,88 @@ from paddle import nn import paddle.nn.functional as F from paddle import ParamAttr +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../../..'))) + +from ppocr.modeling.backbones.det_mobilenet_v3 import SEModule + + +class DSConv(nn.Layer): + def __init__(self, + in_channels, + out_channels, + kernel_size, + padding, + stride=1, + groups=None, + if_act=True, + act="relu", + **kwargs): + super(DSConv, self).__init__() + if groups == None: + groups = in_channels + self.if_act = if_act + self.act = act + self.conv1 = nn.Conv2D( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + bias_attr=False) + + self.bn1 = nn.BatchNorm(num_channels=in_channels, act=None) + + self.conv2 = nn.Conv2D( + in_channels=in_channels, + out_channels=int(in_channels * 4), + kernel_size=1, + stride=1, + bias_attr=False) + + self.bn2 = nn.BatchNorm(num_channels=int(in_channels * 4), act=None) + + self.conv3 = nn.Conv2D( + in_channels=int(in_channels * 4), + out_channels=out_channels, + kernel_size=1, + stride=1, + bias_attr=False) + self._c = [in_channels, out_channels] + if in_channels != out_channels: + self.conv_end = nn.Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + bias_attr=False) + + def forward(self, inputs): + + x = self.conv1(inputs) + x = self.bn1(x) + + x = self.conv2(x) + x = self.bn2(x) + if self.if_act: + if self.act == "relu": + x = F.relu(x) + elif self.act == "hardswish": + x = F.hardswish(x) + else: + print("The activation function({}) is selected incorrectly.". + format(self.act)) + exit() + + x = self.conv3(x) + if self._c[0] != self._c[1]: + x = x + self.conv_end(inputs) + return x class DBFPN(nn.Layer): @@ -32,61 +114,53 @@ def __init__(self, in_channels, out_channels, **kwargs): in_channels=in_channels[0], out_channels=self.out_channels, kernel_size=1, - weight_attr=ParamAttr( - name='conv2d_51.w_0', initializer=weight_attr), + weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False) self.in3_conv = nn.Conv2D( in_channels=in_channels[1], out_channels=self.out_channels, kernel_size=1, - weight_attr=ParamAttr( - name='conv2d_50.w_0', initializer=weight_attr), + weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False) self.in4_conv = nn.Conv2D( in_channels=in_channels[2], out_channels=self.out_channels, kernel_size=1, - weight_attr=ParamAttr( - name='conv2d_49.w_0', initializer=weight_attr), + weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False) self.in5_conv = nn.Conv2D( in_channels=in_channels[3], out_channels=self.out_channels, kernel_size=1, - weight_attr=ParamAttr( - name='conv2d_48.w_0', initializer=weight_attr), + weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False) self.p5_conv = nn.Conv2D( in_channels=self.out_channels, out_channels=self.out_channels // 4, kernel_size=3, padding=1, - weight_attr=ParamAttr( - name='conv2d_52.w_0', initializer=weight_attr), + weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False) self.p4_conv = nn.Conv2D( in_channels=self.out_channels, out_channels=self.out_channels // 4, kernel_size=3, padding=1, - weight_attr=ParamAttr( - name='conv2d_53.w_0', initializer=weight_attr), + weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False) self.p3_conv = nn.Conv2D( in_channels=self.out_channels, out_channels=self.out_channels // 4, kernel_size=3, padding=1, - weight_attr=ParamAttr( - name='conv2d_54.w_0', initializer=weight_attr), + weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False) self.p2_conv = nn.Conv2D( in_channels=self.out_channels, out_channels=self.out_channels // 4, kernel_size=3, padding=1, - weight_attr=ParamAttr( - name='conv2d_55.w_0', initializer=weight_attr), + weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False) def forward(self, x): @@ -114,3 +188,171 @@ def forward(self, x): fuse = paddle.concat([p5, p4, p3, p2], axis=1) return fuse + + +class RSELayer(nn.Layer): + def __init__(self, in_channels, out_channels, kernel_size, shortcut=True): + super(RSELayer, self).__init__() + weight_attr = paddle.nn.initializer.KaimingUniform() + self.out_channels = out_channels + self.in_conv = nn.Conv2D( + in_channels=in_channels, + out_channels=self.out_channels, + kernel_size=kernel_size, + padding=int(kernel_size // 2), + weight_attr=ParamAttr(initializer=weight_attr), + bias_attr=False) + self.se_block = SEModule(self.out_channels) + self.shortcut = shortcut + + def forward(self, ins): + x = self.in_conv(ins) + if self.shortcut: + out = x + self.se_block(x) + else: + out = self.se_block(x) + return out + + +class RSEFPN(nn.Layer): + def __init__(self, in_channels, out_channels, shortcut=True, **kwargs): + super(RSEFPN, self).__init__() + self.out_channels = out_channels + self.ins_conv = nn.LayerList() + self.inp_conv = nn.LayerList() + + for i in range(len(in_channels)): + self.ins_conv.append( + RSELayer( + in_channels[i], + out_channels, + kernel_size=1, + shortcut=shortcut)) + self.inp_conv.append( + RSELayer( + out_channels, + out_channels // 4, + kernel_size=3, + shortcut=shortcut)) + + def forward(self, x): + c2, c3, c4, c5 = x + + in5 = self.ins_conv[3](c5) + in4 = self.ins_conv[2](c4) + in3 = self.ins_conv[1](c3) + in2 = self.ins_conv[0](c2) + + out4 = in4 + F.upsample( + in5, scale_factor=2, mode="nearest", align_mode=1) # 1/16 + out3 = in3 + F.upsample( + out4, scale_factor=2, mode="nearest", align_mode=1) # 1/8 + out2 = in2 + F.upsample( + out3, scale_factor=2, mode="nearest", align_mode=1) # 1/4 + + p5 = self.inp_conv[3](in5) + p4 = self.inp_conv[2](out4) + p3 = self.inp_conv[1](out3) + p2 = self.inp_conv[0](out2) + + p5 = F.upsample(p5, scale_factor=8, mode="nearest", align_mode=1) + p4 = F.upsample(p4, scale_factor=4, mode="nearest", align_mode=1) + p3 = F.upsample(p3, scale_factor=2, mode="nearest", align_mode=1) + + fuse = paddle.concat([p5, p4, p3, p2], axis=1) + return fuse + + +class LKPAN(nn.Layer): + def __init__(self, in_channels, out_channels, mode='large', **kwargs): + super(LKPAN, self).__init__() + self.out_channels = out_channels + weight_attr = paddle.nn.initializer.KaimingUniform() + + self.ins_conv = nn.LayerList() + self.inp_conv = nn.LayerList() + # pan head + self.pan_head_conv = nn.LayerList() + self.pan_lat_conv = nn.LayerList() + + if mode.lower() == 'lite': + p_layer = DSConv + elif mode.lower() == 'large': + p_layer = nn.Conv2D + else: + raise ValueError( + "mode can only be one of ['lite', 'large'], but received {}". + format(mode)) + + for i in range(len(in_channels)): + self.ins_conv.append( + nn.Conv2D( + in_channels=in_channels[i], + out_channels=self.out_channels, + kernel_size=1, + weight_attr=ParamAttr(initializer=weight_attr), + bias_attr=False)) + + self.inp_conv.append( + p_layer( + in_channels=self.out_channels, + out_channels=self.out_channels // 4, + kernel_size=9, + padding=4, + weight_attr=ParamAttr(initializer=weight_attr), + bias_attr=False)) + + if i > 0: + self.pan_head_conv.append( + nn.Conv2D( + in_channels=self.out_channels // 4, + out_channels=self.out_channels // 4, + kernel_size=3, + padding=1, + stride=2, + weight_attr=ParamAttr(initializer=weight_attr), + bias_attr=False)) + self.pan_lat_conv.append( + p_layer( + in_channels=self.out_channels // 4, + out_channels=self.out_channels // 4, + kernel_size=9, + padding=4, + weight_attr=ParamAttr(initializer=weight_attr), + bias_attr=False)) + + def forward(self, x): + c2, c3, c4, c5 = x + + in5 = self.ins_conv[3](c5) + in4 = self.ins_conv[2](c4) + in3 = self.ins_conv[1](c3) + in2 = self.ins_conv[0](c2) + + out4 = in4 + F.upsample( + in5, scale_factor=2, mode="nearest", align_mode=1) # 1/16 + out3 = in3 + F.upsample( + out4, scale_factor=2, mode="nearest", align_mode=1) # 1/8 + out2 = in2 + F.upsample( + out3, scale_factor=2, mode="nearest", align_mode=1) # 1/4 + + f5 = self.inp_conv[3](in5) + f4 = self.inp_conv[2](out4) + f3 = self.inp_conv[1](out3) + f2 = self.inp_conv[0](out2) + + pan3 = f3 + self.pan_head_conv[0](f2) + pan4 = f4 + self.pan_head_conv[1](pan3) + pan5 = f5 + self.pan_head_conv[2](pan4) + + p2 = self.pan_lat_conv[0](f2) + p3 = self.pan_lat_conv[1](pan3) + p4 = self.pan_lat_conv[2](pan4) + p5 = self.pan_lat_conv[3](pan5) + + p5 = F.upsample(p5, scale_factor=8, mode="nearest", align_mode=1) + p4 = F.upsample(p4, scale_factor=4, mode="nearest", align_mode=1) + p3 = F.upsample(p3, scale_factor=2, mode="nearest", align_mode=1) + + fuse = paddle.concat([p5, p4, p3, p2], axis=1) + return fuse diff --git a/backend/ppocr/modeling/necks/fce_fpn.py b/backend/ppocr/modeling/necks/fce_fpn.py new file mode 100644 index 00000000..954e964e --- /dev/null +++ b/backend/ppocr/modeling/necks/fce_fpn.py @@ -0,0 +1,280 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.3/ppdet/modeling/necks/fpn.py +""" + +import paddle.nn as nn +import paddle.nn.functional as F +from paddle import ParamAttr +from paddle.nn.initializer import XavierUniform +from paddle.nn.initializer import Normal +from paddle.regularizer import L2Decay + +__all__ = ['FCEFPN'] + + +class ConvNormLayer(nn.Layer): + def __init__(self, + ch_in, + ch_out, + filter_size, + stride, + groups=1, + norm_type='bn', + norm_decay=0., + norm_groups=32, + lr_scale=1., + freeze_norm=False, + initializer=Normal( + mean=0., std=0.01)): + super(ConvNormLayer, self).__init__() + assert norm_type in ['bn', 'sync_bn', 'gn'] + + bias_attr = False + + self.conv = nn.Conv2D( + in_channels=ch_in, + out_channels=ch_out, + kernel_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + weight_attr=ParamAttr( + initializer=initializer, learning_rate=1.), + bias_attr=bias_attr) + + norm_lr = 0. if freeze_norm else 1. + param_attr = ParamAttr( + learning_rate=norm_lr, + regularizer=L2Decay(norm_decay) if norm_decay is not None else None) + bias_attr = ParamAttr( + learning_rate=norm_lr, + regularizer=L2Decay(norm_decay) if norm_decay is not None else None) + if norm_type == 'bn': + self.norm = nn.BatchNorm2D( + ch_out, weight_attr=param_attr, bias_attr=bias_attr) + elif norm_type == 'sync_bn': + self.norm = nn.SyncBatchNorm( + ch_out, weight_attr=param_attr, bias_attr=bias_attr) + elif norm_type == 'gn': + self.norm = nn.GroupNorm( + num_groups=norm_groups, + num_channels=ch_out, + weight_attr=param_attr, + bias_attr=bias_attr) + + def forward(self, inputs): + out = self.conv(inputs) + out = self.norm(out) + return out + + +class FCEFPN(nn.Layer): + """ + Feature Pyramid Network, see https://arxiv.org/abs/1612.03144 + Args: + in_channels (list[int]): input channels of each level which can be + derived from the output shape of backbone by from_config + out_channels (list[int]): output channel of each level + spatial_scales (list[float]): the spatial scales between input feature + maps and original input image which can be derived from the output + shape of backbone by from_config + has_extra_convs (bool): whether to add extra conv to the last level. + default False + extra_stage (int): the number of extra stages added to the last level. + default 1 + use_c5 (bool): Whether to use c5 as the input of extra stage, + otherwise p5 is used. default True + norm_type (string|None): The normalization type in FPN module. If + norm_type is None, norm will not be used after conv and if + norm_type is string, bn, gn, sync_bn are available. default None + norm_decay (float): weight decay for normalization layer weights. + default 0. + freeze_norm (bool): whether to freeze normalization layer. + default False + relu_before_extra_convs (bool): whether to add relu before extra convs. + default False + + """ + + def __init__(self, + in_channels, + out_channels, + spatial_scales=[0.25, 0.125, 0.0625, 0.03125], + has_extra_convs=False, + extra_stage=1, + use_c5=True, + norm_type=None, + norm_decay=0., + freeze_norm=False, + relu_before_extra_convs=True): + super(FCEFPN, self).__init__() + self.out_channels = out_channels + for s in range(extra_stage): + spatial_scales = spatial_scales + [spatial_scales[-1] / 2.] + self.spatial_scales = spatial_scales + self.has_extra_convs = has_extra_convs + self.extra_stage = extra_stage + self.use_c5 = use_c5 + self.relu_before_extra_convs = relu_before_extra_convs + self.norm_type = norm_type + self.norm_decay = norm_decay + self.freeze_norm = freeze_norm + + self.lateral_convs = [] + self.fpn_convs = [] + fan = out_channels * 3 * 3 + + # stage index 0,1,2,3 stands for res2,res3,res4,res5 on ResNet Backbone + # 0 <= st_stage < ed_stage <= 3 + st_stage = 4 - len(in_channels) + ed_stage = st_stage + len(in_channels) - 1 + for i in range(st_stage, ed_stage + 1): + if i == 3: + lateral_name = 'fpn_inner_res5_sum' + else: + lateral_name = 'fpn_inner_res{}_sum_lateral'.format(i + 2) + in_c = in_channels[i - st_stage] + if self.norm_type is not None: + lateral = self.add_sublayer( + lateral_name, + ConvNormLayer( + ch_in=in_c, + ch_out=out_channels, + filter_size=1, + stride=1, + norm_type=self.norm_type, + norm_decay=self.norm_decay, + freeze_norm=self.freeze_norm, + initializer=XavierUniform(fan_out=in_c))) + else: + lateral = self.add_sublayer( + lateral_name, + nn.Conv2D( + in_channels=in_c, + out_channels=out_channels, + kernel_size=1, + weight_attr=ParamAttr( + initializer=XavierUniform(fan_out=in_c)))) + self.lateral_convs.append(lateral) + + for i in range(st_stage, ed_stage + 1): + fpn_name = 'fpn_res{}_sum'.format(i + 2) + if self.norm_type is not None: + fpn_conv = self.add_sublayer( + fpn_name, + ConvNormLayer( + ch_in=out_channels, + ch_out=out_channels, + filter_size=3, + stride=1, + norm_type=self.norm_type, + norm_decay=self.norm_decay, + freeze_norm=self.freeze_norm, + initializer=XavierUniform(fan_out=fan))) + else: + fpn_conv = self.add_sublayer( + fpn_name, + nn.Conv2D( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + weight_attr=ParamAttr( + initializer=XavierUniform(fan_out=fan)))) + self.fpn_convs.append(fpn_conv) + + # add extra conv levels for RetinaNet(use_c5)/FCOS(use_p5) + if self.has_extra_convs: + for i in range(self.extra_stage): + lvl = ed_stage + 1 + i + if i == 0 and self.use_c5: + in_c = in_channels[-1] + else: + in_c = out_channels + extra_fpn_name = 'fpn_{}'.format(lvl + 2) + if self.norm_type is not None: + extra_fpn_conv = self.add_sublayer( + extra_fpn_name, + ConvNormLayer( + ch_in=in_c, + ch_out=out_channels, + filter_size=3, + stride=2, + norm_type=self.norm_type, + norm_decay=self.norm_decay, + freeze_norm=self.freeze_norm, + initializer=XavierUniform(fan_out=fan))) + else: + extra_fpn_conv = self.add_sublayer( + extra_fpn_name, + nn.Conv2D( + in_channels=in_c, + out_channels=out_channels, + kernel_size=3, + stride=2, + padding=1, + weight_attr=ParamAttr( + initializer=XavierUniform(fan_out=fan)))) + self.fpn_convs.append(extra_fpn_conv) + + @classmethod + def from_config(cls, cfg, input_shape): + return { + 'in_channels': [i.channels for i in input_shape], + 'spatial_scales': [1.0 / i.stride for i in input_shape], + } + + def forward(self, body_feats): + laterals = [] + num_levels = len(body_feats) + + for i in range(num_levels): + laterals.append(self.lateral_convs[i](body_feats[i])) + + for i in range(1, num_levels): + lvl = num_levels - i + upsample = F.interpolate( + laterals[lvl], + scale_factor=2., + mode='nearest', ) + laterals[lvl - 1] += upsample + + fpn_output = [] + for lvl in range(num_levels): + fpn_output.append(self.fpn_convs[lvl](laterals[lvl])) + + if self.extra_stage > 0: + # use max pool to get more levels on top of outputs (Faster R-CNN, Mask R-CNN) + if not self.has_extra_convs: + assert self.extra_stage == 1, 'extra_stage should be 1 if FPN has not extra convs' + fpn_output.append(F.max_pool2d(fpn_output[-1], 1, stride=2)) + # add extra conv levels for RetinaNet(use_c5)/FCOS(use_p5) + else: + if self.use_c5: + extra_source = body_feats[-1] + else: + extra_source = fpn_output[-1] + fpn_output.append(self.fpn_convs[num_levels](extra_source)) + + for i in range(1, self.extra_stage): + if self.relu_before_extra_convs: + fpn_output.append(self.fpn_convs[num_levels + i](F.relu( + fpn_output[-1]))) + else: + fpn_output.append(self.fpn_convs[num_levels + i]( + fpn_output[-1])) + return fpn_output diff --git a/backend/ppocr/modeling/necks/fpn.py b/backend/ppocr/modeling/necks/fpn.py new file mode 100644 index 00000000..48c85b1e --- /dev/null +++ b/backend/ppocr/modeling/necks/fpn.py @@ -0,0 +1,138 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/whai362/PSENet/blob/python3/models/neck/fpn.py +""" + +import paddle.nn as nn +import paddle +import math +import paddle.nn.functional as F + + +class Conv_BN_ReLU(nn.Layer): + def __init__(self, + in_planes, + out_planes, + kernel_size=1, + stride=1, + padding=0): + super(Conv_BN_ReLU, self).__init__() + self.conv = nn.Conv2D( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias_attr=False) + self.bn = nn.BatchNorm2D(out_planes, momentum=0.1) + self.relu = nn.ReLU() + + for m in self.sublayers(): + if isinstance(m, nn.Conv2D): + n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels + m.weight = paddle.create_parameter( + shape=m.weight.shape, + dtype='float32', + default_initializer=paddle.nn.initializer.Normal( + 0, math.sqrt(2. / n))) + elif isinstance(m, nn.BatchNorm2D): + m.weight = paddle.create_parameter( + shape=m.weight.shape, + dtype='float32', + default_initializer=paddle.nn.initializer.Constant(1.0)) + m.bias = paddle.create_parameter( + shape=m.bias.shape, + dtype='float32', + default_initializer=paddle.nn.initializer.Constant(0.0)) + + def forward(self, x): + return self.relu(self.bn(self.conv(x))) + + +class FPN(nn.Layer): + def __init__(self, in_channels, out_channels): + super(FPN, self).__init__() + + # Top layer + self.toplayer_ = Conv_BN_ReLU( + in_channels[3], out_channels, kernel_size=1, stride=1, padding=0) + # Lateral layers + self.latlayer1_ = Conv_BN_ReLU( + in_channels[2], out_channels, kernel_size=1, stride=1, padding=0) + + self.latlayer2_ = Conv_BN_ReLU( + in_channels[1], out_channels, kernel_size=1, stride=1, padding=0) + + self.latlayer3_ = Conv_BN_ReLU( + in_channels[0], out_channels, kernel_size=1, stride=1, padding=0) + + # Smooth layers + self.smooth1_ = Conv_BN_ReLU( + out_channels, out_channels, kernel_size=3, stride=1, padding=1) + + self.smooth2_ = Conv_BN_ReLU( + out_channels, out_channels, kernel_size=3, stride=1, padding=1) + + self.smooth3_ = Conv_BN_ReLU( + out_channels, out_channels, kernel_size=3, stride=1, padding=1) + + self.out_channels = out_channels * 4 + for m in self.sublayers(): + if isinstance(m, nn.Conv2D): + n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels + m.weight = paddle.create_parameter( + shape=m.weight.shape, + dtype='float32', + default_initializer=paddle.nn.initializer.Normal( + 0, math.sqrt(2. / n))) + elif isinstance(m, nn.BatchNorm2D): + m.weight = paddle.create_parameter( + shape=m.weight.shape, + dtype='float32', + default_initializer=paddle.nn.initializer.Constant(1.0)) + m.bias = paddle.create_parameter( + shape=m.bias.shape, + dtype='float32', + default_initializer=paddle.nn.initializer.Constant(0.0)) + + def _upsample(self, x, scale=1): + return F.upsample(x, scale_factor=scale, mode='bilinear') + + def _upsample_add(self, x, y, scale=1): + return F.upsample(x, scale_factor=scale, mode='bilinear') + y + + def forward(self, x): + f2, f3, f4, f5 = x + p5 = self.toplayer_(f5) + + f4 = self.latlayer1_(f4) + p4 = self._upsample_add(p5, f4, 2) + p4 = self.smooth1_(p4) + + f3 = self.latlayer2_(f3) + p3 = self._upsample_add(p4, f3, 2) + p3 = self.smooth2_(p3) + + f2 = self.latlayer3_(f2) + p2 = self._upsample_add(p3, f2, 2) + p2 = self.smooth3_(p2) + + p3 = self._upsample(p3, 2) + p4 = self._upsample(p4, 4) + p5 = self._upsample(p5, 8) + + fuse = paddle.concat([p2, p3, p4, p5], axis=1) + return fuse diff --git a/backend/ppocr/modeling/necks/pg_fpn.py b/backend/ppocr/modeling/necks/pg_fpn.py new file mode 100644 index 00000000..3f64539f --- /dev/null +++ b/backend/ppocr/modeling/necks/pg_fpn.py @@ -0,0 +1,314 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn +import paddle.nn.functional as F +from paddle import ParamAttr + + +class ConvBNLayer(nn.Layer): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + groups=1, + is_vd_mode=False, + act=None, + name=None): + super(ConvBNLayer, self).__init__() + + self.is_vd_mode = is_vd_mode + self._pool2d_avg = nn.AvgPool2D( + kernel_size=2, stride=2, padding=0, ceil_mode=True) + self._conv = nn.Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=(kernel_size - 1) // 2, + groups=groups, + weight_attr=ParamAttr(name=name + "_weights"), + bias_attr=False) + if name == "conv1": + bn_name = "bn_" + name + else: + bn_name = "bn" + name[3:] + self._batch_norm = nn.BatchNorm( + out_channels, + act=act, + param_attr=ParamAttr(name=bn_name + '_scale'), + bias_attr=ParamAttr(bn_name + '_offset'), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance', + use_global_stats=False) + + def forward(self, inputs): + y = self._conv(inputs) + y = self._batch_norm(y) + return y + + +class DeConvBNLayer(nn.Layer): + def __init__(self, + in_channels, + out_channels, + kernel_size=4, + stride=2, + padding=1, + groups=1, + if_act=True, + act=None, + name=None): + super(DeConvBNLayer, self).__init__() + + self.if_act = if_act + self.act = act + self.deconv = nn.Conv2DTranspose( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + weight_attr=ParamAttr(name=name + '_weights'), + bias_attr=False) + self.bn = nn.BatchNorm( + num_channels=out_channels, + act=act, + param_attr=ParamAttr(name="bn_" + name + "_scale"), + bias_attr=ParamAttr(name="bn_" + name + "_offset"), + moving_mean_name="bn_" + name + "_mean", + moving_variance_name="bn_" + name + "_variance", + use_global_stats=False) + + def forward(self, x): + x = self.deconv(x) + x = self.bn(x) + return x + + +class PGFPN(nn.Layer): + def __init__(self, in_channels, **kwargs): + super(PGFPN, self).__init__() + num_inputs = [2048, 2048, 1024, 512, 256] + num_outputs = [256, 256, 192, 192, 128] + self.out_channels = 128 + self.conv_bn_layer_1 = ConvBNLayer( + in_channels=3, + out_channels=32, + kernel_size=3, + stride=1, + act=None, + name='FPN_d1') + self.conv_bn_layer_2 = ConvBNLayer( + in_channels=64, + out_channels=64, + kernel_size=3, + stride=1, + act=None, + name='FPN_d2') + self.conv_bn_layer_3 = ConvBNLayer( + in_channels=256, + out_channels=128, + kernel_size=3, + stride=1, + act=None, + name='FPN_d3') + self.conv_bn_layer_4 = ConvBNLayer( + in_channels=32, + out_channels=64, + kernel_size=3, + stride=2, + act=None, + name='FPN_d4') + self.conv_bn_layer_5 = ConvBNLayer( + in_channels=64, + out_channels=64, + kernel_size=3, + stride=1, + act='relu', + name='FPN_d5') + self.conv_bn_layer_6 = ConvBNLayer( + in_channels=64, + out_channels=128, + kernel_size=3, + stride=2, + act=None, + name='FPN_d6') + self.conv_bn_layer_7 = ConvBNLayer( + in_channels=128, + out_channels=128, + kernel_size=3, + stride=1, + act='relu', + name='FPN_d7') + self.conv_bn_layer_8 = ConvBNLayer( + in_channels=128, + out_channels=128, + kernel_size=1, + stride=1, + act=None, + name='FPN_d8') + + self.conv_h0 = ConvBNLayer( + in_channels=num_inputs[0], + out_channels=num_outputs[0], + kernel_size=1, + stride=1, + act=None, + name="conv_h{}".format(0)) + self.conv_h1 = ConvBNLayer( + in_channels=num_inputs[1], + out_channels=num_outputs[1], + kernel_size=1, + stride=1, + act=None, + name="conv_h{}".format(1)) + self.conv_h2 = ConvBNLayer( + in_channels=num_inputs[2], + out_channels=num_outputs[2], + kernel_size=1, + stride=1, + act=None, + name="conv_h{}".format(2)) + self.conv_h3 = ConvBNLayer( + in_channels=num_inputs[3], + out_channels=num_outputs[3], + kernel_size=1, + stride=1, + act=None, + name="conv_h{}".format(3)) + self.conv_h4 = ConvBNLayer( + in_channels=num_inputs[4], + out_channels=num_outputs[4], + kernel_size=1, + stride=1, + act=None, + name="conv_h{}".format(4)) + + self.dconv0 = DeConvBNLayer( + in_channels=num_outputs[0], + out_channels=num_outputs[0 + 1], + name="dconv_{}".format(0)) + self.dconv1 = DeConvBNLayer( + in_channels=num_outputs[1], + out_channels=num_outputs[1 + 1], + act=None, + name="dconv_{}".format(1)) + self.dconv2 = DeConvBNLayer( + in_channels=num_outputs[2], + out_channels=num_outputs[2 + 1], + act=None, + name="dconv_{}".format(2)) + self.dconv3 = DeConvBNLayer( + in_channels=num_outputs[3], + out_channels=num_outputs[3 + 1], + act=None, + name="dconv_{}".format(3)) + self.conv_g1 = ConvBNLayer( + in_channels=num_outputs[1], + out_channels=num_outputs[1], + kernel_size=3, + stride=1, + act='relu', + name="conv_g{}".format(1)) + self.conv_g2 = ConvBNLayer( + in_channels=num_outputs[2], + out_channels=num_outputs[2], + kernel_size=3, + stride=1, + act='relu', + name="conv_g{}".format(2)) + self.conv_g3 = ConvBNLayer( + in_channels=num_outputs[3], + out_channels=num_outputs[3], + kernel_size=3, + stride=1, + act='relu', + name="conv_g{}".format(3)) + self.conv_g4 = ConvBNLayer( + in_channels=num_outputs[4], + out_channels=num_outputs[4], + kernel_size=3, + stride=1, + act='relu', + name="conv_g{}".format(4)) + self.convf = ConvBNLayer( + in_channels=num_outputs[4], + out_channels=num_outputs[4], + kernel_size=1, + stride=1, + act=None, + name="conv_f{}".format(4)) + + def forward(self, x): + c0, c1, c2, c3, c4, c5, c6 = x + # FPN_Down_Fusion + f = [c0, c1, c2] + g = [None, None, None] + h = [None, None, None] + h[0] = self.conv_bn_layer_1(f[0]) + h[1] = self.conv_bn_layer_2(f[1]) + h[2] = self.conv_bn_layer_3(f[2]) + + g[0] = self.conv_bn_layer_4(h[0]) + g[1] = paddle.add(g[0], h[1]) + g[1] = F.relu(g[1]) + g[1] = self.conv_bn_layer_5(g[1]) + g[1] = self.conv_bn_layer_6(g[1]) + + g[2] = paddle.add(g[1], h[2]) + g[2] = F.relu(g[2]) + g[2] = self.conv_bn_layer_7(g[2]) + f_down = self.conv_bn_layer_8(g[2]) + + # FPN UP Fusion + f1 = [c6, c5, c4, c3, c2] + g = [None, None, None, None, None] + h = [None, None, None, None, None] + h[0] = self.conv_h0(f1[0]) + h[1] = self.conv_h1(f1[1]) + h[2] = self.conv_h2(f1[2]) + h[3] = self.conv_h3(f1[3]) + h[4] = self.conv_h4(f1[4]) + + g[0] = self.dconv0(h[0]) + g[1] = paddle.add(g[0], h[1]) + g[1] = F.relu(g[1]) + g[1] = self.conv_g1(g[1]) + g[1] = self.dconv1(g[1]) + + g[2] = paddle.add(g[1], h[2]) + g[2] = F.relu(g[2]) + g[2] = self.conv_g2(g[2]) + g[2] = self.dconv2(g[2]) + + g[3] = paddle.add(g[2], h[3]) + g[3] = F.relu(g[3]) + g[3] = self.conv_g3(g[3]) + g[3] = self.dconv3(g[3]) + + g[4] = paddle.add(x=g[3], y=h[4]) + g[4] = F.relu(g[4]) + g[4] = self.conv_g4(g[4]) + f_up = self.convf(g[4]) + f_common = paddle.add(f_down, f_up) + f_common = F.relu(f_common) + return f_common diff --git a/backend/ppocr/modeling/necks/pren_fpn.py b/backend/ppocr/modeling/necks/pren_fpn.py new file mode 100644 index 00000000..afbdcea8 --- /dev/null +++ b/backend/ppocr/modeling/necks/pren_fpn.py @@ -0,0 +1,163 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Code is refer from: +https://github.com/RuijieJ/pren/blob/main/Nets/Aggregation.py +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn +import paddle.nn.functional as F + + +class PoolAggregate(nn.Layer): + def __init__(self, n_r, d_in, d_middle=None, d_out=None): + super(PoolAggregate, self).__init__() + if not d_middle: + d_middle = d_in + if not d_out: + d_out = d_in + + self.d_in = d_in + self.d_middle = d_middle + self.d_out = d_out + self.act = nn.Swish() + + self.n_r = n_r + self.aggs = self._build_aggs() + + def _build_aggs(self): + aggs = [] + for i in range(self.n_r): + aggs.append( + self.add_sublayer( + '{}'.format(i), + nn.Sequential( + ('conv1', nn.Conv2D( + self.d_in, self.d_middle, 3, 2, 1, bias_attr=False) + ), ('bn1', nn.BatchNorm(self.d_middle)), + ('act', self.act), ('conv2', nn.Conv2D( + self.d_middle, self.d_out, 3, 2, 1, bias_attr=False + )), ('bn2', nn.BatchNorm(self.d_out))))) + return aggs + + def forward(self, x): + b = x.shape[0] + outs = [] + for agg in self.aggs: + y = agg(x) + p = F.adaptive_avg_pool2d(y, 1) + outs.append(p.reshape((b, 1, self.d_out))) + out = paddle.concat(outs, 1) + return out + + +class WeightAggregate(nn.Layer): + def __init__(self, n_r, d_in, d_middle=None, d_out=None): + super(WeightAggregate, self).__init__() + if not d_middle: + d_middle = d_in + if not d_out: + d_out = d_in + + self.n_r = n_r + self.d_out = d_out + self.act = nn.Swish() + + self.conv_n = nn.Sequential( + ('conv1', nn.Conv2D( + d_in, d_in, 3, 1, 1, + bias_attr=False)), ('bn1', nn.BatchNorm(d_in)), + ('act1', self.act), ('conv2', nn.Conv2D( + d_in, n_r, 1, bias_attr=False)), ('bn2', nn.BatchNorm(n_r)), + ('act2', nn.Sigmoid())) + self.conv_d = nn.Sequential( + ('conv1', nn.Conv2D( + d_in, d_middle, 3, 1, 1, + bias_attr=False)), ('bn1', nn.BatchNorm(d_middle)), + ('act1', self.act), ('conv2', nn.Conv2D( + d_middle, d_out, 1, + bias_attr=False)), ('bn2', nn.BatchNorm(d_out))) + + def forward(self, x): + b, _, h, w = x.shape + + hmaps = self.conv_n(x) + fmaps = self.conv_d(x) + r = paddle.bmm( + hmaps.reshape((b, self.n_r, h * w)), + fmaps.reshape((b, self.d_out, h * w)).transpose((0, 2, 1))) + return r + + +class GCN(nn.Layer): + def __init__(self, d_in, n_in, d_out=None, n_out=None, dropout=0.1): + super(GCN, self).__init__() + if not d_out: + d_out = d_in + if not n_out: + n_out = d_in + + self.conv_n = nn.Conv1D(n_in, n_out, 1) + self.linear = nn.Linear(d_in, d_out) + self.dropout = nn.Dropout(dropout) + self.act = nn.Swish() + + def forward(self, x): + x = self.conv_n(x) + x = self.dropout(self.linear(x)) + return self.act(x) + + +class PRENFPN(nn.Layer): + def __init__(self, in_channels, n_r, d_model, max_len, dropout): + super(PRENFPN, self).__init__() + assert len(in_channels) == 3, "in_channels' length must be 3." + c1, c2, c3 = in_channels # the depths are from big to small + # build fpn + assert d_model % 3 == 0, "{} can't be divided by 3.".format(d_model) + self.agg_p1 = PoolAggregate(n_r, c1, d_out=d_model // 3) + self.agg_p2 = PoolAggregate(n_r, c2, d_out=d_model // 3) + self.agg_p3 = PoolAggregate(n_r, c3, d_out=d_model // 3) + + self.agg_w1 = WeightAggregate(n_r, c1, 4 * c1, d_model // 3) + self.agg_w2 = WeightAggregate(n_r, c2, 4 * c2, d_model // 3) + self.agg_w3 = WeightAggregate(n_r, c3, 4 * c3, d_model // 3) + + self.gcn_pool = GCN(d_model, n_r, d_model, max_len, dropout) + self.gcn_weight = GCN(d_model, n_r, d_model, max_len, dropout) + + self.out_channels = d_model + + def forward(self, inputs): + f3, f5, f7 = inputs + + rp1 = self.agg_p1(f3) + rp2 = self.agg_p2(f5) + rp3 = self.agg_p3(f7) + rp = paddle.concat([rp1, rp2, rp3], 2) # [b,nr,d] + + rw1 = self.agg_w1(f3) + rw2 = self.agg_w2(f5) + rw3 = self.agg_w3(f7) + rw = paddle.concat([rw1, rw2, rw3], 2) # [b,nr,d] + + y1 = self.gcn_pool(rp) + y2 = self.gcn_weight(rw) + y = 0.5 * (y1 + y2) + return y # [b,max_len,d] diff --git a/backend/ppocr/modeling/necks/rnn.py b/backend/ppocr/modeling/necks/rnn.py index de87b3d9..c8a774b8 100644 --- a/backend/ppocr/modeling/necks/rnn.py +++ b/backend/ppocr/modeling/necks/rnn.py @@ -16,9 +16,11 @@ from __future__ import division from __future__ import print_function +import paddle from paddle import nn from ppocr.modeling.heads.rec_ctc_head import get_para_bias_attr +from ppocr.modeling.backbones.rec_svtrnet import Block, ConvBNLayer, trunc_normal_, zeros_, ones_ class Im2Seq(nn.Layer): @@ -51,7 +53,7 @@ def __init__(self, in_channels, hidden_size): super(EncoderWithFC, self).__init__() self.out_channels = hidden_size weight_attr, bias_attr = get_para_bias_attr( - l2_decay=0.00001, k=in_channels, name='reduce_encoder_fea') + l2_decay=0.00001, k=in_channels) self.fc = nn.Linear( in_channels, hidden_size, @@ -64,29 +66,126 @@ def forward(self, x): return x +class EncoderWithSVTR(nn.Layer): + def __init__( + self, + in_channels, + dims=64, # XS + depth=2, + hidden_dims=120, + use_guide=False, + num_heads=8, + qkv_bias=True, + mlp_ratio=2.0, + drop_rate=0.1, + attn_drop_rate=0.1, + drop_path=0., + qk_scale=None): + super(EncoderWithSVTR, self).__init__() + self.depth = depth + self.use_guide = use_guide + self.conv1 = ConvBNLayer( + in_channels, in_channels // 8, padding=1, act=nn.Swish) + self.conv2 = ConvBNLayer( + in_channels // 8, hidden_dims, kernel_size=1, act=nn.Swish) + + self.svtr_block = nn.LayerList([ + Block( + dim=hidden_dims, + num_heads=num_heads, + mixer='Global', + HW=None, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + act_layer=nn.Swish, + attn_drop=attn_drop_rate, + drop_path=drop_path, + norm_layer='nn.LayerNorm', + epsilon=1e-05, + prenorm=False) for i in range(depth) + ]) + self.norm = nn.LayerNorm(hidden_dims, epsilon=1e-6) + self.conv3 = ConvBNLayer( + hidden_dims, in_channels, kernel_size=1, act=nn.Swish) + # last conv-nxn, the input is concat of input tensor and conv3 output tensor + self.conv4 = ConvBNLayer( + 2 * in_channels, in_channels // 8, padding=1, act=nn.Swish) + + self.conv1x1 = ConvBNLayer( + in_channels // 8, dims, kernel_size=1, act=nn.Swish) + self.out_channels = dims + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + zeros_(m.bias) + ones_(m.weight) + + def forward(self, x): + # for use guide + if self.use_guide: + z = x.clone() + z.stop_gradient = True + else: + z = x + # for short cut + h = z + # reduce dim + z = self.conv1(z) + z = self.conv2(z) + # SVTR global block + B, C, H, W = z.shape + z = z.flatten(2).transpose([0, 2, 1]) + for blk in self.svtr_block: + z = blk(z) + z = self.norm(z) + # last stage + z = z.reshape([0, H, W, C]).transpose([0, 3, 1, 2]) + z = self.conv3(z) + z = paddle.concat((h, z), axis=1) + z = self.conv1x1(self.conv4(z)) + return z + + class SequenceEncoder(nn.Layer): def __init__(self, in_channels, encoder_type, hidden_size=48, **kwargs): super(SequenceEncoder, self).__init__() self.encoder_reshape = Im2Seq(in_channels) self.out_channels = self.encoder_reshape.out_channels + self.encoder_type = encoder_type if encoder_type == 'reshape': self.only_reshape = True else: support_encoder_dict = { 'reshape': Im2Seq, 'fc': EncoderWithFC, - 'rnn': EncoderWithRNN + 'rnn': EncoderWithRNN, + 'svtr': EncoderWithSVTR } assert encoder_type in support_encoder_dict, '{} must in {}'.format( encoder_type, support_encoder_dict.keys()) - - self.encoder = support_encoder_dict[encoder_type]( - self.encoder_reshape.out_channels, hidden_size) + if encoder_type == "svtr": + self.encoder = support_encoder_dict[encoder_type]( + self.encoder_reshape.out_channels, **kwargs) + else: + self.encoder = support_encoder_dict[encoder_type]( + self.encoder_reshape.out_channels, hidden_size) self.out_channels = self.encoder.out_channels self.only_reshape = False def forward(self, x): - x = self.encoder_reshape(x) - if not self.only_reshape: + if self.encoder_type != 'svtr': + x = self.encoder_reshape(x) + if not self.only_reshape: + x = self.encoder(x) + return x + else: x = self.encoder(x) - return x + x = self.encoder_reshape(x) + return x diff --git a/backend/ppocr/modeling/necks/table_fpn.py b/backend/ppocr/modeling/necks/table_fpn.py new file mode 100644 index 00000000..734f15af --- /dev/null +++ b/backend/ppocr/modeling/necks/table_fpn.py @@ -0,0 +1,110 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn +import paddle.nn.functional as F +from paddle import ParamAttr + + +class TableFPN(nn.Layer): + def __init__(self, in_channels, out_channels, **kwargs): + super(TableFPN, self).__init__() + self.out_channels = 512 + weight_attr = paddle.nn.initializer.KaimingUniform() + self.in2_conv = nn.Conv2D( + in_channels=in_channels[0], + out_channels=self.out_channels, + kernel_size=1, + weight_attr=ParamAttr(initializer=weight_attr), + bias_attr=False) + self.in3_conv = nn.Conv2D( + in_channels=in_channels[1], + out_channels=self.out_channels, + kernel_size=1, + stride = 1, + weight_attr=ParamAttr(initializer=weight_attr), + bias_attr=False) + self.in4_conv = nn.Conv2D( + in_channels=in_channels[2], + out_channels=self.out_channels, + kernel_size=1, + weight_attr=ParamAttr(initializer=weight_attr), + bias_attr=False) + self.in5_conv = nn.Conv2D( + in_channels=in_channels[3], + out_channels=self.out_channels, + kernel_size=1, + weight_attr=ParamAttr(initializer=weight_attr), + bias_attr=False) + self.p5_conv = nn.Conv2D( + in_channels=self.out_channels, + out_channels=self.out_channels // 4, + kernel_size=3, + padding=1, + weight_attr=ParamAttr(initializer=weight_attr), + bias_attr=False) + self.p4_conv = nn.Conv2D( + in_channels=self.out_channels, + out_channels=self.out_channels // 4, + kernel_size=3, + padding=1, + weight_attr=ParamAttr(initializer=weight_attr), + bias_attr=False) + self.p3_conv = nn.Conv2D( + in_channels=self.out_channels, + out_channels=self.out_channels // 4, + kernel_size=3, + padding=1, + weight_attr=ParamAttr(initializer=weight_attr), + bias_attr=False) + self.p2_conv = nn.Conv2D( + in_channels=self.out_channels, + out_channels=self.out_channels // 4, + kernel_size=3, + padding=1, + weight_attr=ParamAttr(initializer=weight_attr), + bias_attr=False) + self.fuse_conv = nn.Conv2D( + in_channels=self.out_channels * 4, + out_channels=512, + kernel_size=3, + padding=1, + weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False) + + def forward(self, x): + c2, c3, c4, c5 = x + + in5 = self.in5_conv(c5) + in4 = self.in4_conv(c4) + in3 = self.in3_conv(c3) + in2 = self.in2_conv(c2) + + out4 = in4 + F.upsample( + in5, size=in4.shape[2:4], mode="nearest", align_mode=1) # 1/16 + out3 = in3 + F.upsample( + out4, size=in3.shape[2:4], mode="nearest", align_mode=1) # 1/8 + out2 = in2 + F.upsample( + out3, size=in2.shape[2:4], mode="nearest", align_mode=1) # 1/4 + + p4 = F.upsample(out4, size=in5.shape[2:4], mode="nearest", align_mode=1) + p3 = F.upsample(out3, size=in5.shape[2:4], mode="nearest", align_mode=1) + p2 = F.upsample(out2, size=in5.shape[2:4], mode="nearest", align_mode=1) + fuse = paddle.concat([in5, p4, p3, p2], axis=1) + fuse_conv = self.fuse_conv(fuse) * 0.005 + return [c5 + fuse_conv] diff --git a/backend/ppocr/modeling/transforms/__init__.py b/backend/ppocr/modeling/transforms/__init__.py index 78eaeccc..405ab3cc 100755 --- a/backend/ppocr/modeling/transforms/__init__.py +++ b/backend/ppocr/modeling/transforms/__init__.py @@ -17,8 +17,9 @@ def build_transform(config): from .tps import TPS + from .stn import STN_ON - support_dict = ['TPS'] + support_dict = ['TPS', 'STN_ON'] module_name = config.pop('name') assert module_name in support_dict, Exception( diff --git a/backend/ppocr/modeling/transforms/stn.py b/backend/ppocr/modeling/transforms/stn.py new file mode 100644 index 00000000..6f2bdda0 --- /dev/null +++ b/backend/ppocr/modeling/transforms/stn.py @@ -0,0 +1,135 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/ayumiymk/aster.pytorch/blob/master/lib/models/stn_head.py +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import paddle +from paddle import nn, ParamAttr +from paddle.nn import functional as F +import numpy as np + +from .tps_spatial_transformer import TPSSpatialTransformer + + +def conv3x3_block(in_channels, out_channels, stride=1): + n = 3 * 3 * out_channels + w = math.sqrt(2. / n) + conv_layer = nn.Conv2D( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=1, + weight_attr=nn.initializer.Normal( + mean=0.0, std=w), + bias_attr=nn.initializer.Constant(0)) + block = nn.Sequential(conv_layer, nn.BatchNorm2D(out_channels), nn.ReLU()) + return block + + +class STN(nn.Layer): + def __init__(self, in_channels, num_ctrlpoints, activation='none'): + super(STN, self).__init__() + self.in_channels = in_channels + self.num_ctrlpoints = num_ctrlpoints + self.activation = activation + self.stn_convnet = nn.Sequential( + conv3x3_block(in_channels, 32), #32x64 + nn.MaxPool2D( + kernel_size=2, stride=2), + conv3x3_block(32, 64), #16x32 + nn.MaxPool2D( + kernel_size=2, stride=2), + conv3x3_block(64, 128), # 8*16 + nn.MaxPool2D( + kernel_size=2, stride=2), + conv3x3_block(128, 256), # 4*8 + nn.MaxPool2D( + kernel_size=2, stride=2), + conv3x3_block(256, 256), # 2*4, + nn.MaxPool2D( + kernel_size=2, stride=2), + conv3x3_block(256, 256)) # 1*2 + self.stn_fc1 = nn.Sequential( + nn.Linear( + 2 * 256, + 512, + weight_attr=nn.initializer.Normal(0, 0.001), + bias_attr=nn.initializer.Constant(0)), + nn.BatchNorm1D(512), + nn.ReLU()) + fc2_bias = self.init_stn() + self.stn_fc2 = nn.Linear( + 512, + num_ctrlpoints * 2, + weight_attr=nn.initializer.Constant(0.0), + bias_attr=nn.initializer.Assign(fc2_bias)) + + def init_stn(self): + margin = 0.01 + sampling_num_per_side = int(self.num_ctrlpoints / 2) + ctrl_pts_x = np.linspace(margin, 1. - margin, sampling_num_per_side) + ctrl_pts_y_top = np.ones(sampling_num_per_side) * margin + ctrl_pts_y_bottom = np.ones(sampling_num_per_side) * (1 - margin) + ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) + ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) + ctrl_points = np.concatenate( + [ctrl_pts_top, ctrl_pts_bottom], axis=0).astype(np.float32) + if self.activation == 'none': + pass + elif self.activation == 'sigmoid': + ctrl_points = -np.log(1. / ctrl_points - 1.) + ctrl_points = paddle.to_tensor(ctrl_points) + fc2_bias = paddle.reshape( + ctrl_points, shape=[ctrl_points.shape[0] * ctrl_points.shape[1]]) + return fc2_bias + + def forward(self, x): + x = self.stn_convnet(x) + batch_size, _, h, w = x.shape + x = paddle.reshape(x, shape=(batch_size, -1)) + img_feat = self.stn_fc1(x) + x = self.stn_fc2(0.1 * img_feat) + if self.activation == 'sigmoid': + x = F.sigmoid(x) + x = paddle.reshape(x, shape=[-1, self.num_ctrlpoints, 2]) + return img_feat, x + + +class STN_ON(nn.Layer): + def __init__(self, in_channels, tps_inputsize, tps_outputsize, + num_control_points, tps_margins, stn_activation): + super(STN_ON, self).__init__() + self.tps = TPSSpatialTransformer( + output_image_size=tuple(tps_outputsize), + num_control_points=num_control_points, + margins=tuple(tps_margins)) + self.stn_head = STN(in_channels=in_channels, + num_ctrlpoints=num_control_points, + activation=stn_activation) + self.tps_inputsize = tps_inputsize + self.out_channels = in_channels + + def forward(self, image): + stn_input = paddle.nn.functional.interpolate( + image, self.tps_inputsize, mode="bilinear", align_corners=True) + stn_img_feat, ctrl_points = self.stn_head(stn_input) + x, _ = self.tps(image, ctrl_points) + return x diff --git a/backend/ppocr/modeling/transforms/tps.py b/backend/ppocr/modeling/transforms/tps.py index 78338edf..9bdab0f8 100644 --- a/backend/ppocr/modeling/transforms/tps.py +++ b/backend/ppocr/modeling/transforms/tps.py @@ -11,6 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +This code is refer from: +https://github.com/clovaai/deep-text-recognition-benchmark/blob/master/modules/transformation.py +""" from __future__ import absolute_import from __future__ import division @@ -230,15 +234,9 @@ def build_P_paddle(self, I_r_size): def build_inv_delta_C_paddle(self, C): """ Return inv_delta_C which is needed to calculate T """ F = self.F - hat_C = paddle.zeros((F, F), dtype='float64') # F x F - for i in range(0, F): - for j in range(i, F): - if i == j: - hat_C[i, j] = 1 - else: - r = paddle.norm(C[i] - C[j]) - hat_C[i, j] = r - hat_C[j, i] = r + hat_eye = paddle.eye(F, dtype='float64') # F x F + hat_C = paddle.norm( + C.reshape([1, F, 2]) - C.reshape([F, 1, 2]), axis=2) + hat_eye hat_C = (hat_C**2) * paddle.log(hat_C) delta_C = paddle.concat( # F+3 x F+3 [ diff --git a/backend/ppocr/modeling/transforms/tps_spatial_transformer.py b/backend/ppocr/modeling/transforms/tps_spatial_transformer.py new file mode 100644 index 00000000..cb1cb10a --- /dev/null +++ b/backend/ppocr/modeling/transforms/tps_spatial_transformer.py @@ -0,0 +1,156 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/ayumiymk/aster.pytorch/blob/master/lib/models/tps_spatial_transformer.py +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import paddle +from paddle import nn, ParamAttr +from paddle.nn import functional as F +import numpy as np +import itertools + + +def grid_sample(input, grid, canvas=None): + input.stop_gradient = False + output = F.grid_sample(input, grid) + if canvas is None: + return output + else: + input_mask = paddle.ones(shape=input.shape) + output_mask = F.grid_sample(input_mask, grid) + padded_output = output * output_mask + canvas * (1 - output_mask) + return padded_output + + +# phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2 +def compute_partial_repr(input_points, control_points): + N = input_points.shape[0] + M = control_points.shape[0] + pairwise_diff = paddle.reshape( + input_points, shape=[N, 1, 2]) - paddle.reshape( + control_points, shape=[1, M, 2]) + # original implementation, very slow + # pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance + pairwise_diff_square = pairwise_diff * pairwise_diff + pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :, + 1] + repr_matrix = 0.5 * pairwise_dist * paddle.log(pairwise_dist) + # fix numerical error for 0 * log(0), substitute all nan with 0 + mask = np.array(repr_matrix != repr_matrix) + repr_matrix[mask] = 0 + return repr_matrix + + +# output_ctrl_pts are specified, according to our task. +def build_output_control_points(num_control_points, margins): + margin_x, margin_y = margins + num_ctrl_pts_per_side = num_control_points // 2 + ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x, num_ctrl_pts_per_side) + ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y + ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y) + ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) + ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) + output_ctrl_pts_arr = np.concatenate( + [ctrl_pts_top, ctrl_pts_bottom], axis=0) + output_ctrl_pts = paddle.to_tensor(output_ctrl_pts_arr) + return output_ctrl_pts + + +class TPSSpatialTransformer(nn.Layer): + def __init__(self, + output_image_size=None, + num_control_points=None, + margins=None): + super(TPSSpatialTransformer, self).__init__() + self.output_image_size = output_image_size + self.num_control_points = num_control_points + self.margins = margins + + self.target_height, self.target_width = output_image_size + target_control_points = build_output_control_points(num_control_points, + margins) + N = num_control_points + + # create padded kernel matrix + forward_kernel = paddle.zeros(shape=[N + 3, N + 3]) + target_control_partial_repr = compute_partial_repr( + target_control_points, target_control_points) + target_control_partial_repr = paddle.cast(target_control_partial_repr, + forward_kernel.dtype) + forward_kernel[:N, :N] = target_control_partial_repr + forward_kernel[:N, -3] = 1 + forward_kernel[-3, :N] = 1 + target_control_points = paddle.cast(target_control_points, + forward_kernel.dtype) + forward_kernel[:N, -2:] = target_control_points + forward_kernel[-2:, :N] = paddle.transpose( + target_control_points, perm=[1, 0]) + # compute inverse matrix + inverse_kernel = paddle.inverse(forward_kernel) + + # create target cordinate matrix + HW = self.target_height * self.target_width + target_coordinate = list( + itertools.product( + range(self.target_height), range(self.target_width))) + target_coordinate = paddle.to_tensor(target_coordinate) # HW x 2 + Y, X = paddle.split( + target_coordinate, target_coordinate.shape[1], axis=1) + Y = Y / (self.target_height - 1) + X = X / (self.target_width - 1) + target_coordinate = paddle.concat( + [X, Y], axis=1) # convert from (y, x) to (x, y) + target_coordinate_partial_repr = compute_partial_repr( + target_coordinate, target_control_points) + target_coordinate_repr = paddle.concat( + [ + target_coordinate_partial_repr, paddle.ones(shape=[HW, 1]), + target_coordinate + ], + axis=1) + + # register precomputed matrices + self.inverse_kernel = inverse_kernel + self.padding_matrix = paddle.zeros(shape=[3, 2]) + self.target_coordinate_repr = target_coordinate_repr + self.target_control_points = target_control_points + + def forward(self, input, source_control_points): + assert source_control_points.ndimension() == 3 + assert source_control_points.shape[1] == self.num_control_points + assert source_control_points.shape[2] == 2 + batch_size = paddle.shape(source_control_points)[0] + + padding_matrix = paddle.expand( + self.padding_matrix, shape=[batch_size, 3, 2]) + Y = paddle.concat([source_control_points, padding_matrix], 1) + mapping_matrix = paddle.matmul(self.inverse_kernel, Y) + source_coordinate = paddle.matmul(self.target_coordinate_repr, + mapping_matrix) + + grid = paddle.reshape( + source_coordinate, + shape=[-1, self.target_height, self.target_width, 2]) + grid = paddle.clip(grid, 0, + 1) # the source_control_points may be out of [0, 1]. + # the input to grid_sample is normalized [-1, 1], but what we get is [0, 1] + grid = 2.0 * grid - 1.0 + output_maps = grid_sample(input, grid, canvas=None) + return output_maps, source_coordinate diff --git a/backend/ppocr/optimizer/__init__.py b/backend/ppocr/optimizer/__init__.py index c729103a..a6bd2ebb 100644 --- a/backend/ppocr/optimizer/__init__.py +++ b/backend/ppocr/optimizer/__init__.py @@ -25,15 +25,12 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch): from . import learning_rate lr_config.update({'epochs': epochs, 'step_each_epoch': step_each_epoch}) - if 'name' in lr_config: - lr_name = lr_config.pop('name') - lr = getattr(learning_rate, lr_name)(**lr_config)() - else: - lr = lr_config['learning_rate'] + lr_name = lr_config.pop('name', 'Const') + lr = getattr(learning_rate, lr_name)(**lr_config)() return lr -def build_optimizer(config, epochs, step_each_epoch, parameters): +def build_optimizer(config, epochs, step_each_epoch, model): from . import regularizer, optimizer config = copy.deepcopy(config) # step1 build lr @@ -42,8 +39,12 @@ def build_optimizer(config, epochs, step_each_epoch, parameters): # step2 build regularization if 'regularizer' in config and config['regularizer'] is not None: reg_config = config.pop('regularizer') - reg_name = reg_config.pop('name') + 'Decay' + reg_name = reg_config.pop('name') + if not hasattr(regularizer, reg_name): + reg_name += 'Decay' reg = getattr(regularizer, reg_name)(**reg_config)() + elif 'weight_decay' in config: + reg = config.pop('weight_decay') else: reg = None @@ -58,4 +59,4 @@ def build_optimizer(config, epochs, step_each_epoch, parameters): weight_decay=reg, grad_clip=grad_clip, **config) - return optim(parameters), lr + return optim(model), lr diff --git a/backend/ppocr/optimizer/learning_rate.py b/backend/ppocr/optimizer/learning_rate.py index e1b10992..fe251f36 100644 --- a/backend/ppocr/optimizer/learning_rate.py +++ b/backend/ppocr/optimizer/learning_rate.py @@ -18,7 +18,7 @@ from __future__ import unicode_literals from paddle.optimizer import lr -from .lr_scheduler import CyclicalCosineDecay +from .lr_scheduler import CyclicalCosineDecay, OneCycleDecay class Linear(object): @@ -226,3 +226,85 @@ def __call__(self): end_lr=self.learning_rate, last_epoch=self.last_epoch) return learning_rate + + +class OneCycle(object): + """ + One Cycle learning rate decay + Args: + max_lr(float): Upper learning rate boundaries + epochs(int): total training epochs + step_each_epoch(int): steps each epoch + anneal_strategy(str): {‘cos’, ‘linear’} Specifies the annealing strategy: “cos” for cosine annealing, “linear” for linear annealing. + Default: ‘cos’ + three_phase(bool): If True, use a third phase of the schedule to annihilate the learning rate according to ‘final_div_factor’ + instead of modifying the second phase (the first two phases will be symmetrical about the step indicated by ‘pct_start’). + last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. + """ + + def __init__(self, + max_lr, + epochs, + step_each_epoch, + anneal_strategy='cos', + three_phase=False, + warmup_epoch=0, + last_epoch=-1, + **kwargs): + super(OneCycle, self).__init__() + self.max_lr = max_lr + self.epochs = epochs + self.steps_per_epoch = step_each_epoch + self.anneal_strategy = anneal_strategy + self.three_phase = three_phase + self.last_epoch = last_epoch + self.warmup_epoch = round(warmup_epoch * step_each_epoch) + + def __call__(self): + learning_rate = OneCycleDecay( + max_lr=self.max_lr, + epochs=self.epochs, + steps_per_epoch=self.steps_per_epoch, + anneal_strategy=self.anneal_strategy, + three_phase=self.three_phase, + last_epoch=self.last_epoch) + if self.warmup_epoch > 0: + learning_rate = lr.LinearWarmup( + learning_rate=learning_rate, + warmup_steps=self.warmup_epoch, + start_lr=0.0, + end_lr=self.max_lr, + last_epoch=self.last_epoch) + return learning_rate + + +class Const(object): + """ + Const learning rate decay + Args: + learning_rate(float): initial learning rate + step_each_epoch(int): steps each epoch + last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. + """ + + def __init__(self, + learning_rate, + step_each_epoch, + warmup_epoch=0, + last_epoch=-1, + **kwargs): + super(Const, self).__init__() + self.learning_rate = learning_rate + self.last_epoch = last_epoch + self.warmup_epoch = round(warmup_epoch * step_each_epoch) + + def __call__(self): + learning_rate = self.learning_rate + if self.warmup_epoch > 0: + learning_rate = lr.LinearWarmup( + learning_rate=learning_rate, + warmup_steps=self.warmup_epoch, + start_lr=0.0, + end_lr=self.learning_rate, + last_epoch=self.last_epoch) + return learning_rate diff --git a/backend/ppocr/optimizer/lr_scheduler.py b/backend/ppocr/optimizer/lr_scheduler.py index 21aec737..f62f1f3b 100644 --- a/backend/ppocr/optimizer/lr_scheduler.py +++ b/backend/ppocr/optimizer/lr_scheduler.py @@ -47,3 +47,116 @@ def get_lr(self): lr = self.eta_min + 0.5 * (self.base_lr - self.eta_min) * \ (1 + math.cos(math.pi * reletive_epoch / self.cycle)) return lr + + +class OneCycleDecay(LRScheduler): + """ + One Cycle learning rate decay + A learning rate which can be referred in https://arxiv.org/abs/1708.07120 + Code refered in https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR + """ + + def __init__(self, + max_lr, + epochs=None, + steps_per_epoch=None, + pct_start=0.3, + anneal_strategy='cos', + div_factor=25., + final_div_factor=1e4, + three_phase=False, + last_epoch=-1, + verbose=False): + + # Validate total_steps + if epochs <= 0 or not isinstance(epochs, int): + raise ValueError( + "Expected positive integer epochs, but got {}".format(epochs)) + if steps_per_epoch <= 0 or not isinstance(steps_per_epoch, int): + raise ValueError( + "Expected positive integer steps_per_epoch, but got {}".format( + steps_per_epoch)) + self.total_steps = epochs * steps_per_epoch + + self.max_lr = max_lr + self.initial_lr = self.max_lr / div_factor + self.min_lr = self.initial_lr / final_div_factor + + if three_phase: + self._schedule_phases = [ + { + 'end_step': float(pct_start * self.total_steps) - 1, + 'start_lr': self.initial_lr, + 'end_lr': self.max_lr, + }, + { + 'end_step': float(2 * pct_start * self.total_steps) - 2, + 'start_lr': self.max_lr, + 'end_lr': self.initial_lr, + }, + { + 'end_step': self.total_steps - 1, + 'start_lr': self.initial_lr, + 'end_lr': self.min_lr, + }, + ] + else: + self._schedule_phases = [ + { + 'end_step': float(pct_start * self.total_steps) - 1, + 'start_lr': self.initial_lr, + 'end_lr': self.max_lr, + }, + { + 'end_step': self.total_steps - 1, + 'start_lr': self.max_lr, + 'end_lr': self.min_lr, + }, + ] + + # Validate pct_start + if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float): + raise ValueError( + "Expected float between 0 and 1 pct_start, but got {}".format( + pct_start)) + + # Validate anneal_strategy + if anneal_strategy not in ['cos', 'linear']: + raise ValueError( + "anneal_strategy must by one of 'cos' or 'linear', instead got {}". + format(anneal_strategy)) + elif anneal_strategy == 'cos': + self.anneal_func = self._annealing_cos + elif anneal_strategy == 'linear': + self.anneal_func = self._annealing_linear + + super(OneCycleDecay, self).__init__(max_lr, last_epoch, verbose) + + def _annealing_cos(self, start, end, pct): + "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0." + cos_out = math.cos(math.pi * pct) + 1 + return end + (start - end) / 2.0 * cos_out + + def _annealing_linear(self, start, end, pct): + "Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0." + return (end - start) * pct + start + + def get_lr(self): + computed_lr = 0.0 + step_num = self.last_epoch + + if step_num > self.total_steps: + raise ValueError( + "Tried to step {} times. The specified number of total steps is {}" + .format(step_num + 1, self.total_steps)) + start_step = 0 + for i, phase in enumerate(self._schedule_phases): + end_step = phase['end_step'] + if step_num <= end_step or i == len(self._schedule_phases) - 1: + pct = (step_num - start_step) / (end_step - start_step) + computed_lr = self.anneal_func(phase['start_lr'], + phase['end_lr'], pct) + break + start_step = phase['end_step'] + + return computed_lr diff --git a/backend/ppocr/optimizer/optimizer.py b/backend/ppocr/optimizer/optimizer.py index 8215b92d..dd8544e2 100644 --- a/backend/ppocr/optimizer/optimizer.py +++ b/backend/ppocr/optimizer/optimizer.py @@ -42,13 +42,16 @@ def __init__(self, self.weight_decay = weight_decay self.grad_clip = grad_clip - def __call__(self, parameters): + def __call__(self, model): + train_params = [ + param for param in model.parameters() if param.trainable is True + ] opt = optim.Momentum( learning_rate=self.learning_rate, momentum=self.momentum, weight_decay=self.weight_decay, grad_clip=self.grad_clip, - parameters=parameters) + parameters=train_params) return opt @@ -75,7 +78,10 @@ def __init__(self, self.name = name self.lazy_mode = lazy_mode - def __call__(self, parameters): + def __call__(self, model): + train_params = [ + param for param in model.parameters() if param.trainable is True + ] opt = optim.Adam( learning_rate=self.learning_rate, beta1=self.beta1, @@ -85,7 +91,7 @@ def __call__(self, parameters): grad_clip=self.grad_clip, name=self.name, lazy_mode=self.lazy_mode, - parameters=parameters) + parameters=train_params) return opt @@ -117,7 +123,10 @@ def __init__(self, self.weight_decay = weight_decay self.grad_clip = grad_clip - def __call__(self, parameters): + def __call__(self, model): + train_params = [ + param for param in model.parameters() if param.trainable is True + ] opt = optim.RMSProp( learning_rate=self.learning_rate, momentum=self.momentum, @@ -125,5 +134,101 @@ def __call__(self, parameters): epsilon=self.epsilon, weight_decay=self.weight_decay, grad_clip=self.grad_clip, - parameters=parameters) + parameters=train_params) return opt + + +class Adadelta(object): + def __init__(self, + learning_rate=0.001, + epsilon=1e-08, + rho=0.95, + parameter_list=None, + weight_decay=None, + grad_clip=None, + name=None, + **kwargs): + self.learning_rate = learning_rate + self.epsilon = epsilon + self.rho = rho + self.parameter_list = parameter_list + self.learning_rate = learning_rate + self.weight_decay = weight_decay + self.grad_clip = grad_clip + self.name = name + + def __call__(self, model): + train_params = [ + param for param in model.parameters() if param.trainable is True + ] + opt = optim.Adadelta( + learning_rate=self.learning_rate, + epsilon=self.epsilon, + rho=self.rho, + weight_decay=self.weight_decay, + grad_clip=self.grad_clip, + name=self.name, + parameters=train_params) + return opt + + +class AdamW(object): + def __init__(self, + learning_rate=0.001, + beta1=0.9, + beta2=0.999, + epsilon=1e-8, + weight_decay=0.01, + multi_precision=False, + grad_clip=None, + no_weight_decay_name=None, + one_dim_param_no_weight_decay=False, + name=None, + lazy_mode=False, + **args): + super().__init__() + self.learning_rate = learning_rate + self.beta1 = beta1 + self.beta2 = beta2 + self.epsilon = epsilon + self.grad_clip = grad_clip + self.weight_decay = 0.01 if weight_decay is None else weight_decay + self.grad_clip = grad_clip + self.name = name + self.lazy_mode = lazy_mode + self.multi_precision = multi_precision + self.no_weight_decay_name_list = no_weight_decay_name.split( + ) if no_weight_decay_name else [] + self.one_dim_param_no_weight_decay = one_dim_param_no_weight_decay + + def __call__(self, model): + parameters = [ + param for param in model.parameters() if param.trainable is True + ] + + self.no_weight_decay_param_name_list = [ + p.name for n, p in model.named_parameters() + if any(nd in n for nd in self.no_weight_decay_name_list) + ] + + if self.one_dim_param_no_weight_decay: + self.no_weight_decay_param_name_list += [ + p.name for n, p in model.named_parameters() if len(p.shape) == 1 + ] + + opt = optim.AdamW( + learning_rate=self.learning_rate, + beta1=self.beta1, + beta2=self.beta2, + epsilon=self.epsilon, + parameters=parameters, + weight_decay=self.weight_decay, + multi_precision=self.multi_precision, + grad_clip=self.grad_clip, + name=self.name, + lazy_mode=self.lazy_mode, + apply_decay_param_fun=self._apply_decay_param_fun) + return opt + + def _apply_decay_param_fun(self, name): + return name not in self.no_weight_decay_param_name_list diff --git a/backend/ppocr/optimizer/regularizer.py b/backend/ppocr/optimizer/regularizer.py index c6396f33..2ce68f71 100644 --- a/backend/ppocr/optimizer/regularizer.py +++ b/backend/ppocr/optimizer/regularizer.py @@ -29,24 +29,23 @@ class L1Decay(object): def __init__(self, factor=0.0): super(L1Decay, self).__init__() - self.regularization_coeff = factor + self.coeff = factor def __call__(self): - reg = paddle.regularizer.L1Decay(self.regularization_coeff) + reg = paddle.regularizer.L1Decay(self.coeff) return reg class L2Decay(object): """ - L2 Weight Decay Regularization, which encourages the weights to be sparse. + L2 Weight Decay Regularization, which helps to prevent the model over-fitting. Args: factor(float): regularization coeff. Default:0.0. """ def __init__(self, factor=0.0): super(L2Decay, self).__init__() - self.regularization_coeff = factor + self.coeff = float(factor) def __call__(self): - reg = paddle.regularizer.L2Decay(self.regularization_coeff) - return reg + return self.coeff \ No newline at end of file diff --git a/backend/ppocr/postprocess/__init__.py b/backend/ppocr/postprocess/__init__.py index 0156e438..f50b5f1c 100644 --- a/backend/ppocr/postprocess/__init__.py +++ b/backend/ppocr/postprocess/__init__.py @@ -21,21 +21,38 @@ __all__ = ['build_post_process'] +from .db_postprocess import DBPostProcess, DistillationDBPostProcess +from .east_postprocess import EASTPostProcess +from .sast_postprocess import SASTPostProcess +from .fce_postprocess import FCEPostProcess +from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \ + DistillationCTCLabelDecode, TableLabelDecode, NRTRLabelDecode, SARLabelDecode, \ + SEEDLabelDecode, PRENLabelDecode +from .cls_postprocess import ClsPostProcess +from .pg_postprocess import PGPostProcess +from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess +from .vqa_token_re_layoutlm_postprocess import VQAReTokenLayoutLMPostProcess -def build_post_process(config, global_config=None): - from .db_postprocess import DBPostProcess - from .east_postprocess import EASTPostProcess - from .sast_postprocess import SASTPostProcess - from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode - from .cls_postprocess import ClsPostProcess +def build_post_process(config, global_config=None): support_dict = [ - 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', - 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode' + 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'FCEPostProcess', + 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', + 'PGPostProcess', 'DistillationCTCLabelDecode', 'TableLabelDecode', + 'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode', + 'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess', + 'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode', + 'DistillationSARLabelDecode' ] + if config['name'] == 'PSEPostProcess': + from .pse_postprocess import PSEPostProcess + support_dict.append('PSEPostProcess') + config = copy.deepcopy(config) module_name = config.pop('name') + if module_name == "None": + return if global_config is not None: config.update(global_config) assert module_name in support_dict, Exception( diff --git a/backend/ppocr/postprocess/cls_postprocess.py b/backend/ppocr/postprocess/cls_postprocess.py index 77e7f46d..9a27ba08 100644 --- a/backend/ppocr/postprocess/cls_postprocess.py +++ b/backend/ppocr/postprocess/cls_postprocess.py @@ -17,17 +17,26 @@ class ClsPostProcess(object): """ Convert between text-label and text-index """ - def __init__(self, label_list, **kwargs): + def __init__(self, label_list=None, key=None, **kwargs): super(ClsPostProcess, self).__init__() self.label_list = label_list + self.key = key def __call__(self, preds, label=None, *args, **kwargs): + if self.key is not None: + preds = preds[self.key] + + label_list = self.label_list + if label_list is None: + label_list = {idx: idx for idx in range(preds.shape[-1])} + if isinstance(preds, paddle.Tensor): preds = preds.numpy() + pred_idxs = preds.argmax(axis=1) - decode_out = [(self.label_list[idx], preds[i, idx]) + decode_out = [(label_list[idx], preds[i, idx]) for i, idx in enumerate(pred_idxs)] if label is None: return decode_out - label = [(self.label_list[idx], 1.0) for idx in label] + label = [(label_list[idx], 1.0) for idx in label] return decode_out, label diff --git a/backend/ppocr/postprocess/db_postprocess.py b/backend/ppocr/postprocess/db_postprocess.py index 91729e0a..6542a1bf 100755 --- a/backend/ppocr/postprocess/db_postprocess.py +++ b/backend/ppocr/postprocess/db_postprocess.py @@ -11,7 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +""" +This code is refered from: +https://github.com/WenmuZhou/DBNet.pytorch/blob/master/post_processing/seg_detector_representer.py +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -34,12 +37,18 @@ def __init__(self, max_candidates=1000, unclip_ratio=2.0, use_dilation=False, + score_mode="fast", **kwargs): self.thresh = thresh self.box_thresh = box_thresh self.max_candidates = max_candidates self.unclip_ratio = unclip_ratio self.min_size = 3 + self.score_mode = score_mode + assert score_mode in [ + "slow", "fast" + ], "Score mode must be in [slow, fast] but got: {}".format(score_mode) + self.dilation_kernel = None if not use_dilation else np.array( [[1, 1], [1, 1]]) @@ -69,7 +78,10 @@ def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): if sside < self.min_size: continue points = np.array(points) - score = self.box_score_fast(pred, points.reshape(-1, 2)) + if self.score_mode == "fast": + score = self.box_score_fast(pred, points.reshape(-1, 2)) + else: + score = self.box_score_slow(pred, contour) if self.box_thresh > score: continue @@ -120,12 +132,15 @@ def get_mini_boxes(self, contour): return box, min(bounding_box[1]) def box_score_fast(self, bitmap, _box): + ''' + box_score_fast: use bbox mean score as the mean score + ''' h, w = bitmap.shape[:2] box = _box.copy() - xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1) - xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int), 0, w - 1) - ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int), 0, h - 1) - ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int), 0, h - 1) + xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int32), 0, w - 1) + xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int32), 0, w - 1) + ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int32), 0, h - 1) + ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int32), 0, h - 1) mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) box[:, 0] = box[:, 0] - xmin @@ -133,6 +148,27 @@ def box_score_fast(self, bitmap, _box): cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1) return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] + def box_score_slow(self, bitmap, contour): + ''' + box_score_slow: use polyon mean score as the mean score + ''' + h, w = bitmap.shape[:2] + contour = contour.copy() + contour = np.reshape(contour, (-1, 2)) + + xmin = np.clip(np.min(contour[:, 0]), 0, w - 1) + xmax = np.clip(np.max(contour[:, 0]), 0, w - 1) + ymin = np.clip(np.min(contour[:, 1]), 0, h - 1) + ymax = np.clip(np.max(contour[:, 1]), 0, h - 1) + + mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) + + contour[:, 0] = contour[:, 0] - xmin + contour[:, 1] = contour[:, 1] - ymin + + cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1) + return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] + def __call__(self, outs_dict, shape_list): pred = outs_dict['maps'] if isinstance(pred, paddle.Tensor): @@ -154,3 +190,31 @@ def __call__(self, outs_dict, shape_list): boxes_batch.append({'points': boxes}) return boxes_batch + + +class DistillationDBPostProcess(object): + def __init__(self, + model_name=["student"], + key=None, + thresh=0.3, + box_thresh=0.6, + max_candidates=1000, + unclip_ratio=1.5, + use_dilation=False, + score_mode="fast", + **kwargs): + self.model_name = model_name + self.key = key + self.post_process = DBPostProcess( + thresh=thresh, + box_thresh=box_thresh, + max_candidates=max_candidates, + unclip_ratio=unclip_ratio, + use_dilation=use_dilation, + score_mode=score_mode) + + def __call__(self, predicts, shape_list): + results = {} + for k in self.model_name: + results[k] = self.post_process(predicts[k], shape_list=shape_list) + return results diff --git a/backend/ppocr/postprocess/east_postprocess.py b/backend/ppocr/postprocess/east_postprocess.py index ceee727a..c194c81c 100755 --- a/backend/ppocr/postprocess/east_postprocess.py +++ b/backend/ppocr/postprocess/east_postprocess.py @@ -29,6 +29,7 @@ class EASTPostProcess(object): """ The post process for EAST. """ + def __init__(self, score_thresh=0.8, cover_thresh=0.1, @@ -38,11 +39,6 @@ def __init__(self, self.score_thresh = score_thresh self.cover_thresh = cover_thresh self.nms_thresh = nms_thresh - - # c++ la-nms is faster, but only support python 3.5 - self.is_python35 = False - if sys.version_info.major == 3 and sys.version_info.minor == 5: - self.is_python35 = True def restore_rectangle_quad(self, origin, geometry): """ @@ -64,6 +60,7 @@ def detect(self, """ restore text boxes from score map and geo map """ + score_map = score_map[0] geo_map = np.swapaxes(geo_map, 1, 0) geo_map = np.swapaxes(geo_map, 1, 2) @@ -79,10 +76,14 @@ def detect(self, boxes = np.zeros((text_box_restored.shape[0], 9), dtype=np.float32) boxes[:, :8] = text_box_restored.reshape((-1, 8)) boxes[:, 8] = score_map[xy_text[:, 0], xy_text[:, 1]] - if self.is_python35: + + try: import lanms boxes = lanms.merge_quadrangle_n9(boxes, nms_thresh) - else: + except: + print( + 'you should install lanms by pip3 install lanms-nova to speed up nms_locality' + ) boxes = nms_locality(boxes.astype(np.float64), nms_thresh) if boxes.shape[0] == 0: return [] @@ -139,4 +140,4 @@ def __call__(self, outs_dict, shape_list): continue boxes_norm.append(box) dt_boxes_list.append({'points': np.array(boxes_norm)}) - return dt_boxes_list \ No newline at end of file + return dt_boxes_list diff --git a/backend/ppocr/postprocess/fce_postprocess.py b/backend/ppocr/postprocess/fce_postprocess.py new file mode 100755 index 00000000..8e0716f9 --- /dev/null +++ b/backend/ppocr/postprocess/fce_postprocess.py @@ -0,0 +1,241 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/open-mmlab/mmocr/blob/v0.3.0/mmocr/models/textdet/postprocess/wrapper.py +""" + +import cv2 +import paddle +import numpy as np +from numpy.fft import ifft +from ppocr.utils.poly_nms import poly_nms, valid_boundary + + +def fill_hole(input_mask): + h, w = input_mask.shape + canvas = np.zeros((h + 2, w + 2), np.uint8) + canvas[1:h + 1, 1:w + 1] = input_mask.copy() + + mask = np.zeros((h + 4, w + 4), np.uint8) + + cv2.floodFill(canvas, mask, (0, 0), 1) + canvas = canvas[1:h + 1, 1:w + 1].astype(np.bool) + + return ~canvas | input_mask + + +def fourier2poly(fourier_coeff, num_reconstr_points=50): + """ Inverse Fourier transform + Args: + fourier_coeff (ndarray): Fourier coefficients shaped (n, 2k+1), + with n and k being candidates number and Fourier degree + respectively. + num_reconstr_points (int): Number of reconstructed polygon points. + Returns: + Polygons (ndarray): The reconstructed polygons shaped (n, n') + """ + + a = np.zeros((len(fourier_coeff), num_reconstr_points), dtype='complex') + k = (len(fourier_coeff[0]) - 1) // 2 + + a[:, 0:k + 1] = fourier_coeff[:, k:] + a[:, -k:] = fourier_coeff[:, :k] + + poly_complex = ifft(a) * num_reconstr_points + polygon = np.zeros((len(fourier_coeff), num_reconstr_points, 2)) + polygon[:, :, 0] = poly_complex.real + polygon[:, :, 1] = poly_complex.imag + return polygon.astype('int32').reshape((len(fourier_coeff), -1)) + + +class FCEPostProcess(object): + """ + The post process for FCENet. + """ + + def __init__(self, + scales, + fourier_degree=5, + num_reconstr_points=50, + decoding_type='fcenet', + score_thr=0.3, + nms_thr=0.1, + alpha=1.0, + beta=1.0, + box_type='poly', + **kwargs): + + self.scales = scales + self.fourier_degree = fourier_degree + self.num_reconstr_points = num_reconstr_points + self.decoding_type = decoding_type + self.score_thr = score_thr + self.nms_thr = nms_thr + self.alpha = alpha + self.beta = beta + self.box_type = box_type + + def __call__(self, preds, shape_list): + score_maps = [] + for key, value in preds.items(): + if isinstance(value, paddle.Tensor): + value = value.numpy() + cls_res = value[:, :4, :, :] + reg_res = value[:, 4:, :, :] + score_maps.append([cls_res, reg_res]) + + return self.get_boundary(score_maps, shape_list) + + def resize_boundary(self, boundaries, scale_factor): + """Rescale boundaries via scale_factor. + + Args: + boundaries (list[list[float]]): The boundary list. Each boundary + with size 2k+1 with k>=4. + scale_factor(ndarray): The scale factor of size (4,). + + Returns: + boundaries (list[list[float]]): The scaled boundaries. + """ + boxes = [] + scores = [] + for b in boundaries: + sz = len(b) + valid_boundary(b, True) + scores.append(b[-1]) + b = (np.array(b[:sz - 1]) * + (np.tile(scale_factor[:2], int( + (sz - 1) / 2)).reshape(1, sz - 1))).flatten().tolist() + boxes.append(np.array(b).reshape([-1, 2])) + + return np.array(boxes, dtype=np.float32), scores + + def get_boundary(self, score_maps, shape_list): + assert len(score_maps) == len(self.scales) + boundaries = [] + for idx, score_map in enumerate(score_maps): + scale = self.scales[idx] + boundaries = boundaries + self._get_boundary_single(score_map, + scale) + + # nms + boundaries = poly_nms(boundaries, self.nms_thr) + boundaries, scores = self.resize_boundary( + boundaries, (1 / shape_list[0, 2:]).tolist()[::-1]) + + boxes_batch = [dict(points=boundaries, scores=scores)] + return boxes_batch + + def _get_boundary_single(self, score_map, scale): + assert len(score_map) == 2 + assert score_map[1].shape[1] == 4 * self.fourier_degree + 2 + + return self.fcenet_decode( + preds=score_map, + fourier_degree=self.fourier_degree, + num_reconstr_points=self.num_reconstr_points, + scale=scale, + alpha=self.alpha, + beta=self.beta, + box_type=self.box_type, + score_thr=self.score_thr, + nms_thr=self.nms_thr) + + def fcenet_decode(self, + preds, + fourier_degree, + num_reconstr_points, + scale, + alpha=1.0, + beta=2.0, + box_type='poly', + score_thr=0.3, + nms_thr=0.1): + """Decoding predictions of FCENet to instances. + + Args: + preds (list(Tensor)): The head output tensors. + fourier_degree (int): The maximum Fourier transform degree k. + num_reconstr_points (int): The points number of the polygon + reconstructed from predicted Fourier coefficients. + scale (int): The down-sample scale of the prediction. + alpha (float) : The parameter to calculate final scores. Score_{final} + = (Score_{text region} ^ alpha) + * (Score_{text center region}^ beta) + beta (float) : The parameter to calculate final score. + box_type (str): Boundary encoding type 'poly' or 'quad'. + score_thr (float) : The threshold used to filter out the final + candidates. + nms_thr (float) : The threshold of nms. + + Returns: + boundaries (list[list[float]]): The instance boundary and confidence + list. + """ + assert isinstance(preds, list) + assert len(preds) == 2 + assert box_type in ['poly', 'quad'] + + cls_pred = preds[0][0] + tr_pred = cls_pred[0:2] + tcl_pred = cls_pred[2:] + + reg_pred = preds[1][0].transpose([1, 2, 0]) + x_pred = reg_pred[:, :, :2 * fourier_degree + 1] + y_pred = reg_pred[:, :, 2 * fourier_degree + 1:] + + score_pred = (tr_pred[1]**alpha) * (tcl_pred[1]**beta) + tr_pred_mask = (score_pred) > score_thr + tr_mask = fill_hole(tr_pred_mask) + + tr_contours, _ = cv2.findContours( + tr_mask.astype(np.uint8), cv2.RETR_TREE, + cv2.CHAIN_APPROX_SIMPLE) # opencv4 + + mask = np.zeros_like(tr_mask) + boundaries = [] + for cont in tr_contours: + deal_map = mask.copy().astype(np.int8) + cv2.drawContours(deal_map, [cont], -1, 1, -1) + + score_map = score_pred * deal_map + score_mask = score_map > 0 + xy_text = np.argwhere(score_mask) + dxy = xy_text[:, 1] + xy_text[:, 0] * 1j + + x, y = x_pred[score_mask], y_pred[score_mask] + c = x + y * 1j + c[:, fourier_degree] = c[:, fourier_degree] + dxy + c *= scale + + polygons = fourier2poly(c, num_reconstr_points) + score = score_map[score_mask].reshape(-1, 1) + polygons = poly_nms(np.hstack((polygons, score)).tolist(), nms_thr) + + boundaries = boundaries + polygons + + boundaries = poly_nms(boundaries, nms_thr) + + if box_type == 'quad': + new_boundaries = [] + for boundary in boundaries: + poly = np.array(boundary[:-1]).reshape(-1, 2).astype(np.float32) + score = boundary[-1] + points = cv2.boxPoints(cv2.minAreaRect(poly)) + points = np.int0(points) + new_boundaries.append(points.reshape(-1).tolist() + [score]) + boundaries = new_boundaries + + return boundaries diff --git a/backend/ppocr/postprocess/locality_aware_nms.py b/backend/ppocr/postprocess/locality_aware_nms.py index 53280cc1..d305ef68 100644 --- a/backend/ppocr/postprocess/locality_aware_nms.py +++ b/backend/ppocr/postprocess/locality_aware_nms.py @@ -1,5 +1,6 @@ """ Locality aware nms. +This code is refered from: https://github.com/songdejia/EAST/blob/master/locality_aware_nms.py """ import numpy as np diff --git a/backend/ppocr/postprocess/pg_postprocess.py b/backend/ppocr/postprocess/pg_postprocess.py new file mode 100644 index 00000000..0b145518 --- /dev/null +++ b/backend/ppocr/postprocess/pg_postprocess.py @@ -0,0 +1,52 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys + +__dir__ = os.path.dirname(__file__) +sys.path.append(__dir__) +sys.path.append(os.path.join(__dir__, '..')) +from ppocr.utils.e2e_utils.pgnet_pp_utils import PGNet_PostProcess + + +class PGPostProcess(object): + """ + The post process for PGNet. + """ + + def __init__(self, character_dict_path, valid_set, score_thresh, mode, + **kwargs): + self.character_dict_path = character_dict_path + self.valid_set = valid_set + self.score_thresh = score_thresh + self.mode = mode + + # c++ la-nms is faster, but only support python 3.5 + self.is_python35 = False + if sys.version_info.major == 3 and sys.version_info.minor == 5: + self.is_python35 = True + + def __call__(self, outs_dict, shape_list): + post = PGNet_PostProcess(self.character_dict_path, self.valid_set, + self.score_thresh, outs_dict, shape_list) + if self.mode == 'fast': + data = post.pg_postprocess_fast() + else: + data = post.pg_postprocess_slow() + return data diff --git a/backend/ppocr/postprocess/pse_postprocess/__init__.py b/backend/ppocr/postprocess/pse_postprocess/__init__.py new file mode 100644 index 00000000..680473bf --- /dev/null +++ b/backend/ppocr/postprocess/pse_postprocess/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .pse_postprocess import PSEPostProcess \ No newline at end of file diff --git a/backend/ppocr/postprocess/pse_postprocess/pse/README.md b/backend/ppocr/postprocess/pse_postprocess/pse/README.md new file mode 100644 index 00000000..6a19d5d1 --- /dev/null +++ b/backend/ppocr/postprocess/pse_postprocess/pse/README.md @@ -0,0 +1,6 @@ +## 编译 +This code is refer from: +https://github.com/whai362/PSENet/blob/python3/models/post_processing/pse +```python +python3 setup.py build_ext --inplace +``` diff --git a/backend/ppocr/postprocess/pse_postprocess/pse/__init__.py b/backend/ppocr/postprocess/pse_postprocess/pse/__init__.py new file mode 100644 index 00000000..1903a914 --- /dev/null +++ b/backend/ppocr/postprocess/pse_postprocess/pse/__init__.py @@ -0,0 +1,29 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +import os +import subprocess + +python_path = sys.executable + +ori_path = os.getcwd() +os.chdir('ppocr/postprocess/pse_postprocess/pse') +if subprocess.call( + '{} setup.py build_ext --inplace'.format(python_path), shell=True) != 0: + raise RuntimeError( + 'Cannot compile pse: {}, if your system is windows, you need to install all the default components of `desktop development using C++` in visual studio 2019+'. + format(os.path.dirname(os.path.realpath(__file__)))) +os.chdir(ori_path) + +from .pse import pse diff --git a/backend/ppocr/postprocess/pse_postprocess/pse/pse.pyx b/backend/ppocr/postprocess/pse_postprocess/pse/pse.pyx new file mode 100644 index 00000000..b2be49e9 --- /dev/null +++ b/backend/ppocr/postprocess/pse_postprocess/pse/pse.pyx @@ -0,0 +1,70 @@ + +import numpy as np +import cv2 +cimport numpy as np +cimport cython +cimport libcpp +cimport libcpp.pair +cimport libcpp.queue +from libcpp.pair cimport * +from libcpp.queue cimport * + +@cython.boundscheck(False) +@cython.wraparound(False) +cdef np.ndarray[np.int32_t, ndim=2] _pse(np.ndarray[np.uint8_t, ndim=3] kernels, + np.ndarray[np.int32_t, ndim=2] label, + int kernel_num, + int label_num, + float min_area=0): + cdef np.ndarray[np.int32_t, ndim=2] pred + pred = np.zeros((label.shape[0], label.shape[1]), dtype=np.int32) + + for label_idx in range(1, label_num): + if np.sum(label == label_idx) < min_area: + label[label == label_idx] = 0 + + cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] que = \ + queue[libcpp.pair.pair[np.int16_t,np.int16_t]]() + cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] nxt_que = \ + queue[libcpp.pair.pair[np.int16_t,np.int16_t]]() + cdef np.int16_t* dx = [-1, 1, 0, 0] + cdef np.int16_t* dy = [0, 0, -1, 1] + cdef np.int16_t tmpx, tmpy + + points = np.array(np.where(label > 0)).transpose((1, 0)) + for point_idx in range(points.shape[0]): + tmpx, tmpy = points[point_idx, 0], points[point_idx, 1] + que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy)) + pred[tmpx, tmpy] = label[tmpx, tmpy] + + cdef libcpp.pair.pair[np.int16_t,np.int16_t] cur + cdef int cur_label + for kernel_idx in range(kernel_num - 1, -1, -1): + while not que.empty(): + cur = que.front() + que.pop() + cur_label = pred[cur.first, cur.second] + + is_edge = True + for j in range(4): + tmpx = cur.first + dx[j] + tmpy = cur.second + dy[j] + if tmpx < 0 or tmpx >= label.shape[0] or tmpy < 0 or tmpy >= label.shape[1]: + continue + if kernels[kernel_idx, tmpx, tmpy] == 0 or pred[tmpx, tmpy] > 0: + continue + + que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy)) + pred[tmpx, tmpy] = cur_label + is_edge = False + if is_edge: + nxt_que.push(cur) + + que, nxt_que = nxt_que, que + + return pred + +def pse(kernels, min_area): + kernel_num = kernels.shape[0] + label_num, label = cv2.connectedComponents(kernels[-1], connectivity=4) + return _pse(kernels[:-1], label, kernel_num, label_num, min_area) \ No newline at end of file diff --git a/backend/ppocr/postprocess/pse_postprocess/pse/setup.py b/backend/ppocr/postprocess/pse_postprocess/pse/setup.py new file mode 100644 index 00000000..03746782 --- /dev/null +++ b/backend/ppocr/postprocess/pse_postprocess/pse/setup.py @@ -0,0 +1,14 @@ +from distutils.core import setup, Extension +from Cython.Build import cythonize +import numpy + +setup(ext_modules=cythonize(Extension( + 'pse', + sources=['pse.pyx'], + language='c++', + include_dirs=[numpy.get_include()], + library_dirs=[], + libraries=[], + extra_compile_args=['-O3'], + extra_link_args=[] +))) diff --git a/backend/ppocr/postprocess/pse_postprocess/pse_postprocess.py b/backend/ppocr/postprocess/pse_postprocess/pse_postprocess.py new file mode 100755 index 00000000..34f1b8c9 --- /dev/null +++ b/backend/ppocr/postprocess/pse_postprocess/pse_postprocess.py @@ -0,0 +1,118 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/whai362/PSENet/blob/python3/models/head/psenet_head.py +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import cv2 +import paddle +from paddle.nn import functional as F + +from ppocr.postprocess.pse_postprocess.pse import pse + + +class PSEPostProcess(object): + """ + The post process for PSE. + """ + + def __init__(self, + thresh=0.5, + box_thresh=0.85, + min_area=16, + box_type='quad', + scale=4, + **kwargs): + assert box_type in ['quad', 'poly'], 'Only quad and poly is supported' + self.thresh = thresh + self.box_thresh = box_thresh + self.min_area = min_area + self.box_type = box_type + self.scale = scale + + def __call__(self, outs_dict, shape_list): + pred = outs_dict['maps'] + if not isinstance(pred, paddle.Tensor): + pred = paddle.to_tensor(pred) + pred = F.interpolate( + pred, scale_factor=4 // self.scale, mode='bilinear') + + score = F.sigmoid(pred[:, 0, :, :]) + + kernels = (pred > self.thresh).astype('float32') + text_mask = kernels[:, 0, :, :] + kernels[:, 0:, :, :] = kernels[:, 0:, :, :] * text_mask + + score = score.numpy() + kernels = kernels.numpy().astype(np.uint8) + + boxes_batch = [] + for batch_index in range(pred.shape[0]): + boxes, scores = self.boxes_from_bitmap(score[batch_index], + kernels[batch_index], + shape_list[batch_index]) + + boxes_batch.append({'points': boxes, 'scores': scores}) + return boxes_batch + + def boxes_from_bitmap(self, score, kernels, shape): + label = pse(kernels, self.min_area) + return self.generate_box(score, label, shape) + + def generate_box(self, score, label, shape): + src_h, src_w, ratio_h, ratio_w = shape + label_num = np.max(label) + 1 + + boxes = [] + scores = [] + for i in range(1, label_num): + ind = label == i + points = np.array(np.where(ind)).transpose((1, 0))[:, ::-1] + + if points.shape[0] < self.min_area: + label[ind] = 0 + continue + + score_i = np.mean(score[ind]) + if score_i < self.box_thresh: + label[ind] = 0 + continue + + if self.box_type == 'quad': + rect = cv2.minAreaRect(points) + bbox = cv2.boxPoints(rect) + elif self.box_type == 'poly': + box_height = np.max(points[:, 1]) + 10 + box_width = np.max(points[:, 0]) + 10 + + mask = np.zeros((box_height, box_width), np.uint8) + mask[points[:, 1], points[:, 0]] = 255 + + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, + cv2.CHAIN_APPROX_SIMPLE) + bbox = np.squeeze(contours[0], 1) + else: + raise NotImplementedError + + bbox[:, 0] = np.clip(np.round(bbox[:, 0] / ratio_w), 0, src_w) + bbox[:, 1] = np.clip(np.round(bbox[:, 1] / ratio_h), 0, src_h) + boxes.append(bbox) + scores.append(score_i) + return boxes, scores diff --git a/backend/ppocr/postprocess/rec_postprocess.py b/backend/ppocr/postprocess/rec_postprocess.py index b0517982..bf0fd890 100644 --- a/backend/ppocr/postprocess/rec_postprocess.py +++ b/backend/ppocr/postprocess/rec_postprocess.py @@ -11,54 +11,34 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import numpy as np -import string import paddle from paddle.nn import functional as F +import re class BaseRecLabelDecode(object): """ Convert between text-label and text-index """ - def __init__(self, - character_dict_path=None, - character_type='ch', - use_space_char=False): - support_character_type = [ - 'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean', - 'it', 'es', 'pt', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs_latin', - 'oc', 'rs_cyrillic', 'bg', 'uk', 'be', 'te', 'kn', 'ch_tra', 'hi', - 'mr', 'ne', 'EN' - ] - assert character_type in support_character_type, "Only {} are supported now but get {}".format( - support_character_type, character_type) - + def __init__(self, character_dict_path=None, use_space_char=False): self.beg_str = "sos" self.end_str = "eos" - if character_type == "en": + self.character_str = [] + if character_dict_path is None: self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" dict_character = list(self.character_str) - elif character_type == "EN_symbol": - # same with ASTER setting (use 94 char). - self.character_str = string.printable[:-6] - dict_character = list(self.character_str) - elif character_type in support_character_type: - self.character_str = "" - assert character_dict_path is not None, "character_dict_path should not be None when character_type is {}".format( - character_type) + else: with open(character_dict_path, "rb") as fin: lines = fin.readlines() for line in lines: line = line.decode('utf-8').strip("\n").strip("\r\n") - self.character_str += line + self.character_str.append(line) if use_space_char: - self.character_str += " " + self.character_str.append(" ") dict_character = list(self.character_str) - else: - raise NotImplementedError - self.character_type = character_type dict_character = self.add_special_char(dict_character) self.dict = {} for i, char in enumerate(dict_character): @@ -74,24 +54,26 @@ def decode(self, text_index, text_prob=None, is_remove_duplicate=False): ignored_tokens = self.get_ignored_tokens() batch_size = len(text_index) for batch_idx in range(batch_size): - char_list = [] - conf_list = [] - for idx in range(len(text_index[batch_idx])): - if text_index[batch_idx][idx] in ignored_tokens: - continue - if is_remove_duplicate: - # only for predict - if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ - batch_idx][idx]: - continue - char_list.append(self.character[int(text_index[batch_idx][ - idx])]) - if text_prob is not None: - conf_list.append(text_prob[batch_idx][idx]) - else: - conf_list.append(1) + selection = np.ones(len(text_index[batch_idx]), dtype=bool) + if is_remove_duplicate: + selection[1:] = text_index[batch_idx][1:] != text_index[ + batch_idx][:-1] + for ignored_token in ignored_tokens: + selection &= text_index[batch_idx] != ignored_token + + char_list = [ + self.character[text_id] + for text_id in text_index[batch_idx][selection] + ] + if text_prob is not None: + conf_list = text_prob[batch_idx][selection] + else: + conf_list = [1] * len(selection) + if len(conf_list) == 0: + conf_list = [0] + text = ''.join(char_list) - result_list.append((text, np.mean(conf_list))) + result_list.append((text, np.mean(conf_list).tolist())) return result_list def get_ignored_tokens(self): @@ -101,15 +83,14 @@ def get_ignored_tokens(self): class CTCLabelDecode(BaseRecLabelDecode): """ Convert between text-label and text-index """ - def __init__(self, - character_dict_path=None, - character_type='ch', - use_space_char=False, + def __init__(self, character_dict_path=None, use_space_char=False, **kwargs): super(CTCLabelDecode, self).__init__(character_dict_path, - character_type, use_space_char) + use_space_char) def __call__(self, preds, label=None, *args, **kwargs): + if isinstance(preds, tuple) or isinstance(preds, list): + preds = preds[-1] if isinstance(preds, paddle.Tensor): preds = preds.numpy() preds_idx = preds.argmax(axis=2) @@ -125,16 +106,111 @@ def add_special_char(self, dict_character): return dict_character -class AttnLabelDecode(BaseRecLabelDecode): - """ Convert between text-label and text-index """ +class DistillationCTCLabelDecode(CTCLabelDecode): + """ + Convert + Convert between text-label and text-index + """ def __init__(self, character_dict_path=None, - character_type='ch', use_space_char=False, + model_name=["student"], + key=None, + multi_head=False, + **kwargs): + super(DistillationCTCLabelDecode, self).__init__(character_dict_path, + use_space_char) + if not isinstance(model_name, list): + model_name = [model_name] + self.model_name = model_name + + self.key = key + self.multi_head = multi_head + + def __call__(self, preds, label=None, *args, **kwargs): + output = dict() + for name in self.model_name: + pred = preds[name] + if self.key is not None: + pred = pred[self.key] + if self.multi_head and isinstance(pred, dict): + pred = pred['ctc'] + output[name] = super().__call__(pred, label=label, *args, **kwargs) + return output + + +class NRTRLabelDecode(BaseRecLabelDecode): + """ Convert between text-label and text-index """ + + def __init__(self, character_dict_path=None, use_space_char=True, **kwargs): + super(NRTRLabelDecode, self).__init__(character_dict_path, + use_space_char) + + def __call__(self, preds, label=None, *args, **kwargs): + + if len(preds) == 2: + preds_id = preds[0] + preds_prob = preds[1] + if isinstance(preds_id, paddle.Tensor): + preds_id = preds_id.numpy() + if isinstance(preds_prob, paddle.Tensor): + preds_prob = preds_prob.numpy() + if preds_id[0][0] == 2: + preds_idx = preds_id[:, 1:] + preds_prob = preds_prob[:, 1:] + else: + preds_idx = preds_id + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) + if label is None: + return text + label = self.decode(label[:, 1:]) + else: + if isinstance(preds, paddle.Tensor): + preds = preds.numpy() + preds_idx = preds.argmax(axis=2) + preds_prob = preds.max(axis=2) + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) + if label is None: + return text + label = self.decode(label[:, 1:]) + return text, label + + def add_special_char(self, dict_character): + dict_character = ['blank', '', '', ''] + dict_character + return dict_character + + def decode(self, text_index, text_prob=None, is_remove_duplicate=False): + """ convert text-index into text-label. """ + result_list = [] + batch_size = len(text_index) + for batch_idx in range(batch_size): + char_list = [] + conf_list = [] + for idx in range(len(text_index[batch_idx])): + if text_index[batch_idx][idx] == 3: # end + break + try: + char_list.append(self.character[int(text_index[batch_idx][ + idx])]) + except: + continue + if text_prob is not None: + conf_list.append(text_prob[batch_idx][idx]) + else: + conf_list.append(1) + text = ''.join(char_list) + result_list.append((text.lower(), np.mean(conf_list).tolist())) + return result_list + + +class AttnLabelDecode(BaseRecLabelDecode): + """ Convert between text-label and text-index """ + + def __init__(self, character_dict_path=None, use_space_char=False, **kwargs): super(AttnLabelDecode, self).__init__(character_dict_path, - character_type, use_space_char) + use_space_char) def add_special_char(self, dict_character): self.beg_str = "sos" @@ -169,7 +245,7 @@ def decode(self, text_index, text_prob=None, is_remove_duplicate=False): else: conf_list.append(1) text = ''.join(char_list) - result_list.append((text, np.mean(conf_list))) + result_list.append((text, np.mean(conf_list).tolist())) return result_list def __call__(self, preds, label=None, *args, **kwargs): @@ -208,16 +284,95 @@ def get_beg_end_flag_idx(self, beg_or_end): return idx +class SEEDLabelDecode(BaseRecLabelDecode): + """ Convert between text-label and text-index """ + + def __init__(self, character_dict_path=None, use_space_char=False, + **kwargs): + super(SEEDLabelDecode, self).__init__(character_dict_path, + use_space_char) + + def add_special_char(self, dict_character): + self.padding_str = "padding" + self.end_str = "eos" + self.unknown = "unknown" + dict_character = dict_character + [ + self.end_str, self.padding_str, self.unknown + ] + return dict_character + + def get_ignored_tokens(self): + end_idx = self.get_beg_end_flag_idx("eos") + return [end_idx] + + def get_beg_end_flag_idx(self, beg_or_end): + if beg_or_end == "sos": + idx = np.array(self.dict[self.beg_str]) + elif beg_or_end == "eos": + idx = np.array(self.dict[self.end_str]) + else: + assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end + return idx + + def decode(self, text_index, text_prob=None, is_remove_duplicate=False): + """ convert text-index into text-label. """ + result_list = [] + [end_idx] = self.get_ignored_tokens() + batch_size = len(text_index) + for batch_idx in range(batch_size): + char_list = [] + conf_list = [] + for idx in range(len(text_index[batch_idx])): + if int(text_index[batch_idx][idx]) == int(end_idx): + break + if is_remove_duplicate: + # only for predict + if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ + batch_idx][idx]: + continue + char_list.append(self.character[int(text_index[batch_idx][ + idx])]) + if text_prob is not None: + conf_list.append(text_prob[batch_idx][idx]) + else: + conf_list.append(1) + text = ''.join(char_list) + result_list.append((text, np.mean(conf_list).tolist())) + return result_list + + def __call__(self, preds, label=None, *args, **kwargs): + """ + text = self.decode(text) + if label is None: + return text + else: + label = self.decode(label, is_remove_duplicate=False) + return text, label + """ + preds_idx = preds["rec_pred"] + if isinstance(preds_idx, paddle.Tensor): + preds_idx = preds_idx.numpy() + if "rec_pred_scores" in preds: + preds_idx = preds["rec_pred"] + preds_prob = preds["rec_pred_scores"] + else: + preds_idx = preds["rec_pred"].argmax(axis=2) + preds_prob = preds["rec_pred"].max(axis=2) + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) + if label is None: + return text + label = self.decode(label, is_remove_duplicate=False) + return text, label + + class SRNLabelDecode(BaseRecLabelDecode): """ Convert between text-label and text-index """ - def __init__(self, - character_dict_path=None, - character_type='en', - use_space_char=False, + def __init__(self, character_dict_path=None, use_space_char=False, **kwargs): super(SRNLabelDecode, self).__init__(character_dict_path, - character_type, use_space_char) + use_space_char) + self.max_text_length = kwargs.get('max_text_length', 25) def __call__(self, preds, label=None, *args, **kwargs): pred = preds['predict'] @@ -229,9 +384,9 @@ def __call__(self, preds, label=None, *args, **kwargs): preds_idx = np.argmax(pred, axis=1) preds_prob = np.max(pred, axis=1) - preds_idx = np.reshape(preds_idx, [-1, 25]) + preds_idx = np.reshape(preds_idx, [-1, self.max_text_length]) - preds_prob = np.reshape(preds_prob, [-1, 25]) + preds_prob = np.reshape(preds_prob, [-1, self.max_text_length]) text = self.decode(preds_idx, preds_prob) @@ -266,7 +421,7 @@ def decode(self, text_index, text_prob=None, is_remove_duplicate=False): conf_list.append(1) text = ''.join(char_list) - result_list.append((text, np.mean(conf_list))) + result_list.append((text, np.mean(conf_list).tolist())) return result_list def add_special_char(self, dict_character): @@ -287,3 +442,313 @@ def get_beg_end_flag_idx(self, beg_or_end): assert False, "unsupport type %s in get_beg_end_flag_idx" \ % beg_or_end return idx + + +class TableLabelDecode(object): + """ """ + + def __init__(self, character_dict_path, **kwargs): + list_character, list_elem = self.load_char_elem_dict( + character_dict_path) + list_character = self.add_special_char(list_character) + list_elem = self.add_special_char(list_elem) + self.dict_character = {} + self.dict_idx_character = {} + for i, char in enumerate(list_character): + self.dict_idx_character[i] = char + self.dict_character[char] = i + self.dict_elem = {} + self.dict_idx_elem = {} + for i, elem in enumerate(list_elem): + self.dict_idx_elem[i] = elem + self.dict_elem[elem] = i + + def load_char_elem_dict(self, character_dict_path): + list_character = [] + list_elem = [] + with open(character_dict_path, "rb") as fin: + lines = fin.readlines() + substr = lines[0].decode('utf-8').strip("\n").strip("\r\n").split( + "\t") + character_num = int(substr[0]) + elem_num = int(substr[1]) + for cno in range(1, 1 + character_num): + character = lines[cno].decode('utf-8').strip("\n").strip("\r\n") + list_character.append(character) + for eno in range(1 + character_num, 1 + character_num + elem_num): + elem = lines[eno].decode('utf-8').strip("\n").strip("\r\n") + list_elem.append(elem) + return list_character, list_elem + + def add_special_char(self, list_character): + self.beg_str = "sos" + self.end_str = "eos" + list_character = [self.beg_str] + list_character + [self.end_str] + return list_character + + def __call__(self, preds): + structure_probs = preds['structure_probs'] + loc_preds = preds['loc_preds'] + if isinstance(structure_probs, paddle.Tensor): + structure_probs = structure_probs.numpy() + if isinstance(loc_preds, paddle.Tensor): + loc_preds = loc_preds.numpy() + structure_idx = structure_probs.argmax(axis=2) + structure_probs = structure_probs.max(axis=2) + structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode( + structure_idx, structure_probs, 'elem') + res_html_code_list = [] + res_loc_list = [] + batch_num = len(structure_str) + for bno in range(batch_num): + res_loc = [] + for sno in range(len(structure_str[bno])): + text = structure_str[bno][sno] + if text in ['', ' 0 and tmp_elem_idx == end_idx: + break + if tmp_elem_idx in ignored_tokens: + continue + + char_list.append(current_dict[tmp_elem_idx]) + elem_pos_list.append(idx) + score_list.append(structure_probs[batch_idx, idx]) + elem_idx_list.append(tmp_elem_idx) + result_list.append(char_list) + result_pos_list.append(elem_pos_list) + result_score_list.append(score_list) + result_elem_idx_list.append(elem_idx_list) + return result_list, result_pos_list, result_score_list, result_elem_idx_list + + def get_ignored_tokens(self, char_or_elem): + beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem) + end_idx = self.get_beg_end_flag_idx("end", char_or_elem) + return [beg_idx, end_idx] + + def get_beg_end_flag_idx(self, beg_or_end, char_or_elem): + if char_or_elem == "char": + if beg_or_end == "beg": + idx = self.dict_character[self.beg_str] + elif beg_or_end == "end": + idx = self.dict_character[self.end_str] + else: + assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \ + % beg_or_end + elif char_or_elem == "elem": + if beg_or_end == "beg": + idx = self.dict_elem[self.beg_str] + elif beg_or_end == "end": + idx = self.dict_elem[self.end_str] + else: + assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \ + % beg_or_end + else: + assert False, "Unsupport type %s in char_or_elem" \ + % char_or_elem + return idx + + +class SARLabelDecode(BaseRecLabelDecode): + """ Convert between text-label and text-index """ + + def __init__(self, character_dict_path=None, use_space_char=False, + **kwargs): + super(SARLabelDecode, self).__init__(character_dict_path, + use_space_char) + + self.rm_symbol = kwargs.get('rm_symbol', False) + + def add_special_char(self, dict_character): + beg_end_str = "" + unknown_str = "" + padding_str = "" + dict_character = dict_character + [unknown_str] + self.unknown_idx = len(dict_character) - 1 + dict_character = dict_character + [beg_end_str] + self.start_idx = len(dict_character) - 1 + self.end_idx = len(dict_character) - 1 + dict_character = dict_character + [padding_str] + self.padding_idx = len(dict_character) - 1 + return dict_character + + def decode(self, text_index, text_prob=None, is_remove_duplicate=False): + """ convert text-index into text-label. """ + result_list = [] + ignored_tokens = self.get_ignored_tokens() + + batch_size = len(text_index) + for batch_idx in range(batch_size): + char_list = [] + conf_list = [] + for idx in range(len(text_index[batch_idx])): + if text_index[batch_idx][idx] in ignored_tokens: + continue + if int(text_index[batch_idx][idx]) == int(self.end_idx): + if text_prob is None and idx == 0: + continue + else: + break + if is_remove_duplicate: + # only for predict + if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ + batch_idx][idx]: + continue + char_list.append(self.character[int(text_index[batch_idx][ + idx])]) + if text_prob is not None: + conf_list.append(text_prob[batch_idx][idx]) + else: + conf_list.append(1) + text = ''.join(char_list) + if self.rm_symbol: + comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]') + text = text.lower() + text = comp.sub('', text) + result_list.append((text, np.mean(conf_list).tolist())) + return result_list + + def __call__(self, preds, label=None, *args, **kwargs): + if isinstance(preds, paddle.Tensor): + preds = preds.numpy() + preds_idx = preds.argmax(axis=2) + preds_prob = preds.max(axis=2) + + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) + + if label is None: + return text + label = self.decode(label, is_remove_duplicate=False) + return text, label + + def get_ignored_tokens(self): + return [self.padding_idx] + + +class DistillationSARLabelDecode(SARLabelDecode): + """ + Convert + Convert between text-label and text-index + """ + + def __init__(self, + character_dict_path=None, + use_space_char=False, + model_name=["student"], + key=None, + multi_head=False, + **kwargs): + super(DistillationSARLabelDecode, self).__init__(character_dict_path, + use_space_char) + if not isinstance(model_name, list): + model_name = [model_name] + self.model_name = model_name + + self.key = key + self.multi_head = multi_head + + def __call__(self, preds, label=None, *args, **kwargs): + output = dict() + for name in self.model_name: + pred = preds[name] + if self.key is not None: + pred = pred[self.key] + if self.multi_head and isinstance(pred, dict): + pred = pred['sar'] + output[name] = super().__call__(pred, label=label, *args, **kwargs) + return output + + +class PRENLabelDecode(BaseRecLabelDecode): + """ Convert between text-label and text-index """ + + def __init__(self, character_dict_path=None, use_space_char=False, + **kwargs): + super(PRENLabelDecode, self).__init__(character_dict_path, + use_space_char) + + def add_special_char(self, dict_character): + padding_str = '' # 0 + end_str = '' # 1 + unknown_str = '' # 2 + + dict_character = [padding_str, end_str, unknown_str] + dict_character + self.padding_idx = 0 + self.end_idx = 1 + self.unknown_idx = 2 + + return dict_character + + def decode(self, text_index, text_prob=None): + """ convert text-index into text-label. """ + result_list = [] + batch_size = len(text_index) + + for batch_idx in range(batch_size): + char_list = [] + conf_list = [] + for idx in range(len(text_index[batch_idx])): + if text_index[batch_idx][idx] == self.end_idx: + break + if text_index[batch_idx][idx] in \ + [self.padding_idx, self.unknown_idx]: + continue + char_list.append(self.character[int(text_index[batch_idx][ + idx])]) + if text_prob is not None: + conf_list.append(text_prob[batch_idx][idx]) + else: + conf_list.append(1) + + text = ''.join(char_list) + if len(text) > 0: + result_list.append((text, np.mean(conf_list).tolist())) + else: + # here confidence of empty recog result is 1 + result_list.append(('', 1)) + return result_list + + def __call__(self, preds, label=None, *args, **kwargs): + preds = preds.numpy() + preds_idx = preds.argmax(axis=2) + preds_prob = preds.max(axis=2) + text = self.decode(preds_idx, preds_prob) + if label is None: + return text + label = self.decode(label) + return text, label diff --git a/backend/ppocr/postprocess/sast_postprocess.py b/backend/ppocr/postprocess/sast_postprocess.py index f011e7e5..bee75c05 100755 --- a/backend/ppocr/postprocess/sast_postprocess.py +++ b/backend/ppocr/postprocess/sast_postprocess.py @@ -18,6 +18,7 @@ import os import sys + __dir__ = os.path.dirname(__file__) sys.path.append(__dir__) sys.path.append(os.path.join(__dir__, '..')) @@ -49,12 +50,12 @@ def __init__(self, self.shrink_ratio_of_width = shrink_ratio_of_width self.expand_scale = expand_scale self.tcl_map_thresh = tcl_map_thresh - + # c++ la-nms is faster, but only support python 3.5 self.is_python35 = False if sys.version_info.major == 3 and sys.version_info.minor == 5: self.is_python35 = True - + def point_pair2poly(self, point_pair_list): """ Transfer vertical point_pairs into poly point in clockwise. @@ -66,31 +67,42 @@ def point_pair2poly(self, point_pair_list): point_list[idx] = point_pair[0] point_list[point_num - 1 - idx] = point_pair[1] return np.array(point_list).reshape(-1, 2) - - def shrink_quad_along_width(self, quad, begin_width_ratio=0., end_width_ratio=1.): + + def shrink_quad_along_width(self, + quad, + begin_width_ratio=0., + end_width_ratio=1.): """ Generate shrink_quad_along_width. """ - ratio_pair = np.array([[begin_width_ratio], [end_width_ratio]], dtype=np.float32) + ratio_pair = np.array( + [[begin_width_ratio], [end_width_ratio]], dtype=np.float32) p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]]) - + def expand_poly_along_width(self, poly, shrink_ratio_of_width=0.3): """ expand poly along width. """ point_num = poly.shape[0] - left_quad = np.array([poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32) + left_quad = np.array( + [poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32) left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \ - (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6) - left_quad_expand = self.shrink_quad_along_width(left_quad, left_ratio, 1.0) - right_quad = np.array([poly[point_num // 2 - 2], poly[point_num // 2 - 1], - poly[point_num // 2], poly[point_num // 2 + 1]], dtype=np.float32) + (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6) + left_quad_expand = self.shrink_quad_along_width(left_quad, left_ratio, + 1.0) + right_quad = np.array( + [ + poly[point_num // 2 - 2], poly[point_num // 2 - 1], + poly[point_num // 2], poly[point_num // 2 + 1] + ], + dtype=np.float32) right_ratio = 1.0 + \ - shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \ - (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6) - right_quad_expand = self.shrink_quad_along_width(right_quad, 0.0, right_ratio) + shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \ + (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6) + right_quad_expand = self.shrink_quad_along_width(right_quad, 0.0, + right_ratio) poly[0] = left_quad_expand[0] poly[-1] = left_quad_expand[-1] poly[point_num // 2 - 1] = right_quad_expand[1] @@ -100,7 +112,7 @@ def expand_poly_along_width(self, poly, shrink_ratio_of_width=0.3): def restore_quad(self, tcl_map, tcl_map_thresh, tvo_map): """Restore quad.""" xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh) - xy_text = xy_text[:, ::-1] # (n, 2) + xy_text = xy_text[:, ::-1] # (n, 2) # Sort the text boxes via the y axis xy_text = xy_text[np.argsort(xy_text[:, 1])] @@ -112,7 +124,7 @@ def restore_quad(self, tcl_map, tcl_map_thresh, tvo_map): point_num = int(tvo_map.shape[-1] / 2) assert point_num == 4 tvo_map = tvo_map[xy_text[:, 1], xy_text[:, 0], :] - xy_text_tile = np.tile(xy_text, (1, point_num)) # (n, point_num * 2) + xy_text_tile = np.tile(xy_text, (1, point_num)) # (n, point_num * 2) quads = xy_text_tile - tvo_map return scores, quads, xy_text @@ -121,14 +133,12 @@ def quad_area(self, quad): """ compute area of a quad. """ - edge = [ - (quad[1][0] - quad[0][0]) * (quad[1][1] + quad[0][1]), - (quad[2][0] - quad[1][0]) * (quad[2][1] + quad[1][1]), - (quad[3][0] - quad[2][0]) * (quad[3][1] + quad[2][1]), - (quad[0][0] - quad[3][0]) * (quad[0][1] + quad[3][1]) - ] + edge = [(quad[1][0] - quad[0][0]) * (quad[1][1] + quad[0][1]), + (quad[2][0] - quad[1][0]) * (quad[2][1] + quad[1][1]), + (quad[3][0] - quad[2][0]) * (quad[3][1] + quad[2][1]), + (quad[0][0] - quad[3][0]) * (quad[0][1] + quad[3][1])] return np.sum(edge) / 2. - + def nms(self, dets): if self.is_python35: import lanms @@ -141,7 +151,7 @@ def cluster_by_quads_tco(self, tcl_map, tcl_map_thresh, quads, tco_map): """ Cluster pixels in tcl_map based on quads. """ - instance_count = quads.shape[0] + 1 # contain background + instance_count = quads.shape[0] + 1 # contain background instance_label_map = np.zeros(tcl_map.shape[:2], dtype=np.int32) if instance_count == 1: return instance_count, instance_label_map @@ -149,18 +159,19 @@ def cluster_by_quads_tco(self, tcl_map, tcl_map_thresh, quads, tco_map): # predict text center xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh) n = xy_text.shape[0] - xy_text = xy_text[:, ::-1] # (n, 2) - tco = tco_map[xy_text[:, 1], xy_text[:, 0], :] # (n, 2) + xy_text = xy_text[:, ::-1] # (n, 2) + tco = tco_map[xy_text[:, 1], xy_text[:, 0], :] # (n, 2) pred_tc = xy_text - tco - + # get gt text center m = quads.shape[0] - gt_tc = np.mean(quads, axis=1) # (m, 2) + gt_tc = np.mean(quads, axis=1) # (m, 2) - pred_tc_tile = np.tile(pred_tc[:, np.newaxis, :], (1, m, 1)) # (n, m, 2) - gt_tc_tile = np.tile(gt_tc[np.newaxis, :, :], (n, 1, 1)) # (n, m, 2) - dist_mat = np.linalg.norm(pred_tc_tile - gt_tc_tile, axis=2) # (n, m) - xy_text_assign = np.argmin(dist_mat, axis=1) + 1 # (n,) + pred_tc_tile = np.tile(pred_tc[:, np.newaxis, :], + (1, m, 1)) # (n, m, 2) + gt_tc_tile = np.tile(gt_tc[np.newaxis, :, :], (n, 1, 1)) # (n, m, 2) + dist_mat = np.linalg.norm(pred_tc_tile - gt_tc_tile, axis=2) # (n, m) + xy_text_assign = np.argmin(dist_mat, axis=1) + 1 # (n,) instance_label_map[xy_text[:, 1], xy_text[:, 0]] = xy_text_assign return instance_count, instance_label_map @@ -169,26 +180,47 @@ def estimate_sample_pts_num(self, quad, xy_text): """ Estimate sample points number. """ - eh = (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2])) / 2.0 - ew = (np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])) / 2.0 + eh = (np.linalg.norm(quad[0] - quad[3]) + + np.linalg.norm(quad[1] - quad[2])) / 2.0 + ew = (np.linalg.norm(quad[0] - quad[1]) + + np.linalg.norm(quad[2] - quad[3])) / 2.0 dense_sample_pts_num = max(2, int(ew)) - dense_xy_center_line = xy_text[np.linspace(0, xy_text.shape[0] - 1, dense_sample_pts_num, - endpoint=True, dtype=np.float32).astype(np.int32)] - - dense_xy_center_line_diff = dense_xy_center_line[1:] - dense_xy_center_line[:-1] - estimate_arc_len = np.sum(np.linalg.norm(dense_xy_center_line_diff, axis=1)) + dense_xy_center_line = xy_text[np.linspace( + 0, + xy_text.shape[0] - 1, + dense_sample_pts_num, + endpoint=True, + dtype=np.float32).astype(np.int32)] + + dense_xy_center_line_diff = dense_xy_center_line[ + 1:] - dense_xy_center_line[:-1] + estimate_arc_len = np.sum( + np.linalg.norm( + dense_xy_center_line_diff, axis=1)) sample_pts_num = max(2, int(estimate_arc_len / eh)) return sample_pts_num - def detect_sast(self, tcl_map, tvo_map, tbo_map, tco_map, ratio_w, ratio_h, src_w, src_h, - shrink_ratio_of_width=0.3, tcl_map_thresh=0.5, offset_expand=1.0, out_strid=4.0): + def detect_sast(self, + tcl_map, + tvo_map, + tbo_map, + tco_map, + ratio_w, + ratio_h, + src_w, + src_h, + shrink_ratio_of_width=0.3, + tcl_map_thresh=0.5, + offset_expand=1.0, + out_strid=4.0): """ first resize the tcl_map, tvo_map and tbo_map to the input_size, then restore the polys """ # restore quad - scores, quads, xy_text = self.restore_quad(tcl_map, tcl_map_thresh, tvo_map) + scores, quads, xy_text = self.restore_quad(tcl_map, tcl_map_thresh, + tvo_map) dets = np.hstack((quads, scores)).astype(np.float32, copy=False) dets = self.nms(dets) if dets.shape[0] == 0: @@ -202,7 +234,8 @@ def detect_sast(self, tcl_map, tvo_map, tbo_map, tco_map, ratio_w, ratio_h, src_ # instance segmentation # instance_count, instance_label_map = cv2.connectedComponents(tcl_map.astype(np.uint8), connectivity=8) - instance_count, instance_label_map = self.cluster_by_quads_tco(tcl_map, tcl_map_thresh, quads, tco_map) + instance_count, instance_label_map = self.cluster_by_quads_tco( + tcl_map, tcl_map_thresh, quads, tco_map) # restore single poly with tcl instance. poly_list = [] @@ -212,10 +245,10 @@ def detect_sast(self, tcl_map, tvo_map, tbo_map, tco_map, ratio_w, ratio_h, src_ q_area = quad_areas[instance_idx - 1] if q_area < 5: continue - + # - len1 = float(np.linalg.norm(quad[0] -quad[1])) - len2 = float(np.linalg.norm(quad[1] -quad[2])) + len1 = float(np.linalg.norm(quad[0] - quad[1])) + len2 = float(np.linalg.norm(quad[1] - quad[2])) min_len = min(len1, len2) if min_len < 3: continue @@ -225,16 +258,18 @@ def detect_sast(self, tcl_map, tvo_map, tbo_map, tco_map, ratio_w, ratio_h, src_ continue # filter low confidence instance - xy_text_scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0] + xy_text_scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0] if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.1: - # if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.05: + # if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.05: continue # sort xy_text - left_center_pt = np.array([[(quad[0, 0] + quad[-1, 0]) / 2.0, - (quad[0, 1] + quad[-1, 1]) / 2.0]]) # (1, 2) - right_center_pt = np.array([[(quad[1, 0] + quad[2, 0]) / 2.0, - (quad[1, 1] + quad[2, 1]) / 2.0]]) # (1, 2) + left_center_pt = np.array( + [[(quad[0, 0] + quad[-1, 0]) / 2.0, + (quad[0, 1] + quad[-1, 1]) / 2.0]]) # (1, 2) + right_center_pt = np.array( + [[(quad[1, 0] + quad[2, 0]) / 2.0, + (quad[1, 1] + quad[2, 1]) / 2.0]]) # (1, 2) proj_unit_vec = (right_center_pt - left_center_pt) / \ (np.linalg.norm(right_center_pt - left_center_pt) + 1e-6) proj_value = np.sum(xy_text * proj_unit_vec, axis=1) @@ -245,33 +280,45 @@ def detect_sast(self, tcl_map, tvo_map, tbo_map, tco_map, ratio_w, ratio_h, src_ sample_pts_num = self.estimate_sample_pts_num(quad, xy_text) else: sample_pts_num = self.sample_pts_num - xy_center_line = xy_text[np.linspace(0, xy_text.shape[0] - 1, sample_pts_num, - endpoint=True, dtype=np.float32).astype(np.int32)] + xy_center_line = xy_text[np.linspace( + 0, + xy_text.shape[0] - 1, + sample_pts_num, + endpoint=True, + dtype=np.float32).astype(np.int32)] point_pair_list = [] for x, y in xy_center_line: # get corresponding offset offset = tbo_map[y, x, :].reshape(2, 2) if offset_expand != 1.0: - offset_length = np.linalg.norm(offset, axis=1, keepdims=True) - expand_length = np.clip(offset_length * (offset_expand - 1), a_min=0.5, a_max=3.0) + offset_length = np.linalg.norm( + offset, axis=1, keepdims=True) + expand_length = np.clip( + offset_length * (offset_expand - 1), + a_min=0.5, + a_max=3.0) offset_detal = offset / offset_length * expand_length - offset = offset + offset_detal - # original point + offset = offset + offset_detal + # original point ori_yx = np.array([y, x], dtype=np.float32) - point_pair = (ori_yx + offset)[:, ::-1]* out_strid / np.array([ratio_w, ratio_h]).reshape(-1, 2) + point_pair = (ori_yx + offset)[:, ::-1] * out_strid / np.array( + [ratio_w, ratio_h]).reshape(-1, 2) point_pair_list.append(point_pair) # ndarry: (x, 2), expand poly along width detected_poly = self.point_pair2poly(point_pair_list) - detected_poly = self.expand_poly_along_width(detected_poly, shrink_ratio_of_width) - detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w) - detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h) + detected_poly = self.expand_poly_along_width(detected_poly, + shrink_ratio_of_width) + detected_poly[:, 0] = np.clip( + detected_poly[:, 0], a_min=0, a_max=src_w) + detected_poly[:, 1] = np.clip( + detected_poly[:, 1], a_min=0, a_max=src_h) poly_list.append(detected_poly) return poly_list - def __call__(self, outs_dict, shape_list): + def __call__(self, outs_dict, shape_list): score_list = outs_dict['f_score'] border_list = outs_dict['f_border'] tvo_list = outs_dict['f_tvo'] @@ -281,20 +328,28 @@ def __call__(self, outs_dict, shape_list): border_list = border_list.numpy() tvo_list = tvo_list.numpy() tco_list = tco_list.numpy() - + img_num = len(shape_list) poly_lists = [] for ino in range(img_num): - p_score = score_list[ino].transpose((1,2,0)) - p_border = border_list[ino].transpose((1,2,0)) - p_tvo = tvo_list[ino].transpose((1,2,0)) - p_tco = tco_list[ino].transpose((1,2,0)) + p_score = score_list[ino].transpose((1, 2, 0)) + p_border = border_list[ino].transpose((1, 2, 0)) + p_tvo = tvo_list[ino].transpose((1, 2, 0)) + p_tco = tco_list[ino].transpose((1, 2, 0)) src_h, src_w, ratio_h, ratio_w = shape_list[ino] - poly_list = self.detect_sast(p_score, p_tvo, p_border, p_tco, ratio_w, ratio_h, src_w, src_h, - shrink_ratio_of_width=self.shrink_ratio_of_width, - tcl_map_thresh=self.tcl_map_thresh, offset_expand=self.expand_scale) + poly_list = self.detect_sast( + p_score, + p_tvo, + p_border, + p_tco, + ratio_w, + ratio_h, + src_w, + src_h, + shrink_ratio_of_width=self.shrink_ratio_of_width, + tcl_map_thresh=self.tcl_map_thresh, + offset_expand=self.expand_scale) poly_lists.append({'points': np.array(poly_list)}) return poly_lists - diff --git a/backend/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py b/backend/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py new file mode 100644 index 00000000..1d55d13d --- /dev/null +++ b/backend/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py @@ -0,0 +1,51 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle + + +class VQAReTokenLayoutLMPostProcess(object): + """ Convert between text-label and text-index """ + + def __init__(self, **kwargs): + super(VQAReTokenLayoutLMPostProcess, self).__init__() + + def __call__(self, preds, label=None, *args, **kwargs): + if label is not None: + return self._metric(preds, label) + else: + return self._infer(preds, *args, **kwargs) + + def _metric(self, preds, label): + return preds['pred_relations'], label[6], label[5] + + def _infer(self, preds, *args, **kwargs): + ser_results = kwargs['ser_results'] + entity_idx_dict_batch = kwargs['entity_idx_dict_batch'] + pred_relations = preds['pred_relations'] + + # merge relations and ocr info + results = [] + for pred_relation, ser_result, entity_idx_dict in zip( + pred_relations, ser_results, entity_idx_dict_batch): + result = [] + used_tail_id = [] + for relation in pred_relation: + if relation['tail_id'] in used_tail_id: + continue + used_tail_id.append(relation['tail_id']) + ocr_info_head = ser_result[entity_idx_dict[relation['head_id']]] + ocr_info_tail = ser_result[entity_idx_dict[relation['tail_id']]] + result.append((ocr_info_head, ocr_info_tail)) + results.append(result) + return results diff --git a/backend/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py b/backend/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py new file mode 100644 index 00000000..782cdea6 --- /dev/null +++ b/backend/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py @@ -0,0 +1,93 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import paddle +from ppocr.utils.utility import load_vqa_bio_label_maps + + +class VQASerTokenLayoutLMPostProcess(object): + """ Convert between text-label and text-index """ + + def __init__(self, class_path, **kwargs): + super(VQASerTokenLayoutLMPostProcess, self).__init__() + label2id_map, self.id2label_map = load_vqa_bio_label_maps(class_path) + + self.label2id_map_for_draw = dict() + for key in label2id_map: + if key.startswith("I-"): + self.label2id_map_for_draw[key] = label2id_map["B" + key[1:]] + else: + self.label2id_map_for_draw[key] = label2id_map[key] + + self.id2label_map_for_show = dict() + for key in self.label2id_map_for_draw: + val = self.label2id_map_for_draw[key] + if key == "O": + self.id2label_map_for_show[val] = key + if key.startswith("B-") or key.startswith("I-"): + self.id2label_map_for_show[val] = key[2:] + else: + self.id2label_map_for_show[val] = key + + def __call__(self, preds, batch=None, *args, **kwargs): + if isinstance(preds, paddle.Tensor): + preds = preds.numpy() + + if batch is not None: + return self._metric(preds, batch[1]) + else: + return self._infer(preds, **kwargs) + + def _metric(self, preds, label): + pred_idxs = preds.argmax(axis=2) + decode_out_list = [[] for _ in range(pred_idxs.shape[0])] + label_decode_out_list = [[] for _ in range(pred_idxs.shape[0])] + + for i in range(pred_idxs.shape[0]): + for j in range(pred_idxs.shape[1]): + if label[i, j] != -100: + label_decode_out_list[i].append(self.id2label_map[label[i, + j]]) + decode_out_list[i].append(self.id2label_map[pred_idxs[i, + j]]) + return decode_out_list, label_decode_out_list + + def _infer(self, preds, attention_masks, segment_offset_ids, ocr_infos): + results = [] + + for pred, attention_mask, segment_offset_id, ocr_info in zip( + preds, attention_masks, segment_offset_ids, ocr_infos): + pred = np.argmax(pred, axis=1) + pred = [self.id2label_map[idx] for idx in pred] + + for idx in range(len(segment_offset_id)): + if idx == 0: + start_id = 0 + else: + start_id = segment_offset_id[idx - 1] + + end_id = segment_offset_id[idx] + + curr_pred = pred[start_id:end_id] + curr_pred = [self.label2id_map_for_draw[p] for p in curr_pred] + + if len(curr_pred) <= 0: + pred_id = 0 + else: + counts = np.bincount(curr_pred) + pred_id = np.argmax(counts) + ocr_info[idx]["pred_id"] = int(pred_id) + ocr_info[idx]["pred"] = self.id2label_map_for_show[int(pred_id)] + results.append(ocr_info) + return results diff --git a/backend/ppocr/utils/dict/arabic_dict.txt b/backend/ppocr/utils/dict/arabic_dict.txt new file mode 100644 index 00000000..916d421c --- /dev/null +++ b/backend/ppocr/utils/dict/arabic_dict.txt @@ -0,0 +1,161 @@ +! +# +$ +% +& +' +( ++ +, +- +. +/ +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 +: +? +@ +A +B +C +D +E +F +G +H +I +J +K +L +M +N +O +P +Q +R +S +T +U +V +W +X +Y +Z +_ +a +b +c +d +e +f +g +h +i +j +k +l +m +n +o +p +q +r +s +t +u +v +w +x +y +z +É +é +ء +آ +أ +ؤ +إ +ئ +ا +ب +ة +ت +ث +ج +ح +خ +د +ذ +ر +ز +س +ش +ص +ض +ط +ظ +ع +غ +ف +ق +ك +ل +م +ن +ه +و +ى +ي +ً +ٌ +ٍ +َ +ُ +ِ +ّ +ْ +ٓ +ٔ +ٰ +ٱ +ٹ +پ +چ +ڈ +ڑ +ژ +ک +ڭ +گ +ں +ھ +ۀ +ہ +ۂ +ۃ +ۆ +ۇ +ۈ +ۋ +ی +ې +ے +ۓ +ە +١ +٢ +٣ +٤ +٥ +٦ +٧ +٨ +٩ diff --git a/backend/ppocr/utils/ppocr_keys_v1.txt b/backend/ppocr/utils/dict/ch_dict.txt similarity index 100% rename from backend/ppocr/utils/ppocr_keys_v1.txt rename to backend/ppocr/utils/dict/ch_dict.txt diff --git a/backend/ppocr/utils/dict/ch_tra_dict.txt b/backend/ppocr/utils/dict/chinese_cht_dict.txt similarity index 100% rename from backend/ppocr/utils/dict/ch_tra_dict.txt rename to backend/ppocr/utils/dict/chinese_cht_dict.txt diff --git a/backend/ppocr/utils/dict/cyrillic_dict.txt b/backend/ppocr/utils/dict/cyrillic_dict.txt new file mode 100644 index 00000000..2b6f6649 --- /dev/null +++ b/backend/ppocr/utils/dict/cyrillic_dict.txt @@ -0,0 +1,163 @@ + +! +# +$ +% +& +' +( ++ +, +- +. +/ +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 +: +? +@ +A +B +C +D +E +F +G +H +I +J +K +L +M +N +O +P +Q +R +S +T +U +V +W +X +Y +Z +_ +a +b +c +d +e +f +g +h +i +j +k +l +m +n +o +p +q +r +s +t +u +v +w +x +y +z +É +é +Ё +Є +І +Ј +Љ +Ў +А +Б +В +Г +Д +Е +Ж +З +И +Й +К +Л +М +Н +О +П +Р +С +Т +У +Ф +Х +Ц +Ч +Ш +Щ +Ъ +Ы +Ь +Э +Ю +Я +а +б +в +г +д +е +ж +з +и +й +к +л +м +н +о +п +р +с +т +у +ф +х +ц +ч +ш +щ +ъ +ы +ь +э +ю +я +ё +ђ +є +і +ј +љ +њ +ћ +ў +џ +Ґ +ґ diff --git a/backend/ppocr/utils/dict/devanagari_dict.txt b/backend/ppocr/utils/dict/devanagari_dict.txt new file mode 100644 index 00000000..f5592306 --- /dev/null +++ b/backend/ppocr/utils/dict/devanagari_dict.txt @@ -0,0 +1,167 @@ + +! +# +$ +% +& +' +( ++ +, +- +. +/ +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 +: +? +@ +A +B +C +D +E +F +G +H +I +J +K +L +M +N +O +P +Q +R +S +T +U +V +W +X +Y +Z +_ +a +b +c +d +e +f +g +h +i +j +k +l +m +n +o +p +q +r +s +t +u +v +w +x +y +z +É +é +ँ +ं +ः +अ +आ +इ +ई +उ +ऊ +ऋ +ए +ऐ +ऑ +ओ +औ +क +ख +ग +घ +ङ +च +छ +ज +झ +ञ +ट +ठ +ड +ढ +ण +त +थ +द +ध +न +ऩ +प +फ +ब +भ +म +य +र +ऱ +ल +ळ +व +श +ष +स +ह +़ +ा +ि +ी +ु +ू +ृ +ॅ +े +ै +ॉ +ो +ौ +् +॒ +क़ +ख़ +ग़ +ज़ +ड़ +ढ़ +फ़ +ॠ +। +० +१ +२ +३ +४ +५ +६ +७ +८ +९ +॰ diff --git a/backend/ppocr/utils/dict/en_dict.txt b/backend/ppocr/utils/dict/en_dict.txt index 6fbd99f4..7677d31b 100644 --- a/backend/ppocr/utils/dict/en_dict.txt +++ b/backend/ppocr/utils/dict/en_dict.txt @@ -8,32 +8,13 @@ 7 8 9 -a -b -c -d -e -f -g -h -i -j -k -l -m -n -o -p -q -r -s -t -u -v -w -x -y -z +: +; +< += +> +? +@ A B C @@ -60,4 +41,55 @@ W X Y Z +[ +\ +] +^ +_ +` +a +b +c +d +e +f +g +h +i +j +k +l +m +n +o +p +q +r +s +t +u +v +w +x +y +z +{ +| +} +~ +! +" +# +$ +% +& +' +( +) +* ++ +, +- +. +/ diff --git a/backend/ppocr/utils/dict/ka_dict.txt b/backend/ppocr/utils/dict/ka_dict.txt new file mode 100644 index 00000000..d506b691 --- /dev/null +++ b/backend/ppocr/utils/dict/ka_dict.txt @@ -0,0 +1,153 @@ +k +a +_ +i +m +g +/ +1 +2 +I +L +S +V +R +C +0 +v +l +6 +4 +8 +. +j +p +ಗ +ು +ಣ +ಪ +ಡ +ಿ +ಸ +ಲ +ಾ +ದ +್ +7 +5 +3 +ವ +ಷ +ಬ +ಹ +ೆ +9 +ಅ +ಳ +ನ +ರ +ಉ +ಕ +ಎ +ೇ +ಂ +ೈ +ೊ +ೀ +ಯ +ೋ +ತ +ಶ +ಭ +ಧ +ಚ +ಜ +ೂ +ಮ +ಒ +ೃ +ಥ +ಇ +ಟ +ಖ +ಆ +ಞ +ಫ +- +ಢ +ಊ +ಓ +ಐ +ಃ +ಘ +ಝ +ೌ +ಠ +ಛ +ಔ +ಏ +ಈ +ಋ +೨ +೦ +೧ +೮ +೯ +೪ +, +೫ +೭ +೩ +೬ +ಙ +s +c +e +n +w +o +u +t +d +E +A +T +B +Z +N +G +O +q +z +r +x +P +K +M +J +U +D +f +F +h +b +W +Y +y +H +X +Q +' +# +& +! +@ +$ +: +% +é +É +( +? ++ + diff --git a/backend/ppocr/utils/dict/kie_dict/xfund_class_list.txt b/backend/ppocr/utils/dict/kie_dict/xfund_class_list.txt new file mode 100644 index 00000000..faded9f9 --- /dev/null +++ b/backend/ppocr/utils/dict/kie_dict/xfund_class_list.txt @@ -0,0 +1,4 @@ +OTHER +QUESTION +ANSWER +HEADER diff --git a/backend/ppocr/utils/dict/latin_dict.txt b/backend/ppocr/utils/dict/latin_dict.txt new file mode 100644 index 00000000..e166bf33 --- /dev/null +++ b/backend/ppocr/utils/dict/latin_dict.txt @@ -0,0 +1,185 @@ + +! +" +# +$ +% +& +' +( +) +* ++ +, +- +. +/ +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 +: +; +< += +> +? +@ +A +B +C +D +E +F +G +H +I +J +K +L +M +N +O +P +Q +R +S +T +U +V +W +X +Y +Z +[ +] +_ +` +a +b +c +d +e +f +g +h +i +j +k +l +m +n +o +p +q +r +s +t +u +v +w +x +y +z +{ +} +¡ +£ +§ +ª +« +­ +° +² +³ +´ +µ +· +º +» +¿ +À +Á + +Ä +Å +Ç +È +É +Ê +Ë +Ì +Í +Î +Ï +Ò +Ó +Ô +Õ +Ö +Ú +Ü +Ý +ß +à +á +â +ã +ä +å +æ +ç +è +é +ê +ë +ì +í +î +ï +ñ +ò +ó +ô +õ +ö +ø +ù +ú +û +ü +ý +ą +Ć +ć +Č +č +Đ +đ +ę +ı +Ł +ł +ō +Œ +œ +Š +š +Ÿ +Ž +ž +ʒ +β +δ +ε +з +Ṡ +‘ +€ +™ diff --git a/backend/ppocr/utils/dict/layout_dict/layout_cdla_dict.txt b/backend/ppocr/utils/dict/layout_dict/layout_cdla_dict.txt new file mode 100644 index 00000000..8be0f486 --- /dev/null +++ b/backend/ppocr/utils/dict/layout_dict/layout_cdla_dict.txt @@ -0,0 +1,10 @@ +text +title +figure +figure_caption +table +table_caption +header +footer +reference +equation \ No newline at end of file diff --git a/backend/ppocr/utils/dict/layout_dict/layout_publaynet_dict.txt b/backend/ppocr/utils/dict/layout_dict/layout_publaynet_dict.txt new file mode 100644 index 00000000..ca6acf4e --- /dev/null +++ b/backend/ppocr/utils/dict/layout_dict/layout_publaynet_dict.txt @@ -0,0 +1,5 @@ +text +title +list +table +figure \ No newline at end of file diff --git a/backend/ppocr/utils/dict/layout_dict/layout_table_dict.txt b/backend/ppocr/utils/dict/layout_dict/layout_table_dict.txt new file mode 100644 index 00000000..faea15ea --- /dev/null +++ b/backend/ppocr/utils/dict/layout_dict/layout_table_dict.txt @@ -0,0 +1 @@ +table \ No newline at end of file diff --git a/backend/ppocr/utils/dict/pu_dict.txt b/backend/ppocr/utils/dict/pu_dict.txt new file mode 100644 index 00000000..9500fae6 --- /dev/null +++ b/backend/ppocr/utils/dict/pu_dict.txt @@ -0,0 +1,130 @@ +p +u +_ +i +m +g +/ +8 +I +L +S +V +R +C +2 +0 +1 +v +a +l +6 +7 +4 +5 +. +j + +q +e +s +t +ã +o +x +9 +c +n +r +z +ç +õ +3 +A +U +d +º +ô +­ +, +E +; +ó +á +b +D +? +ú +ê +- +h +P +f +à +N +í +O +M +G +É +é +â +F +: +T +Á +" +Q +) +W +J +B +H +( +ö +% +Ö +« +w +K +y +! +k +] +' +Z ++ +Ç +Õ +Y +À +X +µ +» +ª +Í +ü +ä +´ +è +ñ +ß +ï +Ú +ë +Ô +Ï +Ó +[ +Ì +< + +ò +§ +³ +ø +å +# +$ +& +@ diff --git a/backend/ppocr/utils/dict/rs_dict.txt b/backend/ppocr/utils/dict/rs_dict.txt new file mode 100644 index 00000000..d1ce46d2 --- /dev/null +++ b/backend/ppocr/utils/dict/rs_dict.txt @@ -0,0 +1,91 @@ +r +s +_ +i +m +g +/ +1 +I +L +S +V +R +C +2 +0 +v +a +l +7 +5 +8 +6 +. +j +p + +t +d +9 +3 +e +š +4 +k +u +ć +c +n +đ +o +z +č +b +ž +f +Z +T +h +M +F +O +Š +B +H +A +E +Đ +Ž +D +P +G +Č +K +U +N +J +Ć +w +y +W +x +Y +X +q +Q +# +& +$ +, +- +% +' +@ +! +: +? +( +É +é ++ diff --git a/backend/ppocr/utils/dict/rsc_dict.txt b/backend/ppocr/utils/dict/rsc_dict.txt new file mode 100644 index 00000000..95dd4636 --- /dev/null +++ b/backend/ppocr/utils/dict/rsc_dict.txt @@ -0,0 +1,134 @@ +r +s +c +_ +i +m +g +/ +5 +I +L +S +V +R +C +2 +0 +1 +v +a +l +9 +7 +8 +. +j +p +м +а +с +и +р +ћ +е +ш +3 +4 +о +г +н +з +в +л +6 +т +ж +у +к +п +њ +д +ч +С +ј +ф +ц +љ +х +О +И +А +б +Ш +К +ђ +џ +М +В +З +Д +Р +У +Н +Т +Б +? +П +Х +Ј +Ц +Г +Љ +Л +Ф +e +n +w +E +F +A +N +f +o +b +M +G +t +y +W +k +P +u +H +B +T +z +h +O +Y +d +U +K +D +x +X +J +Z +Q +q +' +- +@ +é +# +! +, +% +$ +: +& ++ +( +É + diff --git a/backend/ppocr/utils/dict/ru_dict.txt b/backend/ppocr/utils/dict/ru_dict.txt index 3b0cf3a8..aff9c16e 100644 --- a/backend/ppocr/utils/dict/ru_dict.txt +++ b/backend/ppocr/utils/dict/ru_dict.txt @@ -1,65 +1,16 @@ -к -в -а -з -и -у -р -о -н -я -х -п -л -ы -г -е -т -м -д -ж -ш -ь -с -ё -б -й -ч -ю -ц -щ -М -э -ф -А -ъ -С -Ф -Ю -В -К -Т -Н -О -Э -У -И -Г -Л -Р -Д -Б -Ш -П -З -Х -Е -Ж -Я -Ц -Ч -Й -Щ + +! +# +$ +% +& +' +( ++ +, +- +. +/ 0 1 2 @@ -70,32 +21,9 @@ 7 8 9 -a -b -c -d -e -f -g -h -i -j -k -l -m -n -o -p -q -r -s -t -u -v -w -x -y -z +: +? +@ A B C @@ -122,4 +50,114 @@ W X Y Z - +_ +a +b +c +d +e +f +g +h +i +j +k +l +m +n +o +p +q +r +s +t +u +v +w +x +y +z +É +é +Ё +Є +І +Ј +Љ +Ў +А +Б +В +Г +Д +Е +Ж +З +И +Й +К +Л +М +Н +О +П +Р +С +Т +У +Ф +Х +Ц +Ч +Ш +Щ +Ъ +Ы +Ь +Э +Ю +Я +а +б +в +г +д +е +ж +з +и +й +к +л +м +н +о +п +р +с +т +у +ф +х +ц +ч +ш +щ +ъ +ы +ь +э +ю +я +ё +ђ +є +і +ј +љ +њ +ћ +ў +џ +Ґ +ґ diff --git a/backend/ppocr/utils/ic15_dict.txt b/backend/ppocr/utils/dict/spin_dict.txt similarity index 51% rename from backend/ppocr/utils/ic15_dict.txt rename to backend/ppocr/utils/dict/spin_dict.txt index 47406036..8ee8347f 100644 --- a/backend/ppocr/utils/ic15_dict.txt +++ b/backend/ppocr/utils/dict/spin_dict.txt @@ -33,4 +33,36 @@ v w x y -z \ No newline at end of file +z +: +( +' +- +, +% +> +. +[ +? +) +" += +_ +* +] +; +& ++ +$ +@ +/ +| +! +< +# +` +{ +~ +\ +} +^ \ No newline at end of file diff --git a/backend/ppocr/utils/dict/ta_dict.txt b/backend/ppocr/utils/dict/ta_dict.txt index d1bae501..19d81892 100644 --- a/backend/ppocr/utils/dict/ta_dict.txt +++ b/backend/ppocr/utils/dict/ta_dict.txt @@ -22,7 +22,7 @@ l 8 . j -p +p ப ூ த diff --git a/backend/ppocr/utils/dict/table_dict.txt b/backend/ppocr/utils/dict/table_dict.txt new file mode 100644 index 00000000..2ef028c7 --- /dev/null +++ b/backend/ppocr/utils/dict/table_dict.txt @@ -0,0 +1,277 @@ +← + +☆ +─ +α + + +⋅ +$ +ω +ψ +χ +( +υ +≥ +σ +, +ρ +ε +0 +■ +4 +8 +✗ +b +< +✓ +Ψ +Ω +€ +D +3 +Π +H +║ + +L +Φ +Χ +θ +P +κ +λ +μ +T +ξ +X +β +γ +δ +\ +ζ +η +` +d + +h +f +l +Θ +p +√ +t + +x +Β +Γ +Δ +| +ǂ +ɛ +j +̧ +➢ +⁡ +̌ +′ +« +△ +▲ +# + +' +Ι ++ +¶ +/ +▼ +⇑ +□ +· +7 +▪ +; +? +➔ +∩ +C +÷ +G +⇒ +K + +O +S +С +W +Α +[ +○ +_ +● +‡ +c +z +g + +o + +〈 +〉 +s +⩽ +w +φ +ʹ +{ +» +∣ +̆ +e +ˆ +∈ +τ +◆ +ι +∅ +∆ +∙ +∘ +Ø +ß +✔ +∞ +∑ +− +× +◊ +∗ +∖ +˃ +˂ +∫ +" +i +& +π +↔ +* +∥ +æ +∧ +. +⁄ +ø +Q +∼ +6 +⁎ +: +★ +> +a +B +≈ +F +J +̄ +N +♯ +R +V + +― +Z +♣ +^ +¤ +¥ +§ + +¢ +£ +≦ +­ +≤ +‖ +Λ +© +n +↓ +→ +↑ +r +° +± +v + +♂ +k +♀ +~ +ᅟ +̇ +@ +” +♦ +ł +® +⊕ +„ +! + +% +⇓ +) +- +1 +5 +9 += +А +A +‰ +⋆ +Σ +E +◦ +I +※ +M +m +̨ +⩾ +† + +• +U +Y +
 +] +̸ +2 +‐ +– +‒ +̂ +— +̀ +́ +’ +‘ +⋮ +⋯ +̊ +“ +̈ +≧ +q +u +ı +y + +​ +̃ +} +ν diff --git a/backend/ppocr/utils/dict/table_master_structure_dict.txt b/backend/ppocr/utils/dict/table_master_structure_dict.txt new file mode 100644 index 00000000..95ab2539 --- /dev/null +++ b/backend/ppocr/utils/dict/table_master_structure_dict.txt @@ -0,0 +1,39 @@ + + + + + + + + + + + colspan="2" + colspan="3" + + + rowspan="2" + colspan="4" + colspan="6" + rowspan="3" + colspan="9" + colspan="10" + colspan="7" + rowspan="4" + rowspan="5" + rowspan="9" + colspan="8" + rowspan="8" + rowspan="6" + rowspan="7" + rowspan="10" + + + + + + + + diff --git a/backend/ppocr/utils/dict/table_structure_dict.txt b/backend/ppocr/utils/dict/table_structure_dict.txt new file mode 100644 index 00000000..8edb10b8 --- /dev/null +++ b/backend/ppocr/utils/dict/table_structure_dict.txt @@ -0,0 +1,28 @@ + + + + + + + + + + colspan="2" + colspan="3" + rowspan="2" + colspan="4" + colspan="6" + rowspan="3" + colspan="9" + colspan="10" + colspan="7" + rowspan="4" + rowspan="5" + rowspan="9" + colspan="8" + rowspan="8" + rowspan="6" + rowspan="7" + rowspan="10" \ No newline at end of file diff --git a/backend/ppocr/utils/dict/table_structure_dict_ch.txt b/backend/ppocr/utils/dict/table_structure_dict_ch.txt new file mode 100644 index 00000000..0c59c0e9 --- /dev/null +++ b/backend/ppocr/utils/dict/table_structure_dict_ch.txt @@ -0,0 +1,48 @@ + + + + + + + + + + colspan="2" + colspan="3" + colspan="4" + colspan="5" + colspan="6" + colspan="7" + colspan="8" + colspan="9" + colspan="10" + colspan="11" + colspan="12" + colspan="13" + colspan="14" + colspan="15" + colspan="16" + colspan="17" + colspan="18" + colspan="19" + colspan="20" + rowspan="2" + rowspan="3" + rowspan="4" + rowspan="5" + rowspan="6" + rowspan="7" + rowspan="8" + rowspan="9" + rowspan="10" + rowspan="11" + rowspan="12" + rowspan="13" + rowspan="14" + rowspan="15" + rowspan="16" + rowspan="17" + rowspan="18" + rowspan="19" + rowspan="20" diff --git a/backend/ppocr/utils/dict/xi_dict.txt b/backend/ppocr/utils/dict/xi_dict.txt new file mode 100644 index 00000000..f195f1ea --- /dev/null +++ b/backend/ppocr/utils/dict/xi_dict.txt @@ -0,0 +1,110 @@ +x +i +_ +m +g +/ +1 +0 +I +L +S +V +R +C +2 +v +a +l +3 +6 +4 +5 +. +j +p + +Q +u +e +r +o +8 +7 +n +c +9 +t +b +é +q +d +ó +y +F +s +, +O +í +T +f +" +U +M +h +: +P +H +A +E +D +z +N +á +ñ +ú +% +; +è ++ +Y +- +B +G +( +) +¿ +? +w +¡ +! +X +É +K +k +Á +ü +Ú +« +» +J +' +ö +W +Z +º +Ö +­ +[ +] +Ç +ç +à +ä +û +ò +Í +ê +ô +ø +ª diff --git a/backend/ppocr/utils/e2e_metric/Deteval.py b/backend/ppocr/utils/e2e_metric/Deteval.py new file mode 100755 index 00000000..45567a7d --- /dev/null +++ b/backend/ppocr/utils/e2e_metric/Deteval.py @@ -0,0 +1,574 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import scipy.io as io +from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area + + +def get_socre_A(gt_dir, pred_dict): + allInputs = 1 + + def input_reading_mod(pred_dict): + """This helper reads input from txt files""" + det = [] + n = len(pred_dict) + for i in range(n): + points = pred_dict[i]['points'] + text = pred_dict[i]['texts'] + point = ",".join(map(str, points.reshape(-1, ))) + det.append([point, text]) + return det + + def gt_reading_mod(gt_dict): + """This helper reads groundtruths from mat files""" + gt = [] + n = len(gt_dict) + for i in range(n): + points = gt_dict[i]['points'].tolist() + h = len(points) + text = gt_dict[i]['text'] + xx = [ + np.array( + ['x:'], dtype=' 1): + gt_x = list(map(int, np.squeeze(gt[1]))) + gt_y = list(map(int, np.squeeze(gt[3]))) + for det_id, detection in enumerate(detections): + detection_orig = detection + detection = [float(x) for x in detection[0].split(',')] + detection = list(map(int, detection)) + det_x = detection[0::2] + det_y = detection[1::2] + det_gt_iou = iod(det_x, det_y, gt_x, gt_y) + if det_gt_iou > threshold: + detections[det_id] = [] + + detections[:] = [item for item in detections if item != []] + return detections + + def sigma_calculation(det_x, det_y, gt_x, gt_y): + """ + sigma = inter_area / gt_area + """ + return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) / + area(gt_x, gt_y)), 2) + + def tau_calculation(det_x, det_y, gt_x, gt_y): + if area(det_x, det_y) == 0.0: + return 0 + return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) / + area(det_x, det_y)), 2) + + ##############################Initialization################################### + # global_sigma = [] + # global_tau = [] + # global_pred_str = [] + # global_gt_str = [] + ############################################################################### + + for input_id in range(allInputs): + if (input_id != '.DS_Store') and (input_id != 'Pascal_result.txt') and ( + input_id != 'Pascal_result_curved.txt') and (input_id != 'Pascal_result_non_curved.txt') and ( + input_id != 'Deteval_result.txt') and (input_id != 'Deteval_result_curved.txt') \ + and (input_id != 'Deteval_result_non_curved.txt'): + detections = input_reading_mod(pred_dict) + groundtruths = gt_reading_mod(gt_dir) + detections = detection_filtering( + detections, + groundtruths) # filters detections overlapping with DC area + dc_id = [] + for i in range(len(groundtruths)): + if groundtruths[i][5] == '#': + dc_id.append(i) + cnt = 0 + for a in dc_id: + num = a - cnt + del groundtruths[num] + cnt += 1 + + local_sigma_table = np.zeros((len(groundtruths), len(detections))) + local_tau_table = np.zeros((len(groundtruths), len(detections))) + local_pred_str = {} + local_gt_str = {} + + for gt_id, gt in enumerate(groundtruths): + if len(detections) > 0: + for det_id, detection in enumerate(detections): + detection_orig = detection + detection = [float(x) for x in detection[0].split(',')] + detection = list(map(int, detection)) + pred_seq_str = detection_orig[1].strip() + det_x = detection[0::2] + det_y = detection[1::2] + gt_x = list(map(int, np.squeeze(gt[1]))) + gt_y = list(map(int, np.squeeze(gt[3]))) + gt_seq_str = str(gt[4].tolist()[0]) + + local_sigma_table[gt_id, det_id] = sigma_calculation( + det_x, det_y, gt_x, gt_y) + local_tau_table[gt_id, det_id] = tau_calculation( + det_x, det_y, gt_x, gt_y) + local_pred_str[det_id] = pred_seq_str + local_gt_str[gt_id] = gt_seq_str + + global_sigma = local_sigma_table + global_tau = local_tau_table + global_pred_str = local_pred_str + global_gt_str = local_gt_str + + single_data = {} + single_data['sigma'] = global_sigma + single_data['global_tau'] = global_tau + single_data['global_pred_str'] = global_pred_str + single_data['global_gt_str'] = global_gt_str + return single_data + + +def get_socre_B(gt_dir, img_id, pred_dict): + allInputs = 1 + + def input_reading_mod(pred_dict): + """This helper reads input from txt files""" + det = [] + n = len(pred_dict) + for i in range(n): + points = pred_dict[i]['points'] + text = pred_dict[i]['texts'] + point = ",".join(map(str, points.reshape(-1, ))) + det.append([point, text]) + return det + + def gt_reading_mod(gt_dir, gt_id): + gt = io.loadmat('%s/poly_gt_img%s.mat' % (gt_dir, gt_id)) + gt = gt['polygt'] + return gt + + def detection_filtering(detections, groundtruths, threshold=0.5): + for gt_id, gt in enumerate(groundtruths): + if (gt[5] == '#') and (gt[1].shape[1] > 1): + gt_x = list(map(int, np.squeeze(gt[1]))) + gt_y = list(map(int, np.squeeze(gt[3]))) + for det_id, detection in enumerate(detections): + detection_orig = detection + detection = [float(x) for x in detection[0].split(',')] + detection = list(map(int, detection)) + det_x = detection[0::2] + det_y = detection[1::2] + det_gt_iou = iod(det_x, det_y, gt_x, gt_y) + if det_gt_iou > threshold: + detections[det_id] = [] + + detections[:] = [item for item in detections if item != []] + return detections + + def sigma_calculation(det_x, det_y, gt_x, gt_y): + """ + sigma = inter_area / gt_area + """ + return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) / + area(gt_x, gt_y)), 2) + + def tau_calculation(det_x, det_y, gt_x, gt_y): + if area(det_x, det_y) == 0.0: + return 0 + return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) / + area(det_x, det_y)), 2) + + ##############################Initialization################################### + # global_sigma = [] + # global_tau = [] + # global_pred_str = [] + # global_gt_str = [] + ############################################################################### + + for input_id in range(allInputs): + if (input_id != '.DS_Store') and (input_id != 'Pascal_result.txt') and ( + input_id != 'Pascal_result_curved.txt') and (input_id != 'Pascal_result_non_curved.txt') and ( + input_id != 'Deteval_result.txt') and (input_id != 'Deteval_result_curved.txt') \ + and (input_id != 'Deteval_result_non_curved.txt'): + detections = input_reading_mod(pred_dict) + groundtruths = gt_reading_mod(gt_dir, img_id).tolist() + detections = detection_filtering( + detections, + groundtruths) # filters detections overlapping with DC area + dc_id = [] + for i in range(len(groundtruths)): + if groundtruths[i][5] == '#': + dc_id.append(i) + cnt = 0 + for a in dc_id: + num = a - cnt + del groundtruths[num] + cnt += 1 + + local_sigma_table = np.zeros((len(groundtruths), len(detections))) + local_tau_table = np.zeros((len(groundtruths), len(detections))) + local_pred_str = {} + local_gt_str = {} + + for gt_id, gt in enumerate(groundtruths): + if len(detections) > 0: + for det_id, detection in enumerate(detections): + detection_orig = detection + detection = [float(x) for x in detection[0].split(',')] + detection = list(map(int, detection)) + pred_seq_str = detection_orig[1].strip() + det_x = detection[0::2] + det_y = detection[1::2] + gt_x = list(map(int, np.squeeze(gt[1]))) + gt_y = list(map(int, np.squeeze(gt[3]))) + gt_seq_str = str(gt[4].tolist()[0]) + + local_sigma_table[gt_id, det_id] = sigma_calculation( + det_x, det_y, gt_x, gt_y) + local_tau_table[gt_id, det_id] = tau_calculation( + det_x, det_y, gt_x, gt_y) + local_pred_str[det_id] = pred_seq_str + local_gt_str[gt_id] = gt_seq_str + + global_sigma = local_sigma_table + global_tau = local_tau_table + global_pred_str = local_pred_str + global_gt_str = local_gt_str + + single_data = {} + single_data['sigma'] = global_sigma + single_data['global_tau'] = global_tau + single_data['global_pred_str'] = global_pred_str + single_data['global_gt_str'] = global_gt_str + return single_data + + +def combine_results(all_data): + tr = 0.7 + tp = 0.6 + fsc_k = 0.8 + k = 2 + global_sigma = [] + global_tau = [] + global_pred_str = [] + global_gt_str = [] + for data in all_data: + global_sigma.append(data['sigma']) + global_tau.append(data['global_tau']) + global_pred_str.append(data['global_pred_str']) + global_gt_str.append(data['global_gt_str']) + + global_accumulative_recall = 0 + global_accumulative_precision = 0 + total_num_gt = 0 + total_num_det = 0 + hit_str_count = 0 + hit_count = 0 + + def one_to_one(local_sigma_table, local_tau_table, + local_accumulative_recall, local_accumulative_precision, + global_accumulative_recall, global_accumulative_precision, + gt_flag, det_flag, idy): + hit_str_num = 0 + for gt_id in range(num_gt): + gt_matching_qualified_sigma_candidates = np.where( + local_sigma_table[gt_id, :] > tr) + gt_matching_num_qualified_sigma_candidates = gt_matching_qualified_sigma_candidates[ + 0].shape[0] + gt_matching_qualified_tau_candidates = np.where( + local_tau_table[gt_id, :] > tp) + gt_matching_num_qualified_tau_candidates = gt_matching_qualified_tau_candidates[ + 0].shape[0] + + det_matching_qualified_sigma_candidates = np.where( + local_sigma_table[:, gt_matching_qualified_sigma_candidates[0]] + > tr) + det_matching_num_qualified_sigma_candidates = det_matching_qualified_sigma_candidates[ + 0].shape[0] + det_matching_qualified_tau_candidates = np.where( + local_tau_table[:, gt_matching_qualified_tau_candidates[0]] > + tp) + det_matching_num_qualified_tau_candidates = det_matching_qualified_tau_candidates[ + 0].shape[0] + + if (gt_matching_num_qualified_sigma_candidates == 1) and (gt_matching_num_qualified_tau_candidates == 1) and \ + (det_matching_num_qualified_sigma_candidates == 1) and ( + det_matching_num_qualified_tau_candidates == 1): + global_accumulative_recall = global_accumulative_recall + 1.0 + global_accumulative_precision = global_accumulative_precision + 1.0 + local_accumulative_recall = local_accumulative_recall + 1.0 + local_accumulative_precision = local_accumulative_precision + 1.0 + + gt_flag[0, gt_id] = 1 + matched_det_id = np.where(local_sigma_table[gt_id, :] > tr) + # recg start + gt_str_cur = global_gt_str[idy][gt_id] + pred_str_cur = global_pred_str[idy][matched_det_id[0].tolist()[ + 0]] + if pred_str_cur == gt_str_cur: + hit_str_num += 1 + else: + if pred_str_cur.lower() == gt_str_cur.lower(): + hit_str_num += 1 + # recg end + det_flag[0, matched_det_id] = 1 + return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num + + def one_to_many(local_sigma_table, local_tau_table, + local_accumulative_recall, local_accumulative_precision, + global_accumulative_recall, global_accumulative_precision, + gt_flag, det_flag, idy): + hit_str_num = 0 + for gt_id in range(num_gt): + # skip the following if the groundtruth was matched + if gt_flag[0, gt_id] > 0: + continue + + non_zero_in_sigma = np.where(local_sigma_table[gt_id, :] > 0) + num_non_zero_in_sigma = non_zero_in_sigma[0].shape[0] + + if num_non_zero_in_sigma >= k: + ####search for all detections that overlaps with this groundtruth + qualified_tau_candidates = np.where((local_tau_table[ + gt_id, :] >= tp) & (det_flag[0, :] == 0)) + num_qualified_tau_candidates = qualified_tau_candidates[ + 0].shape[0] + + if num_qualified_tau_candidates == 1: + if ((local_tau_table[gt_id, qualified_tau_candidates] >= tp) + and + (local_sigma_table[gt_id, qualified_tau_candidates] >= + tr)): + # became an one-to-one case + global_accumulative_recall = global_accumulative_recall + 1.0 + global_accumulative_precision = global_accumulative_precision + 1.0 + local_accumulative_recall = local_accumulative_recall + 1.0 + local_accumulative_precision = local_accumulative_precision + 1.0 + + gt_flag[0, gt_id] = 1 + det_flag[0, qualified_tau_candidates] = 1 + # recg start + gt_str_cur = global_gt_str[idy][gt_id] + pred_str_cur = global_pred_str[idy][ + qualified_tau_candidates[0].tolist()[0]] + if pred_str_cur == gt_str_cur: + hit_str_num += 1 + else: + if pred_str_cur.lower() == gt_str_cur.lower(): + hit_str_num += 1 + # recg end + elif (np.sum(local_sigma_table[gt_id, qualified_tau_candidates]) + >= tr): + gt_flag[0, gt_id] = 1 + det_flag[0, qualified_tau_candidates] = 1 + # recg start + gt_str_cur = global_gt_str[idy][gt_id] + pred_str_cur = global_pred_str[idy][ + qualified_tau_candidates[0].tolist()[0]] + if pred_str_cur == gt_str_cur: + hit_str_num += 1 + else: + if pred_str_cur.lower() == gt_str_cur.lower(): + hit_str_num += 1 + # recg end + + global_accumulative_recall = global_accumulative_recall + fsc_k + global_accumulative_precision = global_accumulative_precision + num_qualified_tau_candidates * fsc_k + + local_accumulative_recall = local_accumulative_recall + fsc_k + local_accumulative_precision = local_accumulative_precision + num_qualified_tau_candidates * fsc_k + + return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num + + def many_to_one(local_sigma_table, local_tau_table, + local_accumulative_recall, local_accumulative_precision, + global_accumulative_recall, global_accumulative_precision, + gt_flag, det_flag, idy): + hit_str_num = 0 + for det_id in range(num_det): + # skip the following if the detection was matched + if det_flag[0, det_id] > 0: + continue + + non_zero_in_tau = np.where(local_tau_table[:, det_id] > 0) + num_non_zero_in_tau = non_zero_in_tau[0].shape[0] + + if num_non_zero_in_tau >= k: + ####search for all detections that overlaps with this groundtruth + qualified_sigma_candidates = np.where(( + local_sigma_table[:, det_id] >= tp) & (gt_flag[0, :] == 0)) + num_qualified_sigma_candidates = qualified_sigma_candidates[ + 0].shape[0] + + if num_qualified_sigma_candidates == 1: + if ((local_tau_table[qualified_sigma_candidates, det_id] >= + tp) and + (local_sigma_table[qualified_sigma_candidates, det_id] + >= tr)): + # became an one-to-one case + global_accumulative_recall = global_accumulative_recall + 1.0 + global_accumulative_precision = global_accumulative_precision + 1.0 + local_accumulative_recall = local_accumulative_recall + 1.0 + local_accumulative_precision = local_accumulative_precision + 1.0 + + gt_flag[0, qualified_sigma_candidates] = 1 + det_flag[0, det_id] = 1 + # recg start + pred_str_cur = global_pred_str[idy][det_id] + gt_len = len(qualified_sigma_candidates[0]) + for idx in range(gt_len): + ele_gt_id = qualified_sigma_candidates[0].tolist()[ + idx] + if ele_gt_id not in global_gt_str[idy]: + continue + gt_str_cur = global_gt_str[idy][ele_gt_id] + if pred_str_cur == gt_str_cur: + hit_str_num += 1 + break + else: + if pred_str_cur.lower() == gt_str_cur.lower(): + hit_str_num += 1 + break + # recg end + elif (np.sum(local_tau_table[qualified_sigma_candidates, + det_id]) >= tp): + det_flag[0, det_id] = 1 + gt_flag[0, qualified_sigma_candidates] = 1 + # recg start + pred_str_cur = global_pred_str[idy][det_id] + gt_len = len(qualified_sigma_candidates[0]) + for idx in range(gt_len): + ele_gt_id = qualified_sigma_candidates[0].tolist()[idx] + if ele_gt_id not in global_gt_str[idy]: + continue + gt_str_cur = global_gt_str[idy][ele_gt_id] + if pred_str_cur == gt_str_cur: + hit_str_num += 1 + break + else: + if pred_str_cur.lower() == gt_str_cur.lower(): + hit_str_num += 1 + break + # recg end + + global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k + global_accumulative_precision = global_accumulative_precision + fsc_k + + local_accumulative_recall = local_accumulative_recall + num_qualified_sigma_candidates * fsc_k + local_accumulative_precision = local_accumulative_precision + fsc_k + return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num + + for idx in range(len(global_sigma)): + local_sigma_table = np.array(global_sigma[idx]) + local_tau_table = global_tau[idx] + + num_gt = local_sigma_table.shape[0] + num_det = local_sigma_table.shape[1] + + total_num_gt = total_num_gt + num_gt + total_num_det = total_num_det + num_det + + local_accumulative_recall = 0 + local_accumulative_precision = 0 + gt_flag = np.zeros((1, num_gt)) + det_flag = np.zeros((1, num_det)) + + #######first check for one-to-one case########## + local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \ + gt_flag, det_flag, hit_str_num = one_to_one(local_sigma_table, local_tau_table, + local_accumulative_recall, local_accumulative_precision, + global_accumulative_recall, global_accumulative_precision, + gt_flag, det_flag, idx) + + hit_str_count += hit_str_num + #######then check for one-to-many case########## + local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \ + gt_flag, det_flag, hit_str_num = one_to_many(local_sigma_table, local_tau_table, + local_accumulative_recall, local_accumulative_precision, + global_accumulative_recall, global_accumulative_precision, + gt_flag, det_flag, idx) + hit_str_count += hit_str_num + #######then check for many-to-one case########## + local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \ + gt_flag, det_flag, hit_str_num = many_to_one(local_sigma_table, local_tau_table, + local_accumulative_recall, local_accumulative_precision, + global_accumulative_recall, global_accumulative_precision, + gt_flag, det_flag, idx) + hit_str_count += hit_str_num + + try: + recall = global_accumulative_recall / total_num_gt + except ZeroDivisionError: + recall = 0 + + try: + precision = global_accumulative_precision / total_num_det + except ZeroDivisionError: + precision = 0 + + try: + f_score = 2 * precision * recall / (precision + recall) + except ZeroDivisionError: + f_score = 0 + + try: + seqerr = 1 - float(hit_str_count) / global_accumulative_recall + except ZeroDivisionError: + seqerr = 1 + + try: + recall_e2e = float(hit_str_count) / total_num_gt + except ZeroDivisionError: + recall_e2e = 0 + + try: + precision_e2e = float(hit_str_count) / total_num_det + except ZeroDivisionError: + precision_e2e = 0 + + try: + f_score_e2e = 2 * precision_e2e * recall_e2e / ( + precision_e2e + recall_e2e) + except ZeroDivisionError: + f_score_e2e = 0 + + final = { + 'total_num_gt': total_num_gt, + 'total_num_det': total_num_det, + 'global_accumulative_recall': global_accumulative_recall, + 'hit_str_count': hit_str_count, + 'recall': recall, + 'precision': precision, + 'f_score': f_score, + 'seqerr': seqerr, + 'recall_e2e': recall_e2e, + 'precision_e2e': precision_e2e, + 'f_score_e2e': f_score_e2e + } + return final diff --git a/backend/ppocr/utils/e2e_metric/polygon_fast.py b/backend/ppocr/utils/e2e_metric/polygon_fast.py new file mode 100755 index 00000000..81c9ad70 --- /dev/null +++ b/backend/ppocr/utils/e2e_metric/polygon_fast.py @@ -0,0 +1,83 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +from shapely.geometry import Polygon +""" +:param det_x: [1, N] Xs of detection's vertices +:param det_y: [1, N] Ys of detection's vertices +:param gt_x: [1, N] Xs of groundtruth's vertices +:param gt_y: [1, N] Ys of groundtruth's vertices + +############## +All the calculation of 'AREA' in this script is handled by: +1) First generating a binary mask with the polygon area filled up with 1's +2) Summing up all the 1's +""" + + +def area(x, y): + polygon = Polygon(np.stack([x, y], axis=1)) + return float(polygon.area) + + +def approx_area_of_intersection(det_x, det_y, gt_x, gt_y): + """ + This helper determine if both polygons are intersecting with each others with an approximation method. + Area of intersection represented by the minimum bounding rectangular [xmin, ymin, xmax, ymax] + """ + det_ymax = np.max(det_y) + det_xmax = np.max(det_x) + det_ymin = np.min(det_y) + det_xmin = np.min(det_x) + + gt_ymax = np.max(gt_y) + gt_xmax = np.max(gt_x) + gt_ymin = np.min(gt_y) + gt_xmin = np.min(gt_x) + + all_min_ymax = np.minimum(det_ymax, gt_ymax) + all_max_ymin = np.maximum(det_ymin, gt_ymin) + + intersect_heights = np.maximum(0.0, (all_min_ymax - all_max_ymin)) + + all_min_xmax = np.minimum(det_xmax, gt_xmax) + all_max_xmin = np.maximum(det_xmin, gt_xmin) + intersect_widths = np.maximum(0.0, (all_min_xmax - all_max_xmin)) + + return intersect_heights * intersect_widths + + +def area_of_intersection(det_x, det_y, gt_x, gt_y): + p1 = Polygon(np.stack([det_x, det_y], axis=1)).buffer(0) + p2 = Polygon(np.stack([gt_x, gt_y], axis=1)).buffer(0) + return float(p1.intersection(p2).area) + + +def area_of_union(det_x, det_y, gt_x, gt_y): + p1 = Polygon(np.stack([det_x, det_y], axis=1)).buffer(0) + p2 = Polygon(np.stack([gt_x, gt_y], axis=1)).buffer(0) + return float(p1.union(p2).area) + + +def iou(det_x, det_y, gt_x, gt_y): + return area_of_intersection(det_x, det_y, gt_x, gt_y) / ( + area_of_union(det_x, det_y, gt_x, gt_y) + 1.0) + + +def iod(det_x, det_y, gt_x, gt_y): + """ + This helper determine the fraction of intersection area over detection area + """ + return area_of_intersection(det_x, det_y, gt_x, gt_y) / ( + area(det_x, det_y) + 1.0) diff --git a/backend/ppocr/utils/e2e_utils/extract_batchsize.py b/backend/ppocr/utils/e2e_utils/extract_batchsize.py new file mode 100644 index 00000000..e99a833e --- /dev/null +++ b/backend/ppocr/utils/e2e_utils/extract_batchsize.py @@ -0,0 +1,87 @@ +import paddle +import numpy as np +import copy + + +def org_tcl_rois(batch_size, pos_lists, pos_masks, label_lists, tcl_bs): + """ + """ + pos_lists_, pos_masks_, label_lists_ = [], [], [] + img_bs = batch_size + ngpu = int(batch_size / img_bs) + img_ids = np.array(pos_lists, dtype=np.int32)[:, 0, 0].copy() + pos_lists_split, pos_masks_split, label_lists_split = [], [], [] + for i in range(ngpu): + pos_lists_split.append([]) + pos_masks_split.append([]) + label_lists_split.append([]) + + for i in range(img_ids.shape[0]): + img_id = img_ids[i] + gpu_id = int(img_id / img_bs) + img_id = img_id % img_bs + pos_list = pos_lists[i].copy() + pos_list[:, 0] = img_id + pos_lists_split[gpu_id].append(pos_list) + pos_masks_split[gpu_id].append(pos_masks[i].copy()) + label_lists_split[gpu_id].append(copy.deepcopy(label_lists[i])) + # repeat or delete + for i in range(ngpu): + vp_len = len(pos_lists_split[i]) + if vp_len <= tcl_bs: + for j in range(0, tcl_bs - vp_len): + pos_list = pos_lists_split[i][j].copy() + pos_lists_split[i].append(pos_list) + pos_mask = pos_masks_split[i][j].copy() + pos_masks_split[i].append(pos_mask) + label_list = copy.deepcopy(label_lists_split[i][j]) + label_lists_split[i].append(label_list) + else: + for j in range(0, vp_len - tcl_bs): + c_len = len(pos_lists_split[i]) + pop_id = np.random.permutation(c_len)[0] + pos_lists_split[i].pop(pop_id) + pos_masks_split[i].pop(pop_id) + label_lists_split[i].pop(pop_id) + # merge + for i in range(ngpu): + pos_lists_.extend(pos_lists_split[i]) + pos_masks_.extend(pos_masks_split[i]) + label_lists_.extend(label_lists_split[i]) + return pos_lists_, pos_masks_, label_lists_ + + +def pre_process(label_list, pos_list, pos_mask, max_text_length, max_text_nums, + pad_num, tcl_bs): + label_list = label_list.numpy() + batch, _, _, _ = label_list.shape + pos_list = pos_list.numpy() + pos_mask = pos_mask.numpy() + pos_list_t = [] + pos_mask_t = [] + label_list_t = [] + for i in range(batch): + for j in range(max_text_nums): + if pos_mask[i, j].any(): + pos_list_t.append(pos_list[i][j]) + pos_mask_t.append(pos_mask[i][j]) + label_list_t.append(label_list[i][j]) + pos_list, pos_mask, label_list = org_tcl_rois(batch, pos_list_t, pos_mask_t, + label_list_t, tcl_bs) + label = [] + tt = [l.tolist() for l in label_list] + for i in range(tcl_bs): + k = 0 + for j in range(max_text_length): + if tt[i][j][0] != pad_num: + k += 1 + else: + break + label.append(k) + label = paddle.to_tensor(label) + label = paddle.cast(label, dtype='int64') + pos_list = paddle.to_tensor(pos_list) + pos_mask = paddle.to_tensor(pos_mask) + label_list = paddle.squeeze(paddle.to_tensor(label_list), axis=2) + label_list = paddle.cast(label_list, dtype='int32') + return pos_list, pos_mask, label_list, label diff --git a/backend/ppocr/utils/e2e_utils/extract_textpoint_fast.py b/backend/ppocr/utils/e2e_utils/extract_textpoint_fast.py new file mode 100644 index 00000000..787cd301 --- /dev/null +++ b/backend/ppocr/utils/e2e_utils/extract_textpoint_fast.py @@ -0,0 +1,457 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains various CTC decoders.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import cv2 +import math + +import numpy as np +from itertools import groupby +from skimage.morphology._skeletonize import thin + + +def get_dict(character_dict_path): + character_str = "" + with open(character_dict_path, "rb") as fin: + lines = fin.readlines() + for line in lines: + line = line.decode('utf-8').strip("\n").strip("\r\n") + character_str += line + dict_character = list(character_str) + return dict_character + + +def softmax(logits): + """ + logits: N x d + """ + max_value = np.max(logits, axis=1, keepdims=True) + exp = np.exp(logits - max_value) + exp_sum = np.sum(exp, axis=1, keepdims=True) + dist = exp / exp_sum + return dist + + +def get_keep_pos_idxs(labels, remove_blank=None): + """ + Remove duplicate and get pos idxs of keep items. + The value of keep_blank should be [None, 95]. + """ + duplicate_len_list = [] + keep_pos_idx_list = [] + keep_char_idx_list = [] + for k, v_ in groupby(labels): + current_len = len(list(v_)) + if k != remove_blank: + current_idx = int(sum(duplicate_len_list) + current_len // 2) + keep_pos_idx_list.append(current_idx) + keep_char_idx_list.append(k) + duplicate_len_list.append(current_len) + return keep_char_idx_list, keep_pos_idx_list + + +def remove_blank(labels, blank=0): + new_labels = [x for x in labels if x != blank] + return new_labels + + +def insert_blank(labels, blank=0): + new_labels = [blank] + for l in labels: + new_labels += [l, blank] + return new_labels + + +def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True): + """ + CTC greedy (best path) decoder. + """ + raw_str = np.argmax(np.array(probs_seq), axis=1) + remove_blank_in_pos = None if keep_blank_in_idxs else blank + dedup_str, keep_idx_list = get_keep_pos_idxs( + raw_str, remove_blank=remove_blank_in_pos) + dst_str = remove_blank(dedup_str, blank=blank) + return dst_str, keep_idx_list + + +def instance_ctc_greedy_decoder(gather_info, logits_map, pts_num=4): + _, _, C = logits_map.shape + ys, xs = zip(*gather_info) + logits_seq = logits_map[list(ys), list(xs)] + probs_seq = logits_seq + labels = np.argmax(probs_seq, axis=1) + dst_str = [k for k, v_ in groupby(labels) if k != C - 1] + detal = len(gather_info) // (pts_num - 1) + keep_idx_list = [0] + [detal * (i + 1) for i in range(pts_num - 2)] + [-1] + keep_gather_list = [gather_info[idx] for idx in keep_idx_list] + return dst_str, keep_gather_list + + +def ctc_decoder_for_image(gather_info_list, + logits_map, + Lexicon_Table, + pts_num=6): + """ + CTC decoder using multiple processes. + """ + decoder_str = [] + decoder_xys = [] + for gather_info in gather_info_list: + if len(gather_info) < pts_num: + continue + dst_str, xys_list = instance_ctc_greedy_decoder( + gather_info, logits_map, pts_num=pts_num) + dst_str_readable = ''.join([Lexicon_Table[idx] for idx in dst_str]) + if len(dst_str_readable) < 2: + continue + decoder_str.append(dst_str_readable) + decoder_xys.append(xys_list) + return decoder_str, decoder_xys + + +def sort_with_direction(pos_list, f_direction): + """ + f_direction: h x w x 2 + pos_list: [[y, x], [y, x], [y, x] ...] + """ + + def sort_part_with_direction(pos_list, point_direction): + pos_list = np.array(pos_list).reshape(-1, 2) + point_direction = np.array(point_direction).reshape(-1, 2) + average_direction = np.mean(point_direction, axis=0, keepdims=True) + pos_proj_leng = np.sum(pos_list * average_direction, axis=1) + sorted_list = pos_list[np.argsort(pos_proj_leng)].tolist() + sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist() + return sorted_list, sorted_direction + + pos_list = np.array(pos_list).reshape(-1, 2) + point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y + point_direction = point_direction[:, ::-1] # x, y -> y, x + sorted_point, sorted_direction = sort_part_with_direction(pos_list, + point_direction) + + point_num = len(sorted_point) + if point_num >= 16: + middle_num = point_num // 2 + first_part_point = sorted_point[:middle_num] + first_point_direction = sorted_direction[:middle_num] + sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction( + first_part_point, first_point_direction) + + last_part_point = sorted_point[middle_num:] + last_point_direction = sorted_direction[middle_num:] + sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction( + last_part_point, last_point_direction) + sorted_point = sorted_fist_part_point + sorted_last_part_point + sorted_direction = sorted_fist_part_direction + sorted_last_part_direction + + return sorted_point, np.array(sorted_direction) + + +def add_id(pos_list, image_id=0): + """ + Add id for gather feature, for inference. + """ + new_list = [] + for item in pos_list: + new_list.append((image_id, item[0], item[1])) + return new_list + + +def sort_and_expand_with_direction(pos_list, f_direction): + """ + f_direction: h x w x 2 + pos_list: [[y, x], [y, x], [y, x] ...] + """ + h, w, _ = f_direction.shape + sorted_list, point_direction = sort_with_direction(pos_list, f_direction) + + point_num = len(sorted_list) + sub_direction_len = max(point_num // 3, 2) + left_direction = point_direction[:sub_direction_len, :] + right_dirction = point_direction[point_num - sub_direction_len:, :] + + left_average_direction = -np.mean(left_direction, axis=0, keepdims=True) + left_average_len = np.linalg.norm(left_average_direction) + left_start = np.array(sorted_list[0]) + left_step = left_average_direction / (left_average_len + 1e-6) + + right_average_direction = np.mean(right_dirction, axis=0, keepdims=True) + right_average_len = np.linalg.norm(right_average_direction) + right_step = right_average_direction / (right_average_len + 1e-6) + right_start = np.array(sorted_list[-1]) + + append_num = max( + int((left_average_len + right_average_len) / 2.0 * 0.15), 1) + left_list = [] + right_list = [] + for i in range(append_num): + ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype( + 'int32').tolist() + if ly < h and lx < w and (ly, lx) not in left_list: + left_list.append((ly, lx)) + ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype( + 'int32').tolist() + if ry < h and rx < w and (ry, rx) not in right_list: + right_list.append((ry, rx)) + + all_list = left_list[::-1] + sorted_list + right_list + return all_list + + +def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map): + """ + f_direction: h x w x 2 + pos_list: [[y, x], [y, x], [y, x] ...] + binary_tcl_map: h x w + """ + h, w, _ = f_direction.shape + sorted_list, point_direction = sort_with_direction(pos_list, f_direction) + + point_num = len(sorted_list) + sub_direction_len = max(point_num // 3, 2) + left_direction = point_direction[:sub_direction_len, :] + right_dirction = point_direction[point_num - sub_direction_len:, :] + + left_average_direction = -np.mean(left_direction, axis=0, keepdims=True) + left_average_len = np.linalg.norm(left_average_direction) + left_start = np.array(sorted_list[0]) + left_step = left_average_direction / (left_average_len + 1e-6) + + right_average_direction = np.mean(right_dirction, axis=0, keepdims=True) + right_average_len = np.linalg.norm(right_average_direction) + right_step = right_average_direction / (right_average_len + 1e-6) + right_start = np.array(sorted_list[-1]) + + append_num = max( + int((left_average_len + right_average_len) / 2.0 * 0.15), 1) + max_append_num = 2 * append_num + + left_list = [] + right_list = [] + for i in range(max_append_num): + ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype( + 'int32').tolist() + if ly < h and lx < w and (ly, lx) not in left_list: + if binary_tcl_map[ly, lx] > 0.5: + left_list.append((ly, lx)) + else: + break + + for i in range(max_append_num): + ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype( + 'int32').tolist() + if ry < h and rx < w and (ry, rx) not in right_list: + if binary_tcl_map[ry, rx] > 0.5: + right_list.append((ry, rx)) + else: + break + + all_list = left_list[::-1] + sorted_list + right_list + return all_list + + +def point_pair2poly(point_pair_list): + """ + Transfer vertical point_pairs into poly point in clockwise. + """ + point_num = len(point_pair_list) * 2 + point_list = [0] * point_num + for idx, point_pair in enumerate(point_pair_list): + point_list[idx] = point_pair[0] + point_list[point_num - 1 - idx] = point_pair[1] + return np.array(point_list).reshape(-1, 2) + + +def shrink_quad_along_width(quad, begin_width_ratio=0., end_width_ratio=1.): + ratio_pair = np.array( + [[begin_width_ratio], [end_width_ratio]], dtype=np.float32) + p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair + p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair + return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]]) + + +def expand_poly_along_width(poly, shrink_ratio_of_width=0.3): + """ + expand poly along width. + """ + point_num = poly.shape[0] + left_quad = np.array( + [poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32) + left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \ + (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6) + left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0) + right_quad = np.array( + [ + poly[point_num // 2 - 2], poly[point_num // 2 - 1], + poly[point_num // 2], poly[point_num // 2 + 1] + ], + dtype=np.float32) + right_ratio = 1.0 + shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \ + (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6) + right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio) + poly[0] = left_quad_expand[0] + poly[-1] = left_quad_expand[-1] + poly[point_num // 2 - 1] = right_quad_expand[1] + poly[point_num // 2] = right_quad_expand[2] + return poly + + +def restore_poly(instance_yxs_list, seq_strs, p_border, ratio_w, ratio_h, src_w, + src_h, valid_set): + poly_list = [] + keep_str_list = [] + for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs): + if len(keep_str) < 2: + print('--> too short, {}'.format(keep_str)) + continue + + offset_expand = 1.0 + if valid_set == 'totaltext': + offset_expand = 1.2 + + point_pair_list = [] + for y, x in yx_center_line: + offset = p_border[:, y, x].reshape(2, 2) * offset_expand + ori_yx = np.array([y, x], dtype=np.float32) + point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array( + [ratio_w, ratio_h]).reshape(-1, 2) + point_pair_list.append(point_pair) + + detected_poly = point_pair2poly(point_pair_list) + detected_poly = expand_poly_along_width( + detected_poly, shrink_ratio_of_width=0.2) + detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w) + detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h) + + keep_str_list.append(keep_str) + if valid_set == 'partvgg': + middle_point = len(detected_poly) // 2 + detected_poly = detected_poly[ + [0, middle_point - 1, middle_point, -1], :] + poly_list.append(detected_poly) + elif valid_set == 'totaltext': + poly_list.append(detected_poly) + else: + print('--> Not supported format.') + exit(-1) + return poly_list, keep_str_list + + +def generate_pivot_list_fast(p_score, + p_char_maps, + f_direction, + Lexicon_Table, + score_thresh=0.5): + """ + return center point and end point of TCL instance; filter with the char maps; + """ + p_score = p_score[0] + f_direction = f_direction.transpose(1, 2, 0) + p_tcl_map = (p_score > score_thresh) * 1.0 + skeleton_map = thin(p_tcl_map.astype(np.uint8)) + instance_count, instance_label_map = cv2.connectedComponents( + skeleton_map.astype(np.uint8), connectivity=8) + + # get TCL Instance + all_pos_yxs = [] + if instance_count > 0: + for instance_id in range(1, instance_count): + pos_list = [] + ys, xs = np.where(instance_label_map == instance_id) + pos_list = list(zip(ys, xs)) + + if len(pos_list) < 3: + continue + + pos_list_sorted = sort_and_expand_with_direction_v2( + pos_list, f_direction, p_tcl_map) + all_pos_yxs.append(pos_list_sorted) + + p_char_maps = p_char_maps.transpose([1, 2, 0]) + decoded_str, keep_yxs_list = ctc_decoder_for_image( + all_pos_yxs, logits_map=p_char_maps, Lexicon_Table=Lexicon_Table) + return keep_yxs_list, decoded_str + + +def extract_main_direction(pos_list, f_direction): + """ + f_direction: h x w x 2 + pos_list: [[y, x], [y, x], [y, x] ...] + """ + pos_list = np.array(pos_list) + point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] + point_direction = point_direction[:, ::-1] # x, y -> y, x + average_direction = np.mean(point_direction, axis=0, keepdims=True) + average_direction = average_direction / ( + np.linalg.norm(average_direction) + 1e-6) + return average_direction + + +def sort_by_direction_with_image_id_deprecated(pos_list, f_direction): + """ + f_direction: h x w x 2 + pos_list: [[id, y, x], [id, y, x], [id, y, x] ...] + """ + pos_list_full = np.array(pos_list).reshape(-1, 3) + pos_list = pos_list_full[:, 1:] + point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y + point_direction = point_direction[:, ::-1] # x, y -> y, x + average_direction = np.mean(point_direction, axis=0, keepdims=True) + pos_proj_leng = np.sum(pos_list * average_direction, axis=1) + sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist() + return sorted_list + + +def sort_by_direction_with_image_id(pos_list, f_direction): + """ + f_direction: h x w x 2 + pos_list: [[y, x], [y, x], [y, x] ...] + """ + + def sort_part_with_direction(pos_list_full, point_direction): + pos_list_full = np.array(pos_list_full).reshape(-1, 3) + pos_list = pos_list_full[:, 1:] + point_direction = np.array(point_direction).reshape(-1, 2) + average_direction = np.mean(point_direction, axis=0, keepdims=True) + pos_proj_leng = np.sum(pos_list * average_direction, axis=1) + sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist() + sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist() + return sorted_list, sorted_direction + + pos_list = np.array(pos_list).reshape(-1, 3) + point_direction = f_direction[pos_list[:, 1], pos_list[:, 2]] # x, y + point_direction = point_direction[:, ::-1] # x, y -> y, x + sorted_point, sorted_direction = sort_part_with_direction(pos_list, + point_direction) + + point_num = len(sorted_point) + if point_num >= 16: + middle_num = point_num // 2 + first_part_point = sorted_point[:middle_num] + first_point_direction = sorted_direction[:middle_num] + sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction( + first_part_point, first_point_direction) + + last_part_point = sorted_point[middle_num:] + last_point_direction = sorted_direction[middle_num:] + sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction( + last_part_point, last_point_direction) + sorted_point = sorted_fist_part_point + sorted_last_part_point + sorted_direction = sorted_fist_part_direction + sorted_last_part_direction + + return sorted_point diff --git a/backend/ppocr/utils/e2e_utils/extract_textpoint_slow.py b/backend/ppocr/utils/e2e_utils/extract_textpoint_slow.py new file mode 100644 index 00000000..ace46fba --- /dev/null +++ b/backend/ppocr/utils/e2e_utils/extract_textpoint_slow.py @@ -0,0 +1,592 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains various CTC decoders.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import cv2 +import math + +import numpy as np +from itertools import groupby +from skimage.morphology._skeletonize import thin + + +def get_dict(character_dict_path): + character_str = "" + with open(character_dict_path, "rb") as fin: + lines = fin.readlines() + for line in lines: + line = line.decode('utf-8').strip("\n").strip("\r\n") + character_str += line + dict_character = list(character_str) + return dict_character + + +def point_pair2poly(point_pair_list): + """ + Transfer vertical point_pairs into poly point in clockwise. + """ + pair_length_list = [] + for point_pair in point_pair_list: + pair_length = np.linalg.norm(point_pair[0] - point_pair[1]) + pair_length_list.append(pair_length) + pair_length_list = np.array(pair_length_list) + pair_info = (pair_length_list.max(), pair_length_list.min(), + pair_length_list.mean()) + + point_num = len(point_pair_list) * 2 + point_list = [0] * point_num + for idx, point_pair in enumerate(point_pair_list): + point_list[idx] = point_pair[0] + point_list[point_num - 1 - idx] = point_pair[1] + return np.array(point_list).reshape(-1, 2), pair_info + + +def shrink_quad_along_width(quad, begin_width_ratio=0., end_width_ratio=1.): + """ + Generate shrink_quad_along_width. + """ + ratio_pair = np.array( + [[begin_width_ratio], [end_width_ratio]], dtype=np.float32) + p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair + p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair + return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]]) + + +def expand_poly_along_width(poly, shrink_ratio_of_width=0.3): + """ + expand poly along width. + """ + point_num = poly.shape[0] + left_quad = np.array( + [poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32) + left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \ + (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6) + left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0) + right_quad = np.array( + [ + poly[point_num // 2 - 2], poly[point_num // 2 - 1], + poly[point_num // 2], poly[point_num // 2 + 1] + ], + dtype=np.float32) + right_ratio = 1.0 + \ + shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \ + (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6) + right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio) + poly[0] = left_quad_expand[0] + poly[-1] = left_quad_expand[-1] + poly[point_num // 2 - 1] = right_quad_expand[1] + poly[point_num // 2] = right_quad_expand[2] + return poly + + +def softmax(logits): + """ + logits: N x d + """ + max_value = np.max(logits, axis=1, keepdims=True) + exp = np.exp(logits - max_value) + exp_sum = np.sum(exp, axis=1, keepdims=True) + dist = exp / exp_sum + return dist + + +def get_keep_pos_idxs(labels, remove_blank=None): + """ + Remove duplicate and get pos idxs of keep items. + The value of keep_blank should be [None, 95]. + """ + duplicate_len_list = [] + keep_pos_idx_list = [] + keep_char_idx_list = [] + for k, v_ in groupby(labels): + current_len = len(list(v_)) + if k != remove_blank: + current_idx = int(sum(duplicate_len_list) + current_len // 2) + keep_pos_idx_list.append(current_idx) + keep_char_idx_list.append(k) + duplicate_len_list.append(current_len) + return keep_char_idx_list, keep_pos_idx_list + + +def remove_blank(labels, blank=0): + new_labels = [x for x in labels if x != blank] + return new_labels + + +def insert_blank(labels, blank=0): + new_labels = [blank] + for l in labels: + new_labels += [l, blank] + return new_labels + + +def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True): + """ + CTC greedy (best path) decoder. + """ + raw_str = np.argmax(np.array(probs_seq), axis=1) + remove_blank_in_pos = None if keep_blank_in_idxs else blank + dedup_str, keep_idx_list = get_keep_pos_idxs( + raw_str, remove_blank=remove_blank_in_pos) + dst_str = remove_blank(dedup_str, blank=blank) + return dst_str, keep_idx_list + + +def instance_ctc_greedy_decoder(gather_info, + logits_map, + keep_blank_in_idxs=True): + """ + gather_info: [[x, y], [x, y] ...] + logits_map: H x W X (n_chars + 1) + """ + _, _, C = logits_map.shape + ys, xs = zip(*gather_info) + logits_seq = logits_map[list(ys), list(xs)] # n x 96 + probs_seq = softmax(logits_seq) + dst_str, keep_idx_list = ctc_greedy_decoder( + probs_seq, blank=C - 1, keep_blank_in_idxs=keep_blank_in_idxs) + keep_gather_list = [gather_info[idx] for idx in keep_idx_list] + return dst_str, keep_gather_list + + +def ctc_decoder_for_image(gather_info_list, logits_map, + keep_blank_in_idxs=True): + """ + CTC decoder using multiple processes. + """ + decoder_results = [] + for gather_info in gather_info_list: + res = instance_ctc_greedy_decoder( + gather_info, logits_map, keep_blank_in_idxs=keep_blank_in_idxs) + decoder_results.append(res) + return decoder_results + + +def sort_with_direction(pos_list, f_direction): + """ + f_direction: h x w x 2 + pos_list: [[y, x], [y, x], [y, x] ...] + """ + + def sort_part_with_direction(pos_list, point_direction): + pos_list = np.array(pos_list).reshape(-1, 2) + point_direction = np.array(point_direction).reshape(-1, 2) + average_direction = np.mean(point_direction, axis=0, keepdims=True) + pos_proj_leng = np.sum(pos_list * average_direction, axis=1) + sorted_list = pos_list[np.argsort(pos_proj_leng)].tolist() + sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist() + return sorted_list, sorted_direction + + pos_list = np.array(pos_list).reshape(-1, 2) + point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y + point_direction = point_direction[:, ::-1] # x, y -> y, x + sorted_point, sorted_direction = sort_part_with_direction(pos_list, + point_direction) + + point_num = len(sorted_point) + if point_num >= 16: + middle_num = point_num // 2 + first_part_point = sorted_point[:middle_num] + first_point_direction = sorted_direction[:middle_num] + sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction( + first_part_point, first_point_direction) + + last_part_point = sorted_point[middle_num:] + last_point_direction = sorted_direction[middle_num:] + sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction( + last_part_point, last_point_direction) + sorted_point = sorted_fist_part_point + sorted_last_part_point + sorted_direction = sorted_fist_part_direction + sorted_last_part_direction + + return sorted_point, np.array(sorted_direction) + + +def add_id(pos_list, image_id=0): + """ + Add id for gather feature, for inference. + """ + new_list = [] + for item in pos_list: + new_list.append((image_id, item[0], item[1])) + return new_list + + +def sort_and_expand_with_direction(pos_list, f_direction): + """ + f_direction: h x w x 2 + pos_list: [[y, x], [y, x], [y, x] ...] + """ + h, w, _ = f_direction.shape + sorted_list, point_direction = sort_with_direction(pos_list, f_direction) + + # expand along + point_num = len(sorted_list) + sub_direction_len = max(point_num // 3, 2) + left_direction = point_direction[:sub_direction_len, :] + right_dirction = point_direction[point_num - sub_direction_len:, :] + + left_average_direction = -np.mean(left_direction, axis=0, keepdims=True) + left_average_len = np.linalg.norm(left_average_direction) + left_start = np.array(sorted_list[0]) + left_step = left_average_direction / (left_average_len + 1e-6) + + right_average_direction = np.mean(right_dirction, axis=0, keepdims=True) + right_average_len = np.linalg.norm(right_average_direction) + right_step = right_average_direction / (right_average_len + 1e-6) + right_start = np.array(sorted_list[-1]) + + append_num = max( + int((left_average_len + right_average_len) / 2.0 * 0.15), 1) + left_list = [] + right_list = [] + for i in range(append_num): + ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype( + 'int32').tolist() + if ly < h and lx < w and (ly, lx) not in left_list: + left_list.append((ly, lx)) + ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype( + 'int32').tolist() + if ry < h and rx < w and (ry, rx) not in right_list: + right_list.append((ry, rx)) + + all_list = left_list[::-1] + sorted_list + right_list + return all_list + + +def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map): + """ + f_direction: h x w x 2 + pos_list: [[y, x], [y, x], [y, x] ...] + binary_tcl_map: h x w + """ + h, w, _ = f_direction.shape + sorted_list, point_direction = sort_with_direction(pos_list, f_direction) + + # expand along + point_num = len(sorted_list) + sub_direction_len = max(point_num // 3, 2) + left_direction = point_direction[:sub_direction_len, :] + right_dirction = point_direction[point_num - sub_direction_len:, :] + + left_average_direction = -np.mean(left_direction, axis=0, keepdims=True) + left_average_len = np.linalg.norm(left_average_direction) + left_start = np.array(sorted_list[0]) + left_step = left_average_direction / (left_average_len + 1e-6) + + right_average_direction = np.mean(right_dirction, axis=0, keepdims=True) + right_average_len = np.linalg.norm(right_average_direction) + right_step = right_average_direction / (right_average_len + 1e-6) + right_start = np.array(sorted_list[-1]) + + append_num = max( + int((left_average_len + right_average_len) / 2.0 * 0.15), 1) + max_append_num = 2 * append_num + + left_list = [] + right_list = [] + for i in range(max_append_num): + ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype( + 'int32').tolist() + if ly < h and lx < w and (ly, lx) not in left_list: + if binary_tcl_map[ly, lx] > 0.5: + left_list.append((ly, lx)) + else: + break + + for i in range(max_append_num): + ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype( + 'int32').tolist() + if ry < h and rx < w and (ry, rx) not in right_list: + if binary_tcl_map[ry, rx] > 0.5: + right_list.append((ry, rx)) + else: + break + + all_list = left_list[::-1] + sorted_list + right_list + return all_list + + +def generate_pivot_list_curved(p_score, + p_char_maps, + f_direction, + score_thresh=0.5, + is_expand=True, + is_backbone=False, + image_id=0): + """ + return center point and end point of TCL instance; filter with the char maps; + """ + p_score = p_score[0] + f_direction = f_direction.transpose(1, 2, 0) + p_tcl_map = (p_score > score_thresh) * 1.0 + skeleton_map = thin(p_tcl_map) + instance_count, instance_label_map = cv2.connectedComponents( + skeleton_map.astype(np.uint8), connectivity=8) + + # get TCL Instance + all_pos_yxs = [] + center_pos_yxs = [] + end_points_yxs = [] + instance_center_pos_yxs = [] + pred_strs = [] + if instance_count > 0: + for instance_id in range(1, instance_count): + pos_list = [] + ys, xs = np.where(instance_label_map == instance_id) + pos_list = list(zip(ys, xs)) + + ### FIX-ME, eliminate outlier + if len(pos_list) < 3: + continue + + if is_expand: + pos_list_sorted = sort_and_expand_with_direction_v2( + pos_list, f_direction, p_tcl_map) + else: + pos_list_sorted, _ = sort_with_direction(pos_list, f_direction) + all_pos_yxs.append(pos_list_sorted) + + # use decoder to filter backgroud points. + p_char_maps = p_char_maps.transpose([1, 2, 0]) + decode_res = ctc_decoder_for_image( + all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True) + for decoded_str, keep_yxs_list in decode_res: + if is_backbone: + keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id) + instance_center_pos_yxs.append(keep_yxs_list_with_id) + pred_strs.append(decoded_str) + else: + end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1])) + center_pos_yxs.extend(keep_yxs_list) + + if is_backbone: + return pred_strs, instance_center_pos_yxs + else: + return center_pos_yxs, end_points_yxs + + +def generate_pivot_list_horizontal(p_score, + p_char_maps, + f_direction, + score_thresh=0.5, + is_backbone=False, + image_id=0): + """ + return center point and end point of TCL instance; filter with the char maps; + """ + p_score = p_score[0] + f_direction = f_direction.transpose(1, 2, 0) + p_tcl_map_bi = (p_score > score_thresh) * 1.0 + instance_count, instance_label_map = cv2.connectedComponents( + p_tcl_map_bi.astype(np.uint8), connectivity=8) + + # get TCL Instance + all_pos_yxs = [] + center_pos_yxs = [] + end_points_yxs = [] + instance_center_pos_yxs = [] + + if instance_count > 0: + for instance_id in range(1, instance_count): + pos_list = [] + ys, xs = np.where(instance_label_map == instance_id) + pos_list = list(zip(ys, xs)) + + ### FIX-ME, eliminate outlier + if len(pos_list) < 5: + continue + + # add rule here + main_direction = extract_main_direction(pos_list, + f_direction) # y x + reference_directin = np.array([0, 1]).reshape([-1, 2]) # y x + is_h_angle = abs(np.sum( + main_direction * reference_directin)) < math.cos(math.pi / 180 * + 70) + + point_yxs = np.array(pos_list) + max_y, max_x = np.max(point_yxs, axis=0) + min_y, min_x = np.min(point_yxs, axis=0) + is_h_len = (max_y - min_y) < 1.5 * (max_x - min_x) + + pos_list_final = [] + if is_h_len: + xs = np.unique(xs) + for x in xs: + ys = instance_label_map[:, x].copy().reshape((-1, )) + y = int(np.where(ys == instance_id)[0].mean()) + pos_list_final.append((y, x)) + else: + ys = np.unique(ys) + for y in ys: + xs = instance_label_map[y, :].copy().reshape((-1, )) + x = int(np.where(xs == instance_id)[0].mean()) + pos_list_final.append((y, x)) + + pos_list_sorted, _ = sort_with_direction(pos_list_final, + f_direction) + all_pos_yxs.append(pos_list_sorted) + + # use decoder to filter backgroud points. + p_char_maps = p_char_maps.transpose([1, 2, 0]) + decode_res = ctc_decoder_for_image( + all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True) + for decoded_str, keep_yxs_list in decode_res: + if is_backbone: + keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id) + instance_center_pos_yxs.append(keep_yxs_list_with_id) + else: + end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1])) + center_pos_yxs.extend(keep_yxs_list) + + if is_backbone: + return instance_center_pos_yxs + else: + return center_pos_yxs, end_points_yxs + + +def generate_pivot_list_slow(p_score, + p_char_maps, + f_direction, + score_thresh=0.5, + is_backbone=False, + is_curved=True, + image_id=0): + """ + Warp all the function together. + """ + if is_curved: + return generate_pivot_list_curved( + p_score, + p_char_maps, + f_direction, + score_thresh=score_thresh, + is_expand=True, + is_backbone=is_backbone, + image_id=image_id) + else: + return generate_pivot_list_horizontal( + p_score, + p_char_maps, + f_direction, + score_thresh=score_thresh, + is_backbone=is_backbone, + image_id=image_id) + + +# for refine module +def extract_main_direction(pos_list, f_direction): + """ + f_direction: h x w x 2 + pos_list: [[y, x], [y, x], [y, x] ...] + """ + pos_list = np.array(pos_list) + point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] + point_direction = point_direction[:, ::-1] # x, y -> y, x + average_direction = np.mean(point_direction, axis=0, keepdims=True) + average_direction = average_direction / ( + np.linalg.norm(average_direction) + 1e-6) + return average_direction + + +def sort_by_direction_with_image_id_deprecated(pos_list, f_direction): + """ + f_direction: h x w x 2 + pos_list: [[id, y, x], [id, y, x], [id, y, x] ...] + """ + pos_list_full = np.array(pos_list).reshape(-1, 3) + pos_list = pos_list_full[:, 1:] + point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y + point_direction = point_direction[:, ::-1] # x, y -> y, x + average_direction = np.mean(point_direction, axis=0, keepdims=True) + pos_proj_leng = np.sum(pos_list * average_direction, axis=1) + sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist() + return sorted_list + + +def sort_by_direction_with_image_id(pos_list, f_direction): + """ + f_direction: h x w x 2 + pos_list: [[y, x], [y, x], [y, x] ...] + """ + + def sort_part_with_direction(pos_list_full, point_direction): + pos_list_full = np.array(pos_list_full).reshape(-1, 3) + pos_list = pos_list_full[:, 1:] + point_direction = np.array(point_direction).reshape(-1, 2) + average_direction = np.mean(point_direction, axis=0, keepdims=True) + pos_proj_leng = np.sum(pos_list * average_direction, axis=1) + sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist() + sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist() + return sorted_list, sorted_direction + + pos_list = np.array(pos_list).reshape(-1, 3) + point_direction = f_direction[pos_list[:, 1], pos_list[:, 2]] # x, y + point_direction = point_direction[:, ::-1] # x, y -> y, x + sorted_point, sorted_direction = sort_part_with_direction(pos_list, + point_direction) + + point_num = len(sorted_point) + if point_num >= 16: + middle_num = point_num // 2 + first_part_point = sorted_point[:middle_num] + first_point_direction = sorted_direction[:middle_num] + sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction( + first_part_point, first_point_direction) + + last_part_point = sorted_point[middle_num:] + last_point_direction = sorted_direction[middle_num:] + sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction( + last_part_point, last_point_direction) + sorted_point = sorted_fist_part_point + sorted_last_part_point + sorted_direction = sorted_fist_part_direction + sorted_last_part_direction + + return sorted_point + + +def generate_pivot_list_tt_inference(p_score, + p_char_maps, + f_direction, + score_thresh=0.5, + is_backbone=False, + is_curved=True, + image_id=0): + """ + return center point and end point of TCL instance; filter with the char maps; + """ + p_score = p_score[0] + f_direction = f_direction.transpose(1, 2, 0) + p_tcl_map = (p_score > score_thresh) * 1.0 + skeleton_map = thin(p_tcl_map) + instance_count, instance_label_map = cv2.connectedComponents( + skeleton_map.astype(np.uint8), connectivity=8) + + # get TCL Instance + all_pos_yxs = [] + if instance_count > 0: + for instance_id in range(1, instance_count): + pos_list = [] + ys, xs = np.where(instance_label_map == instance_id) + pos_list = list(zip(ys, xs)) + ### FIX-ME, eliminate outlier + if len(pos_list) < 3: + continue + pos_list_sorted = sort_and_expand_with_direction_v2( + pos_list, f_direction, p_tcl_map) + pos_list_sorted_with_id = add_id(pos_list_sorted, image_id=image_id) + all_pos_yxs.append(pos_list_sorted_with_id) + return all_pos_yxs diff --git a/backend/ppocr/utils/e2e_utils/pgnet_pp_utils.py b/backend/ppocr/utils/e2e_utils/pgnet_pp_utils.py new file mode 100644 index 00000000..a15503c0 --- /dev/null +++ b/backend/ppocr/utils/e2e_utils/pgnet_pp_utils.py @@ -0,0 +1,162 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import paddle +import os +import sys + +__dir__ = os.path.dirname(__file__) +sys.path.append(__dir__) +sys.path.append(os.path.join(__dir__, '..')) +from extract_textpoint_slow import * +from extract_textpoint_fast import generate_pivot_list_fast, restore_poly + + +class PGNet_PostProcess(object): + # two different post-process + def __init__(self, character_dict_path, valid_set, score_thresh, outs_dict, + shape_list): + self.Lexicon_Table = get_dict(character_dict_path) + self.valid_set = valid_set + self.score_thresh = score_thresh + self.outs_dict = outs_dict + self.shape_list = shape_list + + def pg_postprocess_fast(self): + p_score = self.outs_dict['f_score'] + p_border = self.outs_dict['f_border'] + p_char = self.outs_dict['f_char'] + p_direction = self.outs_dict['f_direction'] + if isinstance(p_score, paddle.Tensor): + p_score = p_score[0].numpy() + p_border = p_border[0].numpy() + p_direction = p_direction[0].numpy() + p_char = p_char[0].numpy() + else: + p_score = p_score[0] + p_border = p_border[0] + p_direction = p_direction[0] + p_char = p_char[0] + + src_h, src_w, ratio_h, ratio_w = self.shape_list[0] + instance_yxs_list, seq_strs = generate_pivot_list_fast( + p_score, + p_char, + p_direction, + self.Lexicon_Table, + score_thresh=self.score_thresh) + poly_list, keep_str_list = restore_poly(instance_yxs_list, seq_strs, + p_border, ratio_w, ratio_h, + src_w, src_h, self.valid_set) + data = { + 'points': poly_list, + 'texts': keep_str_list, + } + return data + + def pg_postprocess_slow(self): + p_score = self.outs_dict['f_score'] + p_border = self.outs_dict['f_border'] + p_char = self.outs_dict['f_char'] + p_direction = self.outs_dict['f_direction'] + if isinstance(p_score, paddle.Tensor): + p_score = p_score[0].numpy() + p_border = p_border[0].numpy() + p_direction = p_direction[0].numpy() + p_char = p_char[0].numpy() + else: + p_score = p_score[0] + p_border = p_border[0] + p_direction = p_direction[0] + p_char = p_char[0] + src_h, src_w, ratio_h, ratio_w = self.shape_list[0] + is_curved = self.valid_set == "totaltext" + char_seq_idx_set, instance_yxs_list = generate_pivot_list_slow( + p_score, + p_char, + p_direction, + score_thresh=self.score_thresh, + is_backbone=True, + is_curved=is_curved) + seq_strs = [] + for char_idx_set in char_seq_idx_set: + pr_str = ''.join([self.Lexicon_Table[pos] for pos in char_idx_set]) + seq_strs.append(pr_str) + poly_list = [] + keep_str_list = [] + all_point_list = [] + all_point_pair_list = [] + for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs): + if len(yx_center_line) == 1: + yx_center_line.append(yx_center_line[-1]) + + offset_expand = 1.0 + if self.valid_set == 'totaltext': + offset_expand = 1.2 + + point_pair_list = [] + for batch_id, y, x in yx_center_line: + offset = p_border[:, y, x].reshape(2, 2) + if offset_expand != 1.0: + offset_length = np.linalg.norm( + offset, axis=1, keepdims=True) + expand_length = np.clip( + offset_length * (offset_expand - 1), + a_min=0.5, + a_max=3.0) + offset_detal = offset / offset_length * expand_length + offset = offset + offset_detal + ori_yx = np.array([y, x], dtype=np.float32) + point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array( + [ratio_w, ratio_h]).reshape(-1, 2) + point_pair_list.append(point_pair) + + all_point_list.append([ + int(round(x * 4.0 / ratio_w)), + int(round(y * 4.0 / ratio_h)) + ]) + all_point_pair_list.append(point_pair.round().astype(np.int32) + .tolist()) + + detected_poly, pair_length_info = point_pair2poly(point_pair_list) + detected_poly = expand_poly_along_width( + detected_poly, shrink_ratio_of_width=0.2) + detected_poly[:, 0] = np.clip( + detected_poly[:, 0], a_min=0, a_max=src_w) + detected_poly[:, 1] = np.clip( + detected_poly[:, 1], a_min=0, a_max=src_h) + + if len(keep_str) < 2: + continue + + keep_str_list.append(keep_str) + detected_poly = np.round(detected_poly).astype('int32') + if self.valid_set == 'partvgg': + middle_point = len(detected_poly) // 2 + detected_poly = detected_poly[ + [0, middle_point - 1, middle_point, -1], :] + poly_list.append(detected_poly) + elif self.valid_set == 'totaltext': + poly_list.append(detected_poly) + else: + print('--> Not supported format.') + exit(-1) + data = { + 'points': poly_list, + 'texts': keep_str_list, + } + return data diff --git a/backend/ppocr/utils/e2e_utils/visual.py b/backend/ppocr/utils/e2e_utils/visual.py new file mode 100644 index 00000000..e6e4fd06 --- /dev/null +++ b/backend/ppocr/utils/e2e_utils/visual.py @@ -0,0 +1,162 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import cv2 +import time + + +def resize_image(im, max_side_len=512): + """ + resize image to a size multiple of max_stride which is required by the network + :param im: the resized image + :param max_side_len: limit of max image size to avoid out of memory in gpu + :return: the resized image and the resize ratio + """ + h, w, _ = im.shape + + resize_w = w + resize_h = h + + if resize_h > resize_w: + ratio = float(max_side_len) / resize_h + else: + ratio = float(max_side_len) / resize_w + + resize_h = int(resize_h * ratio) + resize_w = int(resize_w * ratio) + + max_stride = 128 + resize_h = (resize_h + max_stride - 1) // max_stride * max_stride + resize_w = (resize_w + max_stride - 1) // max_stride * max_stride + im = cv2.resize(im, (int(resize_w), int(resize_h))) + ratio_h = resize_h / float(h) + ratio_w = resize_w / float(w) + + return im, (ratio_h, ratio_w) + + +def resize_image_min(im, max_side_len=512): + """ + """ + h, w, _ = im.shape + + resize_w = w + resize_h = h + + if resize_h < resize_w: + ratio = float(max_side_len) / resize_h + else: + ratio = float(max_side_len) / resize_w + + resize_h = int(resize_h * ratio) + resize_w = int(resize_w * ratio) + + max_stride = 128 + resize_h = (resize_h + max_stride - 1) // max_stride * max_stride + resize_w = (resize_w + max_stride - 1) // max_stride * max_stride + im = cv2.resize(im, (int(resize_w), int(resize_h))) + ratio_h = resize_h / float(h) + ratio_w = resize_w / float(w) + return im, (ratio_h, ratio_w) + + +def resize_image_for_totaltext(im, max_side_len=512): + """ + """ + h, w, _ = im.shape + + resize_w = w + resize_h = h + ratio = 1.25 + if h * ratio > max_side_len: + ratio = float(max_side_len) / resize_h + + resize_h = int(resize_h * ratio) + resize_w = int(resize_w * ratio) + + max_stride = 128 + resize_h = (resize_h + max_stride - 1) // max_stride * max_stride + resize_w = (resize_w + max_stride - 1) // max_stride * max_stride + im = cv2.resize(im, (int(resize_w), int(resize_h))) + ratio_h = resize_h / float(h) + ratio_w = resize_w / float(w) + return im, (ratio_h, ratio_w) + + +def point_pair2poly(point_pair_list): + """ + Transfer vertical point_pairs into poly point in clockwise. + """ + pair_length_list = [] + for point_pair in point_pair_list: + pair_length = np.linalg.norm(point_pair[0] - point_pair[1]) + pair_length_list.append(pair_length) + pair_length_list = np.array(pair_length_list) + pair_info = (pair_length_list.max(), pair_length_list.min(), + pair_length_list.mean()) + + point_num = len(point_pair_list) * 2 + point_list = [0] * point_num + for idx, point_pair in enumerate(point_pair_list): + point_list[idx] = point_pair[0] + point_list[point_num - 1 - idx] = point_pair[1] + return np.array(point_list).reshape(-1, 2), pair_info + + +def shrink_quad_along_width(quad, begin_width_ratio=0., end_width_ratio=1.): + """ + Generate shrink_quad_along_width. + """ + ratio_pair = np.array( + [[begin_width_ratio], [end_width_ratio]], dtype=np.float32) + p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair + p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair + return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]]) + + +def expand_poly_along_width(poly, shrink_ratio_of_width=0.3): + """ + expand poly along width. + """ + point_num = poly.shape[0] + left_quad = np.array( + [poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32) + left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \ + (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6) + left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0) + right_quad = np.array( + [ + poly[point_num // 2 - 2], poly[point_num // 2 - 1], + poly[point_num // 2], poly[point_num // 2 + 1] + ], + dtype=np.float32) + right_ratio = 1.0 + \ + shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \ + (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6) + right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio) + poly[0] = left_quad_expand[0] + poly[-1] = left_quad_expand[-1] + poly[point_num // 2 - 1] = right_quad_expand[1] + poly[point_num // 2] = right_quad_expand[2] + return poly + + +def norm2(x, axis=None): + if axis: + return np.sqrt(np.sum(x**2, axis=axis)) + return np.sqrt(np.sum(x**2)) + + +def cos(p1, p2): + return (p1 * p2).sum() / (norm2(p1) * norm2(p2)) diff --git a/backend/ppocr/utils/gen_label.py b/backend/ppocr/utils/gen_label.py deleted file mode 100644 index 43afe9dd..00000000 --- a/backend/ppocr/utils/gen_label.py +++ /dev/null @@ -1,79 +0,0 @@ -#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. -# -#Licensed under the Apache License, Version 2.0 (the "License"); -#you may not use this file except in compliance with the License. -#You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -#Unless required by applicable law or agreed to in writing, software -#distributed under the License is distributed on an "AS IS" BASIS, -#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -#See the License for the specific language governing permissions and -#limitations under the License. -import os -import argparse -import json - - -def gen_rec_label(input_path, out_label): - with open(out_label, 'w') as out_file: - with open(input_path, 'r') as f: - for line in f.readlines(): - tmp = line.strip('\n').replace(" ", "").split(',') - img_path, label = tmp[0], tmp[1] - label = label.replace("\"", "") - out_file.write(img_path + '\t' + label + '\n') - - -def gen_det_label(root_path, input_dir, out_label): - with open(out_label, 'w') as out_file: - for label_file in os.listdir(input_dir): - img_path = root_path + label_file[3:-4] + ".jpg" - label = [] - with open(os.path.join(input_dir, label_file), 'r') as f: - for line in f.readlines(): - tmp = line.strip("\n\r").replace("\xef\xbb\xbf", - "").split(',') - points = tmp[:8] - s = [] - for i in range(0, len(points), 2): - b = points[i:i + 2] - b = [int(t) for t in b] - s.append(b) - result = {"transcription": tmp[8], "points": s} - label.append(result) - - out_file.write(img_path + '\t' + json.dumps( - label, ensure_ascii=False) + '\n') - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - '--mode', - type=str, - default="rec", - help='Generate rec_label or det_label, can be set rec or det') - parser.add_argument( - '--root_path', - type=str, - default=".", - help='The root directory of images.Only takes effect when mode=det ') - parser.add_argument( - '--input_path', - type=str, - default=".", - help='Input_label or input path to be converted') - parser.add_argument( - '--output_label', - type=str, - default="out_label.txt", - help='Output file name') - - args = parser.parse_args() - if args.mode == "rec": - print("Generate rec label") - gen_rec_label(args.input_path, args.output_label) - elif args.mode == "det": - gen_det_label(args.root_path, args.input_path, args.output_label) diff --git a/backend/ppocr/utils/iou.py b/backend/ppocr/utils/iou.py new file mode 100644 index 00000000..35459f5f --- /dev/null +++ b/backend/ppocr/utils/iou.py @@ -0,0 +1,54 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/whai362/PSENet/blob/python3/models/loss/iou.py +""" + +import paddle + +EPS = 1e-6 + + +def iou_single(a, b, mask, n_class): + valid = mask == 1 + a = a.masked_select(valid) + b = b.masked_select(valid) + miou = [] + for i in range(n_class): + if a.shape == [0] and a.shape == b.shape: + inter = paddle.to_tensor(0.0) + union = paddle.to_tensor(0.0) + else: + inter = ((a == i).logical_and(b == i)).astype('float32') + union = ((a == i).logical_or(b == i)).astype('float32') + miou.append(paddle.sum(inter) / (paddle.sum(union) + EPS)) + miou = sum(miou) / len(miou) + return miou + + +def iou(a, b, mask, n_class=2, reduce=True): + batch_size = a.shape[0] + + a = a.reshape([batch_size, -1]) + b = b.reshape([batch_size, -1]) + mask = mask.reshape([batch_size, -1]) + + iou = paddle.zeros((batch_size, ), dtype='float32') + for i in range(batch_size): + iou[i] = iou_single(a[i], b[i], mask[i], n_class) + + if reduce: + iou = paddle.mean(iou) + return iou diff --git a/backend/ppocr/utils/loggers/__init__.py b/backend/ppocr/utils/loggers/__init__.py new file mode 100644 index 00000000..b1e92f73 --- /dev/null +++ b/backend/ppocr/utils/loggers/__init__.py @@ -0,0 +1,3 @@ +from .vdl_logger import VDLLogger +from .wandb_logger import WandbLogger +from .loggers import Loggers diff --git a/backend/ppocr/utils/loggers/base_logger.py b/backend/ppocr/utils/loggers/base_logger.py new file mode 100644 index 00000000..3a7fc359 --- /dev/null +++ b/backend/ppocr/utils/loggers/base_logger.py @@ -0,0 +1,15 @@ +import os +from abc import ABC, abstractmethod + +class BaseLogger(ABC): + def __init__(self, save_dir): + self.save_dir = save_dir + os.makedirs(self.save_dir, exist_ok=True) + + @abstractmethod + def log_metrics(self, metrics, prefix=None): + pass + + @abstractmethod + def close(self): + pass \ No newline at end of file diff --git a/backend/ppocr/utils/loggers/loggers.py b/backend/ppocr/utils/loggers/loggers.py new file mode 100644 index 00000000..26014662 --- /dev/null +++ b/backend/ppocr/utils/loggers/loggers.py @@ -0,0 +1,18 @@ +from .wandb_logger import WandbLogger + +class Loggers(object): + def __init__(self, loggers): + super().__init__() + self.loggers = loggers + + def log_metrics(self, metrics, prefix=None, step=None): + for logger in self.loggers: + logger.log_metrics(metrics, prefix=prefix, step=step) + + def log_model(self, is_best, prefix, metadata=None): + for logger in self.loggers: + logger.log_model(is_best=is_best, prefix=prefix, metadata=metadata) + + def close(self): + for logger in self.loggers: + logger.close() \ No newline at end of file diff --git a/backend/ppocr/utils/loggers/vdl_logger.py b/backend/ppocr/utils/loggers/vdl_logger.py new file mode 100644 index 00000000..c345f932 --- /dev/null +++ b/backend/ppocr/utils/loggers/vdl_logger.py @@ -0,0 +1,21 @@ +from .base_logger import BaseLogger +from visualdl import LogWriter + +class VDLLogger(BaseLogger): + def __init__(self, save_dir): + super().__init__(save_dir) + self.vdl_writer = LogWriter(logdir=save_dir) + + def log_metrics(self, metrics, prefix=None, step=None): + if not prefix: + prefix = "" + updated_metrics = {prefix + "/" + k: v for k, v in metrics.items()} + + for k, v in updated_metrics.items(): + self.vdl_writer.add_scalar(k, v, step) + + def log_model(self, is_best, prefix, metadata=None): + pass + + def close(self): + self.vdl_writer.close() \ No newline at end of file diff --git a/backend/ppocr/utils/loggers/wandb_logger.py b/backend/ppocr/utils/loggers/wandb_logger.py new file mode 100644 index 00000000..5c805f4e --- /dev/null +++ b/backend/ppocr/utils/loggers/wandb_logger.py @@ -0,0 +1,78 @@ +import os +from .base_logger import BaseLogger + +class WandbLogger(BaseLogger): + def __init__(self, + project=None, + name=None, + id=None, + entity=None, + save_dir=None, + config=None, + **kwargs): + try: + import wandb + self.wandb = wandb + except ModuleNotFoundError: + raise ModuleNotFoundError( + "Please install wandb using `pip install wandb`" + ) + + self.project = project + self.name = name + self.id = id + self.save_dir = save_dir + self.config = config + self.kwargs = kwargs + self.entity = entity + self._run = None + self._wandb_init = dict( + project=self.project, + name=self.name, + id=self.id, + entity=self.entity, + dir=self.save_dir, + resume="allow" + ) + self._wandb_init.update(**kwargs) + + _ = self.run + + if self.config: + self.run.settings_config.update(self.config) + + @property + def run(self): + if self._run is None: + if self.wandb.run is not None: + logger.info( + "There is a wandb run already in progress " + "and newly created instances of `WandbLogger` will reuse" + " this run. If this is not desired, call `wandb.finish()`" + "before instantiating `WandbLogger`." + ) + self._run = self.wandb.run + else: + self._run = self.wandb.init(**self._wandb_init) + return self._run + + def log_metrics(self, metrics, prefix=None, step=None): + if not prefix: + prefix = "" + updated_metrics = {prefix.lower() + "/" + k: v for k, v in metrics.items()} + + self.run.log(updated_metrics, step=step) + + def log_model(self, is_best, prefix, metadata=None): + model_path = os.path.join(self.save_dir, prefix + '.pdparams') + artifact = self.wandb.Artifact('model-{}'.format(self.run.id), type='model', metadata=metadata) + artifact.add_file(model_path, name="model_ckpt.pdparams") + + aliases = [prefix] + if is_best: + aliases.append("best") + + self.run.log_artifact(artifact, aliases=aliases) + + def close(self): + self.run.finish() \ No newline at end of file diff --git a/backend/ppocr/utils/logging.py b/backend/ppocr/utils/logging.py index 951141db..1eac8f35 100644 --- a/backend/ppocr/utils/logging.py +++ b/backend/ppocr/utils/logging.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +This code is refer from: +https://github.com/WenmuZhou/PytorchOCR/blob/master/torchocr/utils/logging.py +""" import os import sys @@ -22,7 +26,7 @@ @functools.lru_cache() -def get_logger(name='root', log_file=None, log_level=logging.INFO): +def get_logger(name='ppocr', log_file=None, log_level=logging.DEBUG): """Initialize and get a logger by name. If the logger has not been initialized, this method will initialize the logger by adding one or two handlers, otherwise the initialized logger will @@ -63,4 +67,5 @@ def get_logger(name='root', log_file=None, log_level=logging.INFO): else: logger.setLevel(logging.ERROR) logger_initialized[name] = True + logger.propagate = False return logger diff --git a/backend/ppocr/utils/network.py b/backend/ppocr/utils/network.py new file mode 100644 index 00000000..118d1be3 --- /dev/null +++ b/backend/ppocr/utils/network.py @@ -0,0 +1,84 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import tarfile +import requests +from tqdm import tqdm + +from ppocr.utils.logging import get_logger + + +def download_with_progressbar(url, save_path): + logger = get_logger() + response = requests.get(url, stream=True) + if response.status_code == 200: + total_size_in_bytes = int(response.headers.get('content-length', 1)) + block_size = 1024 # 1 Kibibyte + progress_bar = tqdm( + total=total_size_in_bytes, unit='iB', unit_scale=True) + with open(save_path, 'wb') as file: + for data in response.iter_content(block_size): + progress_bar.update(len(data)) + file.write(data) + progress_bar.close() + else: + logger.error("Something went wrong while downloading models") + sys.exit(0) + + +def maybe_download(model_storage_directory, url): + # using custom model + tar_file_name_list = [ + 'inference.pdiparams', 'inference.pdiparams.info', 'inference.pdmodel' + ] + if not os.path.exists( + os.path.join(model_storage_directory, 'inference.pdiparams') + ) or not os.path.exists( + os.path.join(model_storage_directory, 'inference.pdmodel')): + assert url.endswith('.tar'), 'Only supports tar compressed package' + tmp_path = os.path.join(model_storage_directory, url.split('/')[-1]) + print('download {} to {}'.format(url, tmp_path)) + os.makedirs(model_storage_directory, exist_ok=True) + download_with_progressbar(url, tmp_path) + with tarfile.open(tmp_path, 'r') as tarObj: + for member in tarObj.getmembers(): + filename = None + for tar_file_name in tar_file_name_list: + if tar_file_name in member.name: + filename = tar_file_name + if filename is None: + continue + file = tarObj.extractfile(member) + with open( + os.path.join(model_storage_directory, filename), + 'wb') as f: + f.write(file.read()) + os.remove(tmp_path) + + +def is_link(s): + return s is not None and s.startswith('http') + + +def confirm_model_dir_url(model_dir, default_model_dir, default_url): + url = default_url + if model_dir is None or is_link(model_dir): + if is_link(model_dir): + url = model_dir + file_name = url.split('/')[-1][:-4] + model_dir = default_model_dir + model_dir = os.path.join(model_dir, file_name) + return model_dir, url diff --git a/backend/ppocr/utils/poly_nms.py b/backend/ppocr/utils/poly_nms.py new file mode 100644 index 00000000..9dcb3d2c --- /dev/null +++ b/backend/ppocr/utils/poly_nms.py @@ -0,0 +1,146 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from shapely.geometry import Polygon + + +def points2polygon(points): + """Convert k points to 1 polygon. + + Args: + points (ndarray or list): A ndarray or a list of shape (2k) + that indicates k points. + + Returns: + polygon (Polygon): A polygon object. + """ + if isinstance(points, list): + points = np.array(points) + + assert isinstance(points, np.ndarray) + assert (points.size % 2 == 0) and (points.size >= 8) + + point_mat = points.reshape([-1, 2]) + return Polygon(point_mat) + + +def poly_intersection(poly_det, poly_gt, buffer=0.0001): + """Calculate the intersection area between two polygon. + + Args: + poly_det (Polygon): A polygon predicted by detector. + poly_gt (Polygon): A gt polygon. + + Returns: + intersection_area (float): The intersection area between two polygons. + """ + assert isinstance(poly_det, Polygon) + assert isinstance(poly_gt, Polygon) + + if buffer == 0: + poly_inter = poly_det & poly_gt + else: + poly_inter = poly_det.buffer(buffer) & poly_gt.buffer(buffer) + return poly_inter.area, poly_inter + + +def poly_union(poly_det, poly_gt): + """Calculate the union area between two polygon. + + Args: + poly_det (Polygon): A polygon predicted by detector. + poly_gt (Polygon): A gt polygon. + + Returns: + union_area (float): The union area between two polygons. + """ + assert isinstance(poly_det, Polygon) + assert isinstance(poly_gt, Polygon) + + area_det = poly_det.area + area_gt = poly_gt.area + area_inters, _ = poly_intersection(poly_det, poly_gt) + return area_det + area_gt - area_inters + + +def valid_boundary(x, with_score=True): + num = len(x) + if num < 8: + return False + if num % 2 == 0 and (not with_score): + return True + if num % 2 == 1 and with_score: + return True + + return False + + +def boundary_iou(src, target): + """Calculate the IOU between two boundaries. + + Args: + src (list): Source boundary. + target (list): Target boundary. + + Returns: + iou (float): The iou between two boundaries. + """ + assert valid_boundary(src, False) + assert valid_boundary(target, False) + src_poly = points2polygon(src) + target_poly = points2polygon(target) + + return poly_iou(src_poly, target_poly) + + +def poly_iou(poly_det, poly_gt): + """Calculate the IOU between two polygons. + + Args: + poly_det (Polygon): A polygon predicted by detector. + poly_gt (Polygon): A gt polygon. + + Returns: + iou (float): The IOU between two polygons. + """ + assert isinstance(poly_det, Polygon) + assert isinstance(poly_gt, Polygon) + area_inters, _ = poly_intersection(poly_det, poly_gt) + area_union = poly_union(poly_det, poly_gt) + if area_union == 0: + return 0.0 + return area_inters / area_union + + +def poly_nms(polygons, threshold): + assert isinstance(polygons, list) + + polygons = np.array(sorted(polygons, key=lambda x: x[-1])) + + keep_poly = [] + index = [i for i in range(polygons.shape[0])] + + while len(index) > 0: + keep_poly.append(polygons[index[-1]].tolist()) + A = polygons[index[-1]][:-1] + index = np.delete(index, -1) + iou_list = np.zeros((len(index), )) + for i in range(len(index)): + B = polygons[index[i]][:-1] + iou_list[i] = boundary_iou(A, B) + remove_index = np.where(iou_list > threshold) + index = np.delete(index, remove_index) + + return keep_poly diff --git a/backend/ppocr/utils/profiler.py b/backend/ppocr/utils/profiler.py new file mode 100644 index 00000000..c4e28bc6 --- /dev/null +++ b/backend/ppocr/utils/profiler.py @@ -0,0 +1,110 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import paddle + +# A global variable to record the number of calling times for profiler +# functions. It is used to specify the tracing range of training steps. +_profiler_step_id = 0 + +# A global variable to avoid parsing from string every time. +_profiler_options = None + + +class ProfilerOptions(object): + ''' + Use a string to initialize a ProfilerOptions. + The string should be in the format: "key1=value1;key2=value;key3=value3". + For example: + "profile_path=model.profile" + "batch_range=[50, 60]; profile_path=model.profile" + "batch_range=[50, 60]; tracer_option=OpDetail; profile_path=model.profile" + ProfilerOptions supports following key-value pair: + batch_range - a integer list, e.g. [100, 110]. + state - a string, the optional values are 'CPU', 'GPU' or 'All'. + sorted_key - a string, the optional values are 'calls', 'total', + 'max', 'min' or 'ave. + tracer_option - a string, the optional values are 'Default', 'OpDetail', + 'AllOpDetail'. + profile_path - a string, the path to save the serialized profile data, + which can be used to generate a timeline. + exit_on_finished - a boolean. + ''' + + def __init__(self, options_str): + assert isinstance(options_str, str) + + self._options = { + 'batch_range': [10, 20], + 'state': 'All', + 'sorted_key': 'total', + 'tracer_option': 'Default', + 'profile_path': '/tmp/profile', + 'exit_on_finished': True + } + self._parse_from_string(options_str) + + def _parse_from_string(self, options_str): + for kv in options_str.replace(' ', '').split(';'): + key, value = kv.split('=') + if key == 'batch_range': + value_list = value.replace('[', '').replace(']', '').split(',') + value_list = list(map(int, value_list)) + if len(value_list) >= 2 and value_list[0] >= 0 and value_list[ + 1] > value_list[0]: + self._options[key] = value_list + elif key == 'exit_on_finished': + self._options[key] = value.lower() in ("yes", "true", "t", "1") + elif key in [ + 'state', 'sorted_key', 'tracer_option', 'profile_path' + ]: + self._options[key] = value + + def __getitem__(self, name): + if self._options.get(name, None) is None: + raise ValueError( + "ProfilerOptions does not have an option named %s." % name) + return self._options[name] + + +def add_profiler_step(options_str=None): + ''' + Enable the operator-level timing using PaddlePaddle's profiler. + The profiler uses a independent variable to count the profiler steps. + One call of this function is treated as a profiler step. + + Args: + profiler_options - a string to initialize the ProfilerOptions. + Default is None, and the profiler is disabled. + ''' + if options_str is None: + return + + global _profiler_step_id + global _profiler_options + + if _profiler_options is None: + _profiler_options = ProfilerOptions(options_str) + + if _profiler_step_id == _profiler_options['batch_range'][0]: + paddle.utils.profiler.start_profiler( + _profiler_options['state'], _profiler_options['tracer_option']) + elif _profiler_step_id == _profiler_options['batch_range'][1]: + paddle.utils.profiler.stop_profiler(_profiler_options['sorted_key'], + _profiler_options['profile_path']) + if _profiler_options['exit_on_finished']: + sys.exit(0) + + _profiler_step_id += 1 diff --git a/backend/ppocr/utils/save_load.py b/backend/ppocr/utils/save_load.py index 02814d62..b09f1db6 100644 --- a/backend/ppocr/utils/save_load.py +++ b/backend/ppocr/utils/save_load.py @@ -23,7 +23,9 @@ import paddle -__all__ = ['init_model', 'save_model', 'load_dygraph_pretrain'] +from ppocr.utils.logging import get_logger + +__all__ = ['load_model'] def _mkdir_if_not_exist(path, logger): @@ -42,58 +44,74 @@ def _mkdir_if_not_exist(path, logger): raise OSError('Failed to mkdir {}'.format(path)) -def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False): - if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')): - raise ValueError("Model pretrain path {} does not " - "exists.".format(path)) - if load_static_weights: - pre_state_dict = paddle.static.load_program_state(path) - param_state_dict = {} - model_dict = model.state_dict() - for key in model_dict.keys(): - weight_name = model_dict[key].name - weight_name = weight_name.replace('binarize', '').replace( - 'thresh', '') # for DB - if weight_name in pre_state_dict.keys(): - # logger.info('Load weight: {}, shape: {}'.format( - # weight_name, pre_state_dict[weight_name].shape)) - if 'encoder_rnn' in key: - # delete axis which is 1 - pre_state_dict[weight_name] = pre_state_dict[ - weight_name].squeeze() - # change axis - if len(pre_state_dict[weight_name].shape) > 1: - pre_state_dict[weight_name] = pre_state_dict[ - weight_name].transpose((1, 0)) - param_state_dict[key] = pre_state_dict[weight_name] - else: - param_state_dict[key] = model_dict[key] - model.set_state_dict(param_state_dict) - return - - param_state_dict = paddle.load(path + '.pdparams') - model.set_state_dict(param_state_dict) - return - - -def init_model(config, model, logger, optimizer=None, lr_scheduler=None): +def load_model(config, model, optimizer=None, model_type='det'): """ load model from checkpoint or pretrained_model """ - gloabl_config = config['Global'] - checkpoints = gloabl_config.get('checkpoints') - pretrained_model = gloabl_config.get('pretrained_model') + logger = get_logger() + global_config = config['Global'] + checkpoints = global_config.get('checkpoints') + pretrained_model = global_config.get('pretrained_model') best_model_dict = {} + + if model_type == 'vqa': + checkpoints = config['Architecture']['Backbone']['checkpoints'] + # load vqa method metric + if checkpoints: + if os.path.exists(os.path.join(checkpoints, 'metric.states')): + with open(os.path.join(checkpoints, 'metric.states'), + 'rb') as f: + states_dict = pickle.load(f) if six.PY2 else pickle.load( + f, encoding='latin1') + best_model_dict = states_dict.get('best_model_dict', {}) + if 'epoch' in states_dict: + best_model_dict['start_epoch'] = states_dict['epoch'] + 1 + logger.info("resume from {}".format(checkpoints)) + + if optimizer is not None: + if checkpoints[-1] in ['/', '\\']: + checkpoints = checkpoints[:-1] + if os.path.exists(checkpoints + '.pdopt'): + optim_dict = paddle.load(checkpoints + '.pdopt') + optimizer.set_state_dict(optim_dict) + else: + logger.warning( + "{}.pdopt is not exists, params of optimizer is not loaded". + format(checkpoints)) + return best_model_dict + if checkpoints: + if checkpoints.endswith('.pdparams'): + checkpoints = checkpoints.replace('.pdparams', '') assert os.path.exists(checkpoints + ".pdparams"), \ - "Given dir {}.pdparams not exist.".format(checkpoints) - assert os.path.exists(checkpoints + ".pdopt"), \ - "Given dir {}.pdopt not exist.".format(checkpoints) - para_dict = paddle.load(checkpoints + '.pdparams') - opti_dict = paddle.load(checkpoints + '.pdopt') - model.set_state_dict(para_dict) + "The {}.pdparams does not exists!".format(checkpoints) + + # load params from trained model + params = paddle.load(checkpoints + '.pdparams') + state_dict = model.state_dict() + new_state_dict = {} + for key, value in state_dict.items(): + if key not in params: + logger.warning("{} not in loaded params {} !".format( + key, params.keys())) + continue + pre_value = params[key] + if list(value.shape) == list(pre_value.shape): + new_state_dict[key] = pre_value + else: + logger.warning( + "The shape of model params {} {} not matched with loaded params shape {} !". + format(key, value.shape, pre_value.shape)) + model.set_state_dict(new_state_dict) + if optimizer is not None: - optimizer.set_state_dict(opti_dict) + if os.path.exists(checkpoints + '.pdopt'): + optim_dict = paddle.load(checkpoints + '.pdopt') + optimizer.set_state_dict(optim_dict) + else: + logger.warning( + "{}.pdopt is not exists, params of optimizer is not loaded". + format(checkpoints)) if os.path.exists(checkpoints + '.states'): with open(checkpoints + '.states', 'rb') as f: @@ -102,29 +120,44 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None): best_model_dict = states_dict.get('best_model_dict', {}) if 'epoch' in states_dict: best_model_dict['start_epoch'] = states_dict['epoch'] + 1 - logger.info("resume from {}".format(checkpoints)) elif pretrained_model: - load_static_weights = gloabl_config.get('load_static_weights', False) - if not isinstance(pretrained_model, list): - pretrained_model = [pretrained_model] - if not isinstance(load_static_weights, list): - load_static_weights = [load_static_weights] * len(pretrained_model) - for idx, pretrained in enumerate(pretrained_model): - load_static = load_static_weights[idx] - load_dygraph_pretrain( - model, logger, path=pretrained, load_static_weights=load_static) - logger.info("load pretrained model from {}".format( - pretrained_model)) + load_pretrained_params(model, pretrained_model) else: logger.info('train from scratch') return best_model_dict -def save_model(net, +def load_pretrained_params(model, path): + logger = get_logger() + if path.endswith('.pdparams'): + path = path.replace('.pdparams', '') + assert os.path.exists(path + ".pdparams"), \ + "The {}.pdparams does not exists!".format(path) + + params = paddle.load(path + '.pdparams') + state_dict = model.state_dict() + new_state_dict = {} + for k1 in params.keys(): + if k1 not in state_dict.keys(): + logger.warning("The pretrained params {} not in model".format(k1)) + else: + if list(state_dict[k1].shape) == list(params[k1].shape): + new_state_dict[k1] = params[k1] + else: + logger.warning( + "The shape of model params {} {} not matched with loaded params {} {} !". + format(k1, state_dict[k1].shape, k1, params[k1].shape)) + model.set_state_dict(new_state_dict) + logger.info("load pretrain successful from {}".format(path)) + return model + + +def save_model(model, optimizer, model_path, logger, + config, is_best=False, prefix='ppocr', **kwargs): @@ -133,13 +166,20 @@ def save_model(net, """ _mkdir_if_not_exist(model_path, logger) model_prefix = os.path.join(model_path, prefix) - paddle.save(net.state_dict(), model_prefix + '.pdparams') paddle.save(optimizer.state_dict(), model_prefix + '.pdopt') - + if config['Architecture']["model_type"] != 'vqa': + paddle.save(model.state_dict(), model_prefix + '.pdparams') + metric_prefix = model_prefix + else: + if config['Global']['distributed']: + model._layers.backbone.model.save_pretrained(model_prefix) + else: + model.backbone.model.save_pretrained(model_prefix) + metric_prefix = os.path.join(model_prefix, 'metric') # save metric and config - with open(model_prefix + '.states', 'wb') as f: - pickle.dump(kwargs, f, protocol=2) if is_best: + with open(metric_prefix + '.states', 'wb') as f: + pickle.dump(kwargs, f, protocol=2) logger.info('save best model is to {}'.format(model_prefix)) else: logger.info("save model in {}".format(model_prefix)) diff --git a/backend/ppocr/utils/utility.py b/backend/ppocr/utils/utility.py index 6a746314..4a25ff8b 100755 --- a/backend/ppocr/utils/utility.py +++ b/backend/ppocr/utils/utility.py @@ -14,7 +14,11 @@ import logging import os +import imghdr import cv2 +import random +import numpy as np +import paddle def print_dict(d, logger, delimiter=0): @@ -45,23 +49,27 @@ def get_check_global_params(mode): return check_params +def _check_image_file(path): + img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif'} + return any([path.lower().endswith(e) for e in img_end]) + + def get_image_file_list(img_file): imgs_lists = [] if img_file is None or not os.path.exists(img_file): raise Exception("not found any img file in {}".format(img_file)) img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif'} - if os.path.isfile(img_file) and os.path.splitext(img_file)[-1][1:].lower( - ) in img_end: + if os.path.isfile(img_file) and _check_image_file(img_file): imgs_lists.append(img_file) elif os.path.isdir(img_file): for single_file in os.listdir(img_file): file_path = os.path.join(img_file, single_file) - if os.path.isfile(file_path) and os.path.splitext(file_path)[-1][ - 1:].lower() in img_end: + if os.path.isfile(file_path) and _check_image_file(file_path): imgs_lists.append(file_path) if len(imgs_lists) == 0: raise Exception("not found any img file in {}".format(img_file)) + imgs_lists = sorted(imgs_lists) return imgs_lists @@ -78,3 +86,46 @@ def check_and_read_gif(img_path): imgvalue = frame[:, :, ::-1] return imgvalue, True return None, False + + +def load_vqa_bio_label_maps(label_map_path): + with open(label_map_path, "r", encoding='utf-8') as fin: + lines = fin.readlines() + lines = [line.strip() for line in lines] + if "O" not in lines: + lines.insert(0, "O") + labels = [] + for line in lines: + if line == "O": + labels.append("O") + else: + labels.append("B-" + line) + labels.append("I-" + line) + label2id_map = {label: idx for idx, label in enumerate(labels)} + id2label_map = {idx: label for idx, label in enumerate(labels)} + return label2id_map, id2label_map + + +def set_seed(seed=1024): + random.seed(seed) + np.random.seed(seed) + paddle.seed(seed) + + +class AverageMeter: + def __init__(self): + self.reset() + + def reset(self): + """reset""" + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + """update""" + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count diff --git a/backend/ppocr/utils/visual.py b/backend/ppocr/utils/visual.py new file mode 100644 index 00000000..7a8c1674 --- /dev/null +++ b/backend/ppocr/utils/visual.py @@ -0,0 +1,98 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import numpy as np +from PIL import Image, ImageDraw, ImageFont + + +def draw_ser_results(image, + ocr_results, + font_path="doc/fonts/simfang.ttf", + font_size=18): + np.random.seed(2021) + color = (np.random.permutation(range(255)), + np.random.permutation(range(255)), + np.random.permutation(range(255))) + color_map = { + idx: (color[0][idx], color[1][idx], color[2][idx]) + for idx in range(1, 255) + } + if isinstance(image, np.ndarray): + image = Image.fromarray(image) + elif isinstance(image, str) and os.path.isfile(image): + image = Image.open(image).convert('RGB') + img_new = image.copy() + draw = ImageDraw.Draw(img_new) + + font = ImageFont.truetype(font_path, font_size, encoding="utf-8") + for ocr_info in ocr_results: + if ocr_info["pred_id"] not in color_map: + continue + color = color_map[ocr_info["pred_id"]] + text = "{}: {}".format(ocr_info["pred"], ocr_info["text"]) + + draw_box_txt(ocr_info["bbox"], text, draw, font, font_size, color) + + img_new = Image.blend(image, img_new, 0.5) + return np.array(img_new) + + +def draw_box_txt(bbox, text, draw, font, font_size, color): + # draw ocr results outline + bbox = ((bbox[0], bbox[1]), (bbox[2], bbox[3])) + draw.rectangle(bbox, fill=color) + + # draw ocr results + start_y = max(0, bbox[0][1] - font_size) + tw = font.getsize(text)[0] + draw.rectangle( + [(bbox[0][0] + 1, start_y), (bbox[0][0] + tw + 1, start_y + font_size)], + fill=(0, 0, 255)) + draw.text((bbox[0][0] + 1, start_y), text, fill=(255, 255, 255), font=font) + + +def draw_re_results(image, + result, + font_path="doc/fonts/simfang.ttf", + font_size=18): + np.random.seed(0) + if isinstance(image, np.ndarray): + image = Image.fromarray(image) + elif isinstance(image, str) and os.path.isfile(image): + image = Image.open(image).convert('RGB') + img_new = image.copy() + draw = ImageDraw.Draw(img_new) + + font = ImageFont.truetype(font_path, font_size, encoding="utf-8") + color_head = (0, 0, 255) + color_tail = (255, 0, 0) + color_line = (0, 255, 0) + + for ocr_info_head, ocr_info_tail in result: + draw_box_txt(ocr_info_head["bbox"], ocr_info_head["text"], draw, font, + font_size, color_head) + draw_box_txt(ocr_info_tail["bbox"], ocr_info_tail["text"], draw, font, + font_size, color_tail) + + center_head = ( + (ocr_info_head['bbox'][0] + ocr_info_head['bbox'][2]) // 2, + (ocr_info_head['bbox'][1] + ocr_info_head['bbox'][3]) // 2) + center_tail = ( + (ocr_info_tail['bbox'][0] + ocr_info_tail['bbox'][2]) // 2, + (ocr_info_tail['bbox'][1] + ocr_info_tail['bbox'][3]) // 2) + + draw.line([center_head, center_tail], fill=color_line, width=5) + + img_new = Image.blend(image, img_new, 0.5) + return np.array(img_new) diff --git a/backend/subfinder/linux/VideoSubFinderCli b/backend/subfinder/linux/VideoSubFinderCli new file mode 100755 index 00000000..74ff3846 Binary files /dev/null and b/backend/subfinder/linux/VideoSubFinderCli differ diff --git a/backend/subfinder/linux/VideoSubFinderCli.run b/backend/subfinder/linux/VideoSubFinderCli.run new file mode 100755 index 00000000..9f2ad6a5 --- /dev/null +++ b/backend/subfinder/linux/VideoSubFinderCli.run @@ -0,0 +1,5 @@ +#!/bin/sh +cd ${0%/*} +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$PWD:/lib64 +chmod +x ./VideoSubFinderCli +./VideoSubFinderCli "$@" diff --git a/backend/subfinder/linux/settings/general.cfg b/backend/subfinder/linux/settings/general.cfg new file mode 100644 index 00000000..da029178 --- /dev/null +++ b/backend/subfinder/linux/settings/general.cfg @@ -0,0 +1,35 @@ +prefered_locale = eng +dont_delete_unrecognized_images1 = 1 +dont_delete_unrecognized_images2 = 1 +generate_cleared_text_images_on_test = 1 +dump_debug_images = 0 +dump_debug_second_filtration_images = 0 +clear_test_images_folder = 1 +show_transformed_images_only = 0 +moderate_threshold = 0.4 +moderate_threshold_for_NEdges = 0.3 +segment_width = 8 +segment_height = 3 +minimum_segments_count = 2 +min_sum_color_diff = 500 +between_text_distace = 0.05 +text_centre_offset = 0.1 +min_points_number = 30 +min_points_density = 0.3 +min_symbol_height = 0.02 +min_symbol_density = 0.2 +min_NEdges_points_density = 0.25 +threads = 4 +sub_frame_length = 6 +text_procent = 0.3 +min_text_len_(in_procent) = 0.022 +sub_square_error = 0.3 +vedges_points_line_error = 0.35 +clear_image_logical = 0 +clean_rgb_images_after_run = 0 +def_string_for_empty_sub = sub duration: %sub_duration% +min_sub_duration = 0 +txt_dw = 5 +txt_dy = 5 +fount_size_ocr_lbl = 8 +fount_size_ocr_btn = 10 diff --git a/backend/subfinder/windows/VideoSubFinderWXW.exe b/backend/subfinder/windows/VideoSubFinderWXW.exe new file mode 100644 index 00000000..3234c138 Binary files /dev/null and b/backend/subfinder/windows/VideoSubFinderWXW.exe differ diff --git a/backend/subfinder/windows/avcodec-58.dll b/backend/subfinder/windows/avcodec-58.dll new file mode 100644 index 00000000..cd120d4f Binary files /dev/null and b/backend/subfinder/windows/avcodec-58.dll differ diff --git a/backend/subfinder/windows/avdevice-58.dll b/backend/subfinder/windows/avdevice-58.dll new file mode 100644 index 00000000..66093493 Binary files /dev/null and b/backend/subfinder/windows/avdevice-58.dll differ diff --git a/backend/subfinder/windows/avfilter-7.dll b/backend/subfinder/windows/avfilter-7.dll new file mode 100644 index 00000000..21f6cf3e Binary files /dev/null and b/backend/subfinder/windows/avfilter-7.dll differ diff --git a/backend/subfinder/windows/avformat-58.dll b/backend/subfinder/windows/avformat-58.dll new file mode 100644 index 00000000..dcc948fa Binary files /dev/null and b/backend/subfinder/windows/avformat-58.dll differ diff --git a/backend/subfinder/windows/avutil-56.dll b/backend/subfinder/windows/avutil-56.dll new file mode 100644 index 00000000..b90eee56 Binary files /dev/null and b/backend/subfinder/windows/avutil-56.dll differ diff --git a/backend/subfinder/windows/bitmaps/left_na.bmp b/backend/subfinder/windows/bitmaps/left_na.bmp new file mode 100644 index 00000000..a30b40d2 Binary files /dev/null and b/backend/subfinder/windows/bitmaps/left_na.bmp differ diff --git a/backend/subfinder/windows/bitmaps/left_od.bmp b/backend/subfinder/windows/bitmaps/left_od.bmp new file mode 100644 index 00000000..3c26b886 Binary files /dev/null and b/backend/subfinder/windows/bitmaps/left_od.bmp differ diff --git a/backend/subfinder/windows/bitmaps/right_na.bmp b/backend/subfinder/windows/bitmaps/right_na.bmp new file mode 100644 index 00000000..35112740 Binary files /dev/null and b/backend/subfinder/windows/bitmaps/right_na.bmp differ diff --git a/backend/subfinder/windows/bitmaps/right_od.bmp b/backend/subfinder/windows/bitmaps/right_od.bmp new file mode 100644 index 00000000..3e3ce81e Binary files /dev/null and b/backend/subfinder/windows/bitmaps/right_od.bmp differ diff --git a/backend/subfinder/windows/bitmaps/sb_la.bmp b/backend/subfinder/windows/bitmaps/sb_la.bmp new file mode 100644 index 00000000..5bb65be4 Binary files /dev/null and b/backend/subfinder/windows/bitmaps/sb_la.bmp differ diff --git a/backend/subfinder/windows/bitmaps/sb_lc.bmp b/backend/subfinder/windows/bitmaps/sb_lc.bmp new file mode 100644 index 00000000..d92e411d Binary files /dev/null and b/backend/subfinder/windows/bitmaps/sb_lc.bmp differ diff --git a/backend/subfinder/windows/bitmaps/sb_ra.bmp b/backend/subfinder/windows/bitmaps/sb_ra.bmp new file mode 100644 index 00000000..d55a4758 Binary files /dev/null and b/backend/subfinder/windows/bitmaps/sb_ra.bmp differ diff --git a/backend/subfinder/windows/bitmaps/sb_rc.bmp b/backend/subfinder/windows/bitmaps/sb_rc.bmp new file mode 100644 index 00000000..609256d3 Binary files /dev/null and b/backend/subfinder/windows/bitmaps/sb_rc.bmp differ diff --git a/backend/subfinder/windows/bitmaps/sb_t.bmp b/backend/subfinder/windows/bitmaps/sb_t.bmp new file mode 100644 index 00000000..ecdd46d7 Binary files /dev/null and b/backend/subfinder/windows/bitmaps/sb_t.bmp differ diff --git a/backend/subfinder/windows/bitmaps/tb_pause.bmp b/backend/subfinder/windows/bitmaps/tb_pause.bmp new file mode 100644 index 00000000..2a15732a Binary files /dev/null and b/backend/subfinder/windows/bitmaps/tb_pause.bmp differ diff --git a/backend/subfinder/windows/bitmaps/tb_run.bmp b/backend/subfinder/windows/bitmaps/tb_run.bmp new file mode 100644 index 00000000..69253adc Binary files /dev/null and b/backend/subfinder/windows/bitmaps/tb_run.bmp differ diff --git a/backend/subfinder/windows/bitmaps/tb_stop.bmp b/backend/subfinder/windows/bitmaps/tb_stop.bmp new file mode 100644 index 00000000..d73a638f Binary files /dev/null and b/backend/subfinder/windows/bitmaps/tb_stop.bmp differ diff --git a/backend/subfinder/windows/cudart64_110.dll b/backend/subfinder/windows/cudart64_110.dll new file mode 100644 index 00000000..6e795cf5 Binary files /dev/null and b/backend/subfinder/windows/cudart64_110.dll differ diff --git a/backend/subfinder/windows/finished.wav b/backend/subfinder/windows/finished.wav new file mode 100644 index 00000000..d6e8baa8 Binary files /dev/null and b/backend/subfinder/windows/finished.wav differ diff --git a/backend/subfinder/windows/nppc64_11.dll b/backend/subfinder/windows/nppc64_11.dll new file mode 100644 index 00000000..064adce4 Binary files /dev/null and b/backend/subfinder/windows/nppc64_11.dll differ diff --git a/backend/subfinder/windows/nppicc64_11.dll b/backend/subfinder/windows/nppicc64_11.dll new file mode 100644 index 00000000..3abe17b7 Binary files /dev/null and b/backend/subfinder/windows/nppicc64_11.dll differ diff --git a/backend/subfinder/windows/nppig64_11.dll b/backend/subfinder/windows/nppig64_11.dll new file mode 100644 index 00000000..0859df57 Binary files /dev/null and b/backend/subfinder/windows/nppig64_11.dll differ diff --git a/backend/subfinder/windows/opencv_videoio_ffmpeg430_64.dll b/backend/subfinder/windows/opencv_videoio_ffmpeg430_64.dll new file mode 100644 index 00000000..af1ae6a5 Binary files /dev/null and b/backend/subfinder/windows/opencv_videoio_ffmpeg430_64.dll differ diff --git a/backend/subfinder/windows/opencv_world430.dll b/backend/subfinder/windows/opencv_world430.dll new file mode 100644 index 00000000..2e47847f Binary files /dev/null and b/backend/subfinder/windows/opencv_world430.dll differ diff --git a/backend/subfinder/windows/postproc-55.dll b/backend/subfinder/windows/postproc-55.dll new file mode 100644 index 00000000..5692b480 Binary files /dev/null and b/backend/subfinder/windows/postproc-55.dll differ diff --git a/backend/subfinder/windows/previous_video.inf b/backend/subfinder/windows/previous_video.inf new file mode 100644 index 00000000..05936eb7 --- /dev/null +++ b/backend/subfinder/windows/previous_video.inf @@ -0,0 +1,4 @@ +C:\Users\fangyao\Downloads\test.mp4 +0 +53766 +0 \ No newline at end of file diff --git a/backend/subfinder/windows/settings/eng/locale.cfg b/backend/subfinder/windows/settings/eng/locale.cfg new file mode 100644 index 00000000..d7d24fe7 --- /dev/null +++ b/backend/subfinder/windows/settings/eng/locale.cfg @@ -0,0 +1,102 @@ +label_text_alignment = Text Alignment +ocr_label_msd_text = Min Sub Duration +ocr_label_jsact_text = Join Subs And Correct Time +ocr_label_clear_txt_folders = Clear TXT Folders Before Run +ocr_button_ccti_text = Create Cleared TXTImages +ocr_button_csftr_text = Create Sub From TXTResults +ocr_button_cesfcti_text = Create Empty Sub From Cleared TXTImages +ocr_button_ces_text = Create Empty Sub From RGBImages +ocr_button_join_text = Join TXTImages +ocr_button_test_text = Test +ocr_label_save_each_substring_separately = Save Each Substring Separately +ocr_label_save_scaled_images = Save Scaled Images +ssp_label_parameters_influencing_image_processing = Parameters Influencing Image Processing +ssp_label_ocl_and_multiframe_image_stream_processing = OCR and Multiframe Image Stream Processing +ssp_oi_group_global_image_processing_settings = Global Image Processing Settings +ssp_oi_property_use_ocl = Use OCL In OpenCV +ssp_oi_property_use_cuda_gpu = Use CUDA GPU Acceleration +ssp_oi_property_image_scale_for_clear_image = Image Scale For Clear Image +ssp_oi_property_cpu_kmeans_initial_loop_iterations = CPU kmeans initial loop iterations +ssp_oi_property_cpu_kmeans_loop_iterations = CPU kmeans loop iterations +ssp_oi_property_cuda_kmeans_initial_loop_iterations = CUDA kmeans initial loop iterations +ssp_oi_property_cuda_kmeans_loop_iterations = CUDA kmeans loop iterations +ssp_oi_property_generate_cleared_text_images_on_test = Generate Cleared Text Images On Test Button +ssp_oi_property_dump_debug_images = Dump Debug Images +ssp_oi_property_dump_debug_second_filtration_images = Dump Debug Secondary Processing Images +ssp_oi_property_clear_test_images_folder = Clear Test Images Folder +ssp_oi_property_show_transformed_images_only = Show Transformed Images Only +ssp_oi_group_initial_image_processing = Initial Image Processing +ssp_oi_sub_group_settings_for_sobel_operators = Settings For Sobel Operators +ssp_oi_property_moderate_threshold = Moderate Threshold +ssp_oi_property_moderate_nedges_threshold = Moderate NEdges Threshold +ssp_oi_sub_group_settings_for_color_filtering = Settings For Color Filtering +ssp_oi_property_segment_width = Line Segment Width +ssp_oi_property_min_segments_count = Min Segments Count +ssp_oi_property_min_sum_color_difference = Min Sum Color Difference +ssp_oi_group_secondary_image_processing = Secondary Image Processing +ssp_oi_sub_group_settings_for_linear_filtering = Settings For Linear Filtering +ssp_oi_property_line_height = Line Segment Height +ssp_oi_property_max_between_text_distance = Max Between Text Distance +ssp_oi_property_max_text_center_offset = Max Text Offset +ssp_oi_property_max_text_center_percent_offset = Max Text Center Percent Offset +ssp_oi_sub_group_settings_for_color_border_points = Settings For Color Border Points +ssp_oi_property_min_points_number = Min Points Number +ssp_oi_property_min_points_density = Min Points Density +ssp_oi_property_min_symbol_height = Min Symbol Height (in % to Full Image Height) +ssp_oi_property_min_symbol_density = Min Symbol Density (in % to Its Size) +ssp_oi_property_min_vedges_points_density = Min VEdges points density +ssp_oi_property_min_nedges_points_density = Min NEdges points density +ssp_oi_property_min_sum_multiple_color_difference = Min Sum Multiple Color Difference +ssp_oi_group_tertiary_image_processing = Tertiary Image Processing +ssp_oi_property_min_vedges_points_density_per_half_line = Min VEdges points density (per half line) +ssp_oi_property_min_hedges_points_density_per_half_line = Min HEdges points density (per half line) +ssp_oi_property_min_nedges_points_density_per_half_line = Min NEdges points density (per half line) +ssp_oim_group_ocr_settings = OCR settings +ssp_oim_property_clear_images_logical = Clear Images Logical;(don't use on Hieroglyph or Arabic subs) +ssp_oim_property_clear_rgbimages_after_search_subtitles = Clear RGBImages after search subtitles +ssp_oim_property_using_isaimages_for_getting_txt_areas = Use ISAImages for getting TXT areas +ssp_oim_property_using_ilaimages_for_getting_txt_areas = Use ILAImages for getting TXT areas +label_ILA_images_for_getting_txt_symbols_areas = Use ILAImages for getting TXT symbols areas +label_use_ILA_images_before_clear_txt_images_from_borders = Use ILAImages before clear TXT images from borders +ssp_oim_property_validate_and_compare_cleared_txt_images = (NotRealized) Validate And Compare Cleared TXTImages +ssp_oim_property_dont_delete_unrecognized_images_first = Don't Delete Unrecognized Images (First) +ssp_oim_property_dont_delete_unrecognized_images_second = Don't Delete Unrecognized Images (Second) +ssp_oim_property_default_string_for_empty_sub = Default string for empty sub +ssp_oim_group_settings_for_multiframe_image_processing = Settings For Multi-Frame Image Processing +ssp_oim_sub_group_settings_for_sub_detection = Settings For Sub Detection +ssp_oim_property_threads = Number Of Parallel Tasks;(For Run Search) +ssp_ocr_threads = Number Of Parallel Tasks;(For Create Cleared TXTImages) +ssp_oim_property_sub_frames_length = Sub Frames Length +ssp_oim_property_use_ILA_images_for_search_subtitles = Use ILAImages for search subtitles +ssp_oim_property_use_ISA_images_for_search_subtitles = Analyze ISAImages for sub presence +ssp_oim_property_replace_ISA_by_filtered_version = Replace ISAImages by filtered version +ssp_oim_property_max_dl_down = Max luminance diff from down for IL image generation +ssp_oim_property_max_dl_up = Max luminance diff from up for IL image generation +ssp_oim_sub_group_settings_for_comparing_subs = Settings For Comparing Subs +ssp_oim_property_vedges_points_line_error = VEdges Points line error +ssp_oim_property_ila_points_line_error = ILA Points line error +ssp_oim_sub_group_settings_for_checking_sub = Settings For Checking Sub +ssp_oim_property_text_percent = Text Percent +ssp_oim_property_min_text_length = Min Text Length +ssp_oim_property_use_gradient_images_for_clear_txt_images = Use Gradient Images For Clear TXTImages +ssp_oim_property_use_ILA_images_for_clear_txt_images = Use ILAImages For Clear TXTImages +ssp_oim_property_clear_txt_images_by_main_color = Clear TXTImages By Main Color +ssp_oi_property_moderate_threshold_for_scaled_image = Moderate Threshold For Scaled Image +ssp_oim_property_remove_wide_symbols = Remove too wide symbols;(don't use for Arabic or handwritten subs) +ssp_hw_device = FFMPEG HW Devices +label_filter_descr = FFMPEG Video Filters +label_settings_file = Current Settings File +label_playback_sound = Playback Sound On Task Finished +label_border_is_darker = Characters Border Is Darker +label_extend_by_grey_color = Extend By Grey Color;(try to use in case of subs with unstable luminance) +label_allow_min_luminance = Allow Min Luminance;(used only if "Extend By Grey Color" is set) +ssp_oim_sub_group_settings_for_update_video_color = Settings For Update Video Color +label_video_contrast = Video Contrast +label_video_gamma = Video Gamma +label_pixel_color = Pixel Color;(By 'Left Mouse Click' in Video Box) +label_use_filter_color = Use Filter Colors;(Use 'Ctrl+Enter' for add New Line);(Press and hold 'T'/'R'/'Y'/'U' button in Video Box for check) +label_use_outline_filter_color = Use Outline Filter Colors;(Use 'Ctrl+Enter' for add New Line);(Press and hold 'T'/'R'/'I'/'U' button in Video Box for check) +label_dL_color = default dL For RGB and Lab Filter Colors +label_dA_color = default dA For Lab Filter Colors +label_dB_color = default dB For Lab Filter Colors +label_combine_to_single_cluster = Combine To Single Cluster;(can be used in case of multiple colors in single line) diff --git a/backend/subfinder/windows/settings/general.cfg b/backend/subfinder/windows/settings/general.cfg new file mode 100644 index 00000000..5e965749 --- /dev/null +++ b/backend/subfinder/windows/settings/general.cfg @@ -0,0 +1,82 @@ +prefered_locale = eng +ocr_join_txt_images_split_line = [begin_time] --> [end_time] +process_affinity_mask = -1 +fount_size_lbl = 10 +fount_size_btn = 13 +dont_delete_unrecognized_images1 = 0 +dont_delete_unrecognized_images2 = 1 +generate_cleared_text_images_on_test = 1 +dump_debug_images = 0 +dump_debug_second_filtration_images = 0 +clear_test_images_folder = 1 +show_transformed_images_only = 0 +use_ocl = 1 +use_cuda_gpu = 0 +use_filter_color = none +use_outline_filter_color = none +dL_color = 40 +dA_color = 30 +dB_color = 30 +combine_to_single_cluster = 0 +cuda_kmeans_initial_loop_iterations = 20 +cuda_kmeans_loop_iterations = 30 +cpu_kmeans_initial_loop_iterations = 20 +cpu_kmeans_loop_iterations = 30 +moderate_threshold_for_scaled_image = 0.25 +moderate_threshold = 0.25 +moderate_threshold_for_NEdges = 0.25 +segment_width = 8 +segment_height = 3 +minimum_segments_count = 2 +min_sum_color_diff = 0 +between_text_distace = 0.07 +text_centre_offset = 0.2 +image_scale_for_clear_image = 4 +use_ISA_images = 1 +use_ILA_images = 1 +use_ILA_images_for_getting_txt_symbols_areas = 0 +use_ILA_images_before_clear_txt_images_from_borders = 0 +use_gradient_images_for_clear_txt_images = 1 +clear_txt_images_by_main_color = 1 +use_ILA_images_for_clear_txt_images = 1 +min_points_number = 30 +min_points_density = 0.3 +min_symbol_height = 0.02 +min_symbol_density = 0.2 +min_NEdges_points_density = 0.2 +threads = -1 +ocr_threads = -1 +sub_frame_length = 6 +text_percent = 0.3 +min_text_len_in_percent = 0.022 +vedges_points_line_error = 0.3 +ila_points_line_error = 0.3 +video_contrast = 1 +video_gamma = 1 +clear_txt_folders = 1 +join_subs_and_correct_time = 1 +clear_image_logical = 0 +clean_rgb_images_after_run = 0 +def_string_for_empty_sub = sub duration: %sub_duration% +min_sub_duration = 0 +txt_dw = 5 +txt_dy = 5 +use_ISA_images_for_search_subtitles = 1 +use_ILA_images_for_search_subtitles = 1 +replace_ISA_by_filtered_version = 1 +max_dl_down = 20 +max_dl_up = 40 +remove_wide_symbols = 0 +hw_device = cpu +filter_descr = none +text_alignment = Center +save_each_substring_separately = 0 +save_scaled_images = 1 +playback_sound = 0 +border_is_darker = 1 +extend_by_grey_color = 0 +allow_min_luminance = 100 +bottom_video_image_percent_end = 0 +top_video_image_percent_end = 0.3 +left_video_image_percent_end = 0 +right_video_image_percent_end = 1 diff --git a/backend/subfinder/windows/swresample-3.dll b/backend/subfinder/windows/swresample-3.dll new file mode 100644 index 00000000..51134cf1 Binary files /dev/null and b/backend/subfinder/windows/swresample-3.dll differ diff --git a/backend/subfinder/windows/swscale-5.dll b/backend/subfinder/windows/swscale-5.dll new file mode 100644 index 00000000..40e591e9 Binary files /dev/null and b/backend/subfinder/windows/swscale-5.dll differ diff --git a/backend/tools/NotoSansCJK-Bold.otf b/backend/tools/NotoSansCJK-Bold.otf new file mode 100644 index 00000000..7f666ddb Binary files /dev/null and b/backend/tools/NotoSansCJK-Bold.otf differ diff --git a/backend/tools/__init__.py b/backend/tools/__init__.py new file mode 100644 index 00000000..d56c9dba --- /dev/null +++ b/backend/tools/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/backend/tools/constant.py b/backend/tools/constant.py new file mode 100644 index 00000000..6b8f6637 --- /dev/null +++ b/backend/tools/constant.py @@ -0,0 +1,26 @@ +from enum import Enum + + +# 默认字幕出现的大致区域 +class SubtitleArea(Enum): + # 字幕区域出现在下半部分 + LOWER_PART = 0 + # 字幕区域出现在上半部分 + UPPER_PART = 1 + # 不知道字幕区域可能出现的位置 + UNKNOWN = 2 + # 明确知道字幕区域出现的位置 + CUSTOM = 3 + + +class BackgroundColor(Enum): + # 字幕背景 + WHITE = 0 + DARK = 1 + UNKNOWN = 2 + + +BGR_COLOR_GREEN = (0, 0xff, 0) +BGR_COLOR_BLUE = (0xff, 0, 0) +BGR_COLOR_RED = (0, 0, 0xff) +BGR_COLOR_WHITE = (0xff, 0xff, 0xff) diff --git a/backend/tools/eval.py b/backend/tools/eval.py index 9817fa75..cab28334 100755 --- a/backend/tools/eval.py +++ b/backend/tools/eval.py @@ -20,15 +20,14 @@ import sys __dir__ = os.path.dirname(os.path.abspath(__file__)) -sys.path.append(__dir__) -sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) +sys.path.insert(0, __dir__) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..'))) from ppocr.data import build_dataloader from ppocr.modeling.architectures import build_model from ppocr.postprocess import build_post_process from ppocr.metrics import build_metric -from ppocr.utils.save_load import init_model -from ppocr.utils.utility import print_dict +from ppocr.utils.save_load import load_model import tools.program as program @@ -44,12 +43,51 @@ def main(): # build model # for rec algorithm if hasattr(post_process_class, 'character'): - config['Architecture']["Head"]['out_channels'] = len( - getattr(post_process_class, 'character')) + char_num = len(getattr(post_process_class, 'character')) + if config['Architecture']["algorithm"] in ["Distillation", + ]: # distillation model + for key in config['Architecture']["Models"]: + if config['Architecture']['Models'][key]['Head'][ + 'name'] == 'MultiHead': # for multi head + out_channels_list = {} + if config['PostProcess'][ + 'name'] == 'DistillationSARLabelDecode': + char_num = char_num - 2 + out_channels_list['CTCLabelDecode'] = char_num + out_channels_list['SARLabelDecode'] = char_num + 2 + config['Architecture']['Models'][key]['Head'][ + 'out_channels_list'] = out_channels_list + else: + config['Architecture']["Models"][key]["Head"][ + 'out_channels'] = char_num + elif config['Architecture']['Head'][ + 'name'] == 'MultiHead': # for multi head + out_channels_list = {} + if config['PostProcess']['name'] == 'SARLabelDecode': + char_num = char_num - 2 + out_channels_list['CTCLabelDecode'] = char_num + out_channels_list['SARLabelDecode'] = char_num + 2 + config['Architecture']['Head'][ + 'out_channels_list'] = out_channels_list + else: # base rec model + config['Architecture']["Head"]['out_channels'] = char_num + model = build_model(config['Architecture']) - use_srn = config['Architecture']['algorithm'] == "SRN" + extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR"] + extra_input = False + if config['Architecture']['algorithm'] == 'Distillation': + for key in config['Architecture']["Models"]: + extra_input = extra_input or config['Architecture']['Models'][key][ + 'algorithm'] in extra_input_models + else: + extra_input = config['Architecture']['algorithm'] in extra_input_models + if "model_type" in config['Architecture'].keys(): + model_type = config['Architecture']['model_type'] + else: + model_type = None - best_model_dict = init_model(config, model, logger) + best_model_dict = load_model( + config, model, model_type=config['Architecture']["model_type"]) if len(best_model_dict): logger.info('metric in ckpt ***************') for k, v in best_model_dict.items(): @@ -57,10 +95,9 @@ def main(): # build metric eval_class = build_metric(config['Metric']) - # start eval metric = program.eval(model, valid_dataloader, post_process_class, - eval_class, use_srn) + eval_class, model_type, extra_input) logger.info('metric eval ***************') for k, v in metric.items(): logger.info('{}:{}'.format(k, v)) diff --git a/backend/tools/export_center.py b/backend/tools/export_center.py new file mode 100644 index 00000000..9a6372f1 --- /dev/null +++ b/backend/tools/export_center.py @@ -0,0 +1,76 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys +import pickle + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) + +from ppocr.data import build_dataloader +from ppocr.modeling.architectures import build_model +from ppocr.postprocess import build_post_process +from ppocr.utils.save_load import load_model +import tools.program as program + + +def main(): + global_config = config['Global'] + # build dataloader + config['Eval']['dataset']['name'] = config['Train']['dataset']['name'] + config['Eval']['dataset']['data_dir'] = config['Train']['dataset'][ + 'data_dir'] + config['Eval']['dataset']['label_file_list'] = config['Train']['dataset'][ + 'label_file_list'] + eval_dataloader = build_dataloader(config, 'Eval', device, logger) + + # build post process + post_process_class = build_post_process(config['PostProcess'], + global_config) + + # build model + # for rec algorithm + if hasattr(post_process_class, 'character'): + char_num = len(getattr(post_process_class, 'character')) + config['Architecture']["Head"]['out_channels'] = char_num + + #set return_features = True + config['Architecture']["Head"]["return_feats"] = True + + model = build_model(config['Architecture']) + + best_model_dict = load_model(config, model) + if len(best_model_dict): + logger.info('metric in ckpt ***************') + for k, v in best_model_dict.items(): + logger.info('{}:{}'.format(k, v)) + + # get features from train data + char_center = program.get_center(model, eval_dataloader, post_process_class) + + #serialize to disk + with open("train_center.pkl", 'wb') as f: + pickle.dump(char_center, f) + return + + +if __name__ == '__main__': + config, device, logger, vdl_writer = program.preprocess() + main() diff --git a/backend/tools/export_model.py b/backend/tools/export_model.py index 1e9526e0..76c716e0 100755 --- a/backend/tools/export_model.py +++ b/backend/tools/export_model.py @@ -17,7 +17,7 @@ __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) -sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) +sys.path.append(os.path.abspath(os.path.join(__dir__, ".."))) import argparse @@ -26,75 +26,146 @@ from ppocr.modeling.architectures import build_model from ppocr.postprocess import build_post_process -from ppocr.utils.save_load import init_model +from ppocr.utils.save_load import load_model from ppocr.utils.logging import get_logger from tools.program import load_config, merge_config, ArgsParser -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("-c", "--config", help="configuration file to use") - parser.add_argument( - "-o", "--output_path", type=str, default='./output/infer/') - return parser.parse_args() - - -def main(): - FLAGS = ArgsParser().parse_args() - config = load_config(FLAGS.config) - merge_config(FLAGS.opt) - logger = get_logger() - # build post process - - post_process_class = build_post_process(config['PostProcess'], - config['Global']) - - # build model - # for rec algorithm - if hasattr(post_process_class, 'character'): - char_num = len(getattr(post_process_class, 'character')) - config['Architecture']["Head"]['out_channels'] = char_num - model = build_model(config['Architecture']) - init_model(config, model, logger) - model.eval() - - save_path = '{}/inference'.format(config['Global']['save_inference_dir']) - - if config['Architecture']['algorithm'] == "SRN": +def export_single_model(model, arch_config, save_path, logger, quanter=None): + if arch_config["algorithm"] == "SRN": + max_text_length = arch_config["Head"]["max_text_length"] other_shape = [ paddle.static.InputSpec( - shape=[None, 1, 64, 256], dtype='float32'), [ + shape=[None, 1, 64, 256], dtype="float32"), [ paddle.static.InputSpec( shape=[None, 256, 1], dtype="int64"), paddle.static.InputSpec( - shape=[None, 25, 1], - dtype="int64"), paddle.static.InputSpec( - shape=[None, 8, 25, 25], dtype="int64"), + shape=[None, max_text_length, 1], dtype="int64"), paddle.static.InputSpec( - shape=[None, 8, 25, 25], dtype="int64") + shape=[None, 8, max_text_length, max_text_length], + dtype="int64"), paddle.static.InputSpec( + shape=[None, 8, max_text_length, max_text_length], + dtype="int64") ] ] model = to_static(model, input_spec=other_shape) + elif arch_config["algorithm"] == "SAR": + other_shape = [ + paddle.static.InputSpec( + shape=[None, 3, 48, 160], dtype="float32"), + ] + model = to_static(model, input_spec=other_shape) + elif arch_config["algorithm"] == "SVTR": + if arch_config["Head"]["name"] == 'MultiHead': + other_shape = [ + paddle.static.InputSpec( + shape=[None, 3, 48, -1], dtype="float32"), + ] + else: + other_shape = [ + paddle.static.InputSpec( + shape=[None, 3, 64, 256], dtype="float32"), + ] + model = to_static(model, input_spec=other_shape) + elif arch_config["algorithm"] == "PREN": + other_shape = [ + paddle.static.InputSpec( + shape=[None, 3, 64, 512], dtype="float32"), + ] + model = to_static(model, input_spec=other_shape) else: infer_shape = [3, -1, -1] - if config['Architecture']['model_type'] == "rec": + if arch_config["model_type"] == "rec": infer_shape = [3, 32, -1] # for rec model, H must be 32 - if 'Transform' in config['Architecture'] and config['Architecture'][ - 'Transform'] is not None and config['Architecture'][ - 'Transform']['name'] == 'TPS': + if "Transform" in arch_config and arch_config[ + "Transform"] is not None and arch_config["Transform"][ + "name"] == "TPS": logger.info( - 'When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training' + "When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training" ) infer_shape[-1] = 100 + if arch_config["algorithm"] == "NRTR": + infer_shape = [1, 32, 100] + elif arch_config["model_type"] == "table": + infer_shape = [3, 488, 488] model = to_static( model, input_spec=[ paddle.static.InputSpec( - shape=[None] + infer_shape, dtype='float32') + shape=[None] + infer_shape, dtype="float32") ]) - paddle.jit.save(model, save_path) - logger.info('inference model is saved to {}'.format(save_path)) + if quanter is None: + paddle.jit.save(model, save_path) + else: + quanter.save_quantized_model(model, save_path) + logger.info("inference model is saved to {}".format(save_path)) + return + + +def main(): + FLAGS = ArgsParser().parse_args() + config = load_config(FLAGS.settings_config) + config = merge_config(config, FLAGS.opt) + logger = get_logger() + # build post process + + post_process_class = build_post_process(config["PostProcess"], + config["Global"]) + + # build model + # for rec algorithm + if hasattr(post_process_class, "character"): + char_num = len(getattr(post_process_class, "character")) + if config["Architecture"]["algorithm"] in ["Distillation", + ]: # distillation model + for key in config["Architecture"]["Models"]: + if config["Architecture"]["Models"][key]["Head"][ + "name"] == 'MultiHead': # multi head + out_channels_list = {} + if config['PostProcess'][ + 'name'] == 'DistillationSARLabelDecode': + char_num = char_num - 2 + out_channels_list['CTCLabelDecode'] = char_num + out_channels_list['SARLabelDecode'] = char_num + 2 + config['Architecture']['Models'][key]['Head'][ + 'out_channels_list'] = out_channels_list + else: + config["Architecture"]["Models"][key]["Head"][ + "out_channels"] = char_num + # just one final tensor needs to exported for inference + config["Architecture"]["Models"][key][ + "return_all_feats"] = False + elif config['Architecture']['Head'][ + 'name'] == 'MultiHead': # multi head + out_channels_list = {} + char_num = len(getattr(post_process_class, 'character')) + if config['PostProcess']['name'] == 'SARLabelDecode': + char_num = char_num - 2 + out_channels_list['CTCLabelDecode'] = char_num + out_channels_list['SARLabelDecode'] = char_num + 2 + config['Architecture']['Head'][ + 'out_channels_list'] = out_channels_list + else: # base rec model + config["Architecture"]["Head"]["out_channels"] = char_num + + model = build_model(config["Architecture"]) + load_model(config, model) + model.eval() + + save_path = config["Global"]["save_inference_dir"] + + arch_config = config["Architecture"] + + if arch_config["algorithm"] in ["Distillation", ]: # distillation model + archs = list(arch_config["Models"].values()) + for idx, name in enumerate(model.model_name_list): + sub_model_save_path = os.path.join(save_path, name, "inference") + export_single_model(model.model_list[idx], archs[idx], + sub_model_save_path, logger) + else: + save_path = os.path.join(save_path, "inference") + export_single_model(model, arch_config, save_path, logger) if __name__ == "__main__": diff --git a/backend/tools/infer/predict_cls.py b/backend/tools/infer/predict_cls.py index 074172cc..ed2f47c0 100755 --- a/backend/tools/infer/predict_cls.py +++ b/backend/tools/infer/predict_cls.py @@ -16,7 +16,7 @@ __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) -sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..'))) os.environ["FLAGS_allocator_strategy"] = 'auto_growth' @@ -45,8 +45,9 @@ def __init__(self, args): "label_list": args.label_list, } self.postprocess_op = build_post_process(postprocess_params) - self.predictor, self.input_tensor, self.output_tensors = \ + self.predictor, self.input_tensor, self.output_tensors, _ = \ utility.create_predictor(args, 'cls', logger) + self.use_onnx = args.use_onnx def resize_norm_img(self, img): imgC, imgH, imgW = self.cls_image_shape @@ -84,9 +85,11 @@ def __call__(self, img_list): batch_num = self.cls_batch_num elapse = 0 for beg_img_no in range(0, img_num, batch_num): + end_img_no = min(img_num, beg_img_no + batch_num) norm_img_batch = [] max_wh_ratio = 0 + starttime = time.time() for ino in range(beg_img_no, end_img_no): h, w = img_list[indices[ino]].shape[0:2] wh_ratio = w * 1.0 / h @@ -97,11 +100,17 @@ def __call__(self, img_list): norm_img_batch.append(norm_img) norm_img_batch = np.concatenate(norm_img_batch) norm_img_batch = norm_img_batch.copy() - starttime = time.time() - self.input_tensor.copy_from_cpu(norm_img_batch) - self.predictor.run() - prob_out = self.output_tensors[0].copy_to_cpu() + if self.use_onnx: + input_dict = {} + input_dict[self.input_tensor.name] = norm_img_batch + outputs = self.predictor.run(self.output_tensors, input_dict) + prob_out = outputs[0] + else: + self.input_tensor.copy_from_cpu(norm_img_batch) + self.predictor.run() + prob_out = self.output_tensors[0].copy_to_cpu() + self.predictor.try_shrink_memory() cls_result = self.postprocess_op(prob_out) elapse += time.time() - starttime for rno in range(len(cls_result)): @@ -129,20 +138,13 @@ def main(args): img_list.append(img) try: img_list, cls_res, predict_time = text_classifier(img_list) - except: + except Exception as E: logger.info(traceback.format_exc()) - logger.info( - "ERROR!!!! \n" - "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n" - "If your model has tps module: " - "TPS does not support variable shape.\n" - "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ") + logger.info(E) exit() for ino in range(len(img_list)): logger.info("Predicts of {}:{}".format(valid_image_file_list[ino], cls_res[ino])) - logger.info("Total predict time for {} images, cost: {:.3f}".format( - len(img_list), predict_time)) if __name__ == "__main__": diff --git a/backend/tools/infer/predict_det.py b/backend/tools/infer/predict_det.py index b14825bd..5f2675d6 100755 --- a/backend/tools/infer/predict_det.py +++ b/backend/tools/infer/predict_det.py @@ -16,7 +16,7 @@ __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) -sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..'))) os.environ["FLAGS_allocator_strategy"] = 'auto_growth' @@ -30,7 +30,7 @@ from ppocr.utils.utility import get_image_file_list, check_and_read_gif from ppocr.data import create_operators, transform from ppocr.postprocess import build_post_process - +import json logger = get_logger() @@ -38,8 +38,12 @@ class TextDetector(object): def __init__(self, args): self.args = args self.det_algorithm = args.det_algorithm + self.use_onnx = args.use_onnx pre_process_list = [{ - 'DetResizeForTest': None + 'DetResizeForTest': { + 'limit_side_len': args.det_limit_side_len, + 'limit_type': args.det_limit_type, + } }, { 'NormalizeImage': { 'std': [0.229, 0.224, 0.225], @@ -62,6 +66,7 @@ def __init__(self, args): postprocess_params["max_candidates"] = 1000 postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio postprocess_params["use_dilation"] = args.use_dilation + postprocess_params["score_mode"] = args.det_db_score_mode elif self.det_algorithm == "EAST": postprocess_params['name'] = 'EASTPostProcess' postprocess_params["score_thresh"] = args.det_east_score_thresh @@ -85,38 +90,73 @@ def __init__(self, args): postprocess_params["sample_pts_num"] = 2 postprocess_params["expand_scale"] = 1.0 postprocess_params["shrink_ratio_of_width"] = 0.3 + elif self.det_algorithm == "PSE": + postprocess_params['name'] = 'PSEPostProcess' + postprocess_params["thresh"] = args.det_pse_thresh + postprocess_params["box_thresh"] = args.det_pse_box_thresh + postprocess_params["min_area"] = args.det_pse_min_area + postprocess_params["box_type"] = args.det_pse_box_type + postprocess_params["scale"] = args.det_pse_scale + self.det_pse_box_type = args.det_pse_box_type + elif self.det_algorithm == "FCE": + pre_process_list[0] = { + 'DetResizeForTest': { + 'rescale_img': [1080, 736] + } + } + postprocess_params['name'] = 'FCEPostProcess' + postprocess_params["scales"] = args.scales + postprocess_params["alpha"] = args.alpha + postprocess_params["beta"] = args.beta + postprocess_params["fourier_degree"] = args.fourier_degree + postprocess_params["box_type"] = args.det_fce_box_type else: logger.info("unknown det_algorithm:{}".format(self.det_algorithm)) sys.exit(0) self.preprocess_op = create_operators(pre_process_list) self.postprocess_op = build_post_process(postprocess_params) - self.predictor, self.input_tensor, self.output_tensors = utility.create_predictor( - args, 'det', logger) # paddle.jit.load(args.det_model_dir) - # self.predictor.eval() + self.predictor, self.input_tensor, self.output_tensors, self.config = utility.create_predictor( + args, 'det', logger) + + if self.use_onnx: + img_h, img_w = self.input_tensor.shape[2:] + if img_h is not None and img_w is not None and img_h > 0 and img_w > 0: + pre_process_list[0] = { + 'DetResizeForTest': { + 'image_shape': [img_h, img_w] + } + } + self.preprocess_op = create_operators(pre_process_list) + + if args.benchmark: + import auto_log + pid = os.getpid() + gpu_id = utility.get_infer_gpuid() + self.autolog = auto_log.AutoLogger( + model_name="det", + model_precision=args.precision, + batch_size=1, + data_shape="dynamic", + save_path=None, + inference_config=self.config, + pids=pid, + process_name=None, + gpu_ids=gpu_id if args.use_gpu else None, + time_keys=[ + 'preprocess_time', 'inference_time', 'postprocess_time' + ], + warmup=2, + logger=logger) def order_points_clockwise(self, pts): - """ - reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py - # sort the points based on their x-coordinates - """ - xSorted = pts[np.argsort(pts[:, 0]), :] - - # grab the left-most and right-most points from the sorted - # x-roodinate points - leftMost = xSorted[:2, :] - rightMost = xSorted[2:, :] - - # now, sort the left-most coordinates according to their - # y-coordinates so we can grab the top-left and bottom-left - # points, respectively - leftMost = leftMost[np.argsort(leftMost[:, 1]), :] - (tl, bl) = leftMost - - rightMost = rightMost[np.argsort(rightMost[:, 1]), :] - (tr, br) = rightMost - - rect = np.array([tl, tr, br, bl], dtype="float32") + rect = np.zeros((4, 2), dtype="float32") + s = pts.sum(axis=1) + rect[0] = pts[np.argmin(s)] + rect[2] = pts[np.argmax(s)] + diff = np.diff(pts, axis=1) + rect[1] = pts[np.argmin(diff)] + rect[3] = pts[np.argmax(diff)] return rect def clip_det_res(self, points, img_height, img_width): @@ -151,6 +191,12 @@ def filter_tag_det_res_only_clip(self, dt_boxes, image_shape): def __call__(self, img): ori_im = img.copy() data = {'image': img} + + st = time.time() + + if self.args.benchmark: + self.autolog.times.start() + data = transform(data, self.preprocess_op) img, shape_list = data if img is None: @@ -158,14 +204,22 @@ def __call__(self, img): img = np.expand_dims(img, axis=0) shape_list = np.expand_dims(shape_list, axis=0) img = img.copy() - starttime = time.time() - self.input_tensor.copy_from_cpu(img) - self.predictor.run() - outputs = [] - for output_tensor in self.output_tensors: - output = output_tensor.copy_to_cpu() - outputs.append(output) + if self.args.benchmark: + self.autolog.times.stamp() + if self.use_onnx: + input_dict = {} + input_dict[self.input_tensor.name] = img + outputs = self.predictor.run(self.output_tensors, input_dict) + else: + self.input_tensor.copy_from_cpu(img) + self.predictor.run() + outputs = [] + for output_tensor in self.output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) + if self.args.benchmark: + self.autolog.times.stamp() preds = {} if self.det_algorithm == "EAST": @@ -176,19 +230,28 @@ def __call__(self, img): preds['f_score'] = outputs[1] preds['f_tco'] = outputs[2] preds['f_tvo'] = outputs[3] - elif self.det_algorithm == 'DB': + elif self.det_algorithm in ['DB', 'PSE']: preds['maps'] = outputs[0] + elif self.det_algorithm == 'FCE': + for i, output in enumerate(outputs): + preds['level_{}'.format(i)] = output else: raise NotImplementedError + #self.predictor.try_shrink_memory() post_result = self.postprocess_op(preds, shape_list) dt_boxes = post_result[0]['points'] - if self.det_algorithm == "SAST" and self.det_sast_polygon: + if (self.det_algorithm == "SAST" and self.det_sast_polygon) or ( + self.det_algorithm in ["PSE", "FCE"] and + self.postprocess_op.box_type == 'poly'): dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape) else: dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape) - elapse = time.time() - starttime - return dt_boxes, elapse + + if self.args.benchmark: + self.autolog.times.end(stamp=True) + et = time.time() + return dt_boxes, et - st if __name__ == "__main__": @@ -198,8 +261,15 @@ def __call__(self, img): count = 0 total_time = 0 draw_img_save = "./inference_results" + + if args.warmup: + img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8) + for i in range(2): + res = text_detector(img) + if not os.path.exists(draw_img_save): os.makedirs(draw_img_save) + save_results = [] for image_file in image_file_list: img, flag = check_and_read_gif(image_file) if not flag: @@ -207,16 +277,26 @@ def __call__(self, img): if img is None: logger.info("error in loading image:{}".format(image_file)) continue - dt_boxes, elapse = text_detector(img) + st = time.time() + dt_boxes, _ = text_detector(img) + elapse = time.time() - st if count > 0: total_time += elapse count += 1 - logger.info("Predict time of {}: {}".format(image_file, elapse)) + save_pred = os.path.basename(image_file) + "\t" + str( + json.dumps([x.tolist() for x in dt_boxes])) + "\n" + save_results.append(save_pred) + logger.info(save_pred) + logger.info("The predict time of {}: {}".format(image_file, elapse)) src_im = utility.draw_text_det_res(dt_boxes, image_file) img_name_pure = os.path.split(image_file)[-1] img_path = os.path.join(draw_img_save, "det_res_{}".format(img_name_pure)) cv2.imwrite(img_path, src_im) logger.info("The visualized image saved in {}".format(img_path)) - if count > 1: - logger.info("Avg Time: {}".format(total_time / (count - 1))) + + with open(os.path.join(draw_img_save, "det_results.txt"), 'w') as f: + f.writelines(save_results) + f.close() + if args.benchmark: + text_detector.autolog.report() diff --git a/backend/tools/infer/predict_e2e.py b/backend/tools/infer/predict_e2e.py new file mode 100755 index 00000000..fb2859f0 --- /dev/null +++ b/backend/tools/infer/predict_e2e.py @@ -0,0 +1,169 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..'))) + +os.environ["FLAGS_allocator_strategy"] = 'auto_growth' + +import cv2 +import numpy as np +import time +import sys + +import tools.infer.utility as utility +from ppocr.utils.logging import get_logger +from ppocr.utils.utility import get_image_file_list, check_and_read_gif +from ppocr.data import create_operators, transform +from ppocr.postprocess import build_post_process + +logger = get_logger() + + +class TextE2E(object): + def __init__(self, args): + self.args = args + self.e2e_algorithm = args.e2e_algorithm + self.use_onnx = args.use_onnx + pre_process_list = [{ + 'E2EResizeForTest': {} + }, { + 'NormalizeImage': { + 'std': [0.229, 0.224, 0.225], + 'mean': [0.485, 0.456, 0.406], + 'scale': '1./255.', + 'order': 'hwc' + } + }, { + 'ToCHWImage': None + }, { + 'KeepKeys': { + 'keep_keys': ['image', 'shape'] + } + }] + postprocess_params = {} + if self.e2e_algorithm == "PGNet": + pre_process_list[0] = { + 'E2EResizeForTest': { + 'max_side_len': args.e2e_limit_side_len, + 'valid_set': 'totaltext' + } + } + postprocess_params['name'] = 'PGPostProcess' + postprocess_params["score_thresh"] = args.e2e_pgnet_score_thresh + postprocess_params["character_dict_path"] = args.e2e_char_dict_path + postprocess_params["valid_set"] = args.e2e_pgnet_valid_set + postprocess_params["mode"] = args.e2e_pgnet_mode + else: + logger.info("unknown e2e_algorithm:{}".format(self.e2e_algorithm)) + sys.exit(0) + + self.preprocess_op = create_operators(pre_process_list) + self.postprocess_op = build_post_process(postprocess_params) + self.predictor, self.input_tensor, self.output_tensors, _ = utility.create_predictor( + args, 'e2e', logger) # paddle.jit.load(args.det_model_dir) + # self.predictor.eval() + + def clip_det_res(self, points, img_height, img_width): + for pno in range(points.shape[0]): + points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1)) + points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1)) + return points + + def filter_tag_det_res_only_clip(self, dt_boxes, image_shape): + img_height, img_width = image_shape[0:2] + dt_boxes_new = [] + for box in dt_boxes: + box = self.clip_det_res(box, img_height, img_width) + dt_boxes_new.append(box) + dt_boxes = np.array(dt_boxes_new) + return dt_boxes + + def __call__(self, img): + + ori_im = img.copy() + data = {'image': img} + data = transform(data, self.preprocess_op) + img, shape_list = data + if img is None: + return None, 0 + img = np.expand_dims(img, axis=0) + shape_list = np.expand_dims(shape_list, axis=0) + img = img.copy() + starttime = time.time() + + if self.use_onnx: + input_dict = {} + input_dict[self.input_tensor.name] = img + outputs = self.predictor.run(self.output_tensors, input_dict) + preds = {} + preds['f_border'] = outputs[0] + preds['f_char'] = outputs[1] + preds['f_direction'] = outputs[2] + preds['f_score'] = outputs[3] + else: + self.input_tensor.copy_from_cpu(img) + self.predictor.run() + outputs = [] + for output_tensor in self.output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) + + preds = {} + if self.e2e_algorithm == 'PGNet': + preds['f_border'] = outputs[0] + preds['f_char'] = outputs[1] + preds['f_direction'] = outputs[2] + preds['f_score'] = outputs[3] + else: + raise NotImplementedError + post_result = self.postprocess_op(preds, shape_list) + points, strs = post_result['points'], post_result['texts'] + dt_boxes = self.filter_tag_det_res_only_clip(points, ori_im.shape) + elapse = time.time() - starttime + return dt_boxes, strs, elapse + + +if __name__ == "__main__": + args = utility.parse_args() + image_file_list = get_image_file_list(args.image_dir) + text_detector = TextE2E(args) + count = 0 + total_time = 0 + draw_img_save = "./inference_results" + if not os.path.exists(draw_img_save): + os.makedirs(draw_img_save) + for image_file in image_file_list: + img, flag = check_and_read_gif(image_file) + if not flag: + img = cv2.imread(image_file) + if img is None: + logger.info("error in loading image:{}".format(image_file)) + continue + points, strs, elapse = text_detector(img) + if count > 0: + total_time += elapse + count += 1 + logger.info("Predict time of {}: {}".format(image_file, elapse)) + src_im = utility.draw_e2e_res(points, strs, image_file) + img_name_pure = os.path.split(image_file)[-1] + img_path = os.path.join(draw_img_save, + "e2e_res_{}".format(img_name_pure)) + cv2.imwrite(img_path, src_im) + logger.info("The visualized image saved in {}".format(img_path)) + if count > 1: + logger.info("Avg Time: {}".format(total_time / (count - 1))) diff --git a/backend/tools/infer/predict_rec.py b/backend/tools/infer/predict_rec.py index b3d9d490..3664ef2c 100755 --- a/backend/tools/infer/predict_rec.py +++ b/backend/tools/infer/predict_rec.py @@ -13,10 +13,10 @@ # limitations under the License. import os import sys - +from PIL import Image __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) -sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..'))) os.environ["FLAGS_allocator_strategy"] = 'auto_growth' @@ -38,44 +38,91 @@ class TextRecognizer(object): def __init__(self, args): self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")] - self.character_type = args.rec_char_type self.rec_batch_num = args.rec_batch_num self.rec_algorithm = args.rec_algorithm postprocess_params = { 'name': 'CTCLabelDecode', - "character_type": args.rec_char_type, "character_dict_path": args.rec_char_dict_path, "use_space_char": args.use_space_char } if self.rec_algorithm == "SRN": postprocess_params = { 'name': 'SRNLabelDecode', - "character_type": args.rec_char_type, "character_dict_path": args.rec_char_dict_path, "use_space_char": args.use_space_char } elif self.rec_algorithm == "RARE": postprocess_params = { 'name': 'AttnLabelDecode', - "character_type": args.rec_char_type, + "character_dict_path": args.rec_char_dict_path, + "use_space_char": args.use_space_char + } + elif self.rec_algorithm == 'NRTR': + postprocess_params = { + 'name': 'NRTRLabelDecode', + "character_dict_path": args.rec_char_dict_path, + "use_space_char": args.use_space_char + } + elif self.rec_algorithm == "SAR": + postprocess_params = { + 'name': 'SARLabelDecode', "character_dict_path": args.rec_char_dict_path, "use_space_char": args.use_space_char } self.postprocess_op = build_post_process(postprocess_params) - self.predictor, self.input_tensor, self.output_tensors = \ + self.predictor, self.input_tensor, self.output_tensors, self.config = \ utility.create_predictor(args, 'rec', logger) + self.benchmark = args.benchmark + self.use_onnx = args.use_onnx + if args.benchmark: + import auto_log + pid = os.getpid() + gpu_id = utility.get_infer_gpuid() + self.autolog = auto_log.AutoLogger( + model_name="rec", + model_precision=args.precision, + batch_size=args.rec_batch_num, + data_shape="dynamic", + save_path=None, #args.save_log_path, + inference_config=self.config, + pids=pid, + process_name=None, + gpu_ids=gpu_id if args.use_gpu else None, + time_keys=[ + 'preprocess_time', 'inference_time', 'postprocess_time' + ], + warmup=0, + logger=logger) def resize_norm_img(self, img, max_wh_ratio): imgC, imgH, imgW = self.rec_image_shape + if self.rec_algorithm == 'NRTR': + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + # return padding_im + image_pil = Image.fromarray(np.uint8(img)) + img = image_pil.resize([100, 32], Image.ANTIALIAS) + img = np.array(img) + norm_img = np.expand_dims(img, -1) + norm_img = norm_img.transpose((2, 0, 1)) + return norm_img.astype(np.float32) / 128. - 1. + assert imgC == img.shape[2] - if self.character_type == "ch": - imgW = int((32 * max_wh_ratio)) + imgW = int((imgH * max_wh_ratio)) + if self.use_onnx: + w = self.input_tensor.shape[3:][0] + if w is not None and w > 0: + imgW = w + h, w = img.shape[:2] ratio = w / float(h) if math.ceil(imgH * ratio) > imgW: resized_w = imgW else: resized_w = int(math.ceil(imgH * ratio)) + if self.rec_algorithm == 'RARE': + if resized_w > self.rec_image_shape[2]: + resized_w = self.rec_image_shape[2] + imgW = self.rec_image_shape[2] resized_image = cv2.resize(img, (resized_w, imgH)) resized_image = resized_image.astype('float32') resized_image = resized_image.transpose((2, 0, 1)) / 255 @@ -85,6 +132,17 @@ def resize_norm_img(self, img, max_wh_ratio): padding_im[:, :, 0:resized_w] = resized_image return padding_im + def resize_norm_img_svtr(self, img, image_shape): + + imgC, imgH, imgW = image_shape + resized_image = cv2.resize( + img, (imgW, imgH), interpolation=cv2.INTER_LINEAR) + resized_image = resized_image.astype('float32') + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + return resized_image + def resize_norm_img_srn(self, img, image_shape): imgC, imgH, imgW = image_shape @@ -157,6 +215,41 @@ def process_image_srn(self, img, image_shape, num_heads, max_text_length): return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2) + def resize_norm_img_sar(self, img, image_shape, + width_downsample_ratio=0.25): + imgC, imgH, imgW_min, imgW_max = image_shape + h = img.shape[0] + w = img.shape[1] + valid_ratio = 1.0 + # make sure new_width is an integral multiple of width_divisor. + width_divisor = int(1 / width_downsample_ratio) + # resize + ratio = w / float(h) + resize_w = math.ceil(imgH * ratio) + if resize_w % width_divisor != 0: + resize_w = round(resize_w / width_divisor) * width_divisor + if imgW_min is not None: + resize_w = max(imgW_min, resize_w) + if imgW_max is not None: + valid_ratio = min(1.0, 1.0 * resize_w / imgW_max) + resize_w = min(imgW_max, resize_w) + resized_image = cv2.resize(img, (resize_w, imgH)) + resized_image = resized_image.astype('float32') + # norm + if image_shape[0] == 1: + resized_image = resized_image / 255 + resized_image = resized_image[np.newaxis, :] + else: + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + resize_shape = resized_image.shape + padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32) + padding_im[:, :, 0:resize_w] = resized_image + pad_shape = padding_im.shape + + return padding_im, resize_shape, pad_shape, valid_ratio + def __call__(self, img_list): img_num = len(img_list) # Calculate the aspect ratio of all text bars @@ -165,27 +258,32 @@ def __call__(self, img_list): width_list.append(img.shape[1] / float(img.shape[0])) # Sorting can speed up the recognition process indices = np.argsort(np.array(width_list)) - - # rec_res = [] rec_res = [['', 0.0]] * img_num batch_num = self.rec_batch_num - elapse = 0 + st = time.time() + if self.benchmark: + self.autolog.times.start() for beg_img_no in range(0, img_num, batch_num): end_img_no = min(img_num, beg_img_no + batch_num) norm_img_batch = [] - max_wh_ratio = 0 + imgC, imgH, imgW = self.rec_image_shape + max_wh_ratio = imgW / imgH + # max_wh_ratio = 0 for ino in range(beg_img_no, end_img_no): - # h, w = img_list[ino].shape[0:2] h, w = img_list[indices[ino]].shape[0:2] wh_ratio = w * 1.0 / h max_wh_ratio = max(max_wh_ratio, wh_ratio) for ino in range(beg_img_no, end_img_no): - if self.rec_algorithm != "SRN": - norm_img = self.resize_norm_img(img_list[indices[ino]], - max_wh_ratio) + + if self.rec_algorithm == "SAR": + norm_img, _, _, valid_ratio = self.resize_norm_img_sar( + img_list[indices[ino]], self.rec_image_shape) norm_img = norm_img[np.newaxis, :] + valid_ratio = np.expand_dims(valid_ratio, axis=0) + valid_ratios = [] + valid_ratios.append(valid_ratio) norm_img_batch.append(norm_img) - else: + elif self.rec_algorithm == "SRN": norm_img = self.process_image_srn( img_list[indices[ino]], self.rec_image_shape, 8, 25) encoder_word_pos_list = [] @@ -197,11 +295,22 @@ def __call__(self, img_list): gsrm_slf_attn_bias1_list.append(norm_img[3]) gsrm_slf_attn_bias2_list.append(norm_img[4]) norm_img_batch.append(norm_img[0]) + elif self.rec_algorithm == "SVTR": + norm_img = self.resize_norm_img_svtr(img_list[indices[ino]], + self.rec_image_shape) + norm_img = norm_img[np.newaxis, :] + norm_img_batch.append(norm_img) + else: + norm_img = self.resize_norm_img(img_list[indices[ino]], + max_wh_ratio) + norm_img = norm_img[np.newaxis, :] + norm_img_batch.append(norm_img) norm_img_batch = np.concatenate(norm_img_batch) norm_img_batch = norm_img_batch.copy() + if self.benchmark: + self.autolog.times.stamp() if self.rec_algorithm == "SRN": - starttime = time.time() encoder_word_pos_list = np.concatenate(encoder_word_pos_list) gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list) gsrm_slf_attn_bias1_list = np.concatenate( @@ -216,33 +325,78 @@ def __call__(self, img_list): gsrm_slf_attn_bias1_list, gsrm_slf_attn_bias2_list, ] - input_names = self.predictor.get_input_names() - for i in range(len(input_names)): - input_tensor = self.predictor.get_input_handle(input_names[ - i]) - input_tensor.copy_from_cpu(inputs[i]) - self.predictor.run() - outputs = [] - for output_tensor in self.output_tensors: - output = output_tensor.copy_to_cpu() - outputs.append(output) - preds = {"predict": outputs[2]} + if self.use_onnx: + input_dict = {} + input_dict[self.input_tensor.name] = norm_img_batch + outputs = self.predictor.run(self.output_tensors, + input_dict) + preds = {"predict": outputs[2]} + else: + input_names = self.predictor.get_input_names() + for i in range(len(input_names)): + input_tensor = self.predictor.get_input_handle( + input_names[i]) + input_tensor.copy_from_cpu(inputs[i]) + self.predictor.run() + outputs = [] + for output_tensor in self.output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) + if self.benchmark: + self.autolog.times.stamp() + preds = {"predict": outputs[2]} + elif self.rec_algorithm == "SAR": + valid_ratios = np.concatenate(valid_ratios) + inputs = [ + norm_img_batch, + valid_ratios, + ] + if self.use_onnx: + input_dict = {} + input_dict[self.input_tensor.name] = norm_img_batch + outputs = self.predictor.run(self.output_tensors, + input_dict) + preds = outputs[0] + else: + input_names = self.predictor.get_input_names() + for i in range(len(input_names)): + input_tensor = self.predictor.get_input_handle( + input_names[i]) + input_tensor.copy_from_cpu(inputs[i]) + self.predictor.run() + outputs = [] + for output_tensor in self.output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) + if self.benchmark: + self.autolog.times.stamp() + preds = outputs[0] else: - starttime = time.time() - self.input_tensor.copy_from_cpu(norm_img_batch) - self.predictor.run() - - outputs = [] - for output_tensor in self.output_tensors: - output = output_tensor.copy_to_cpu() - outputs.append(output) - preds = outputs[0] - + if self.use_onnx: + input_dict = {} + input_dict[self.input_tensor.name] = norm_img_batch + outputs = self.predictor.run(self.output_tensors, + input_dict) + preds = outputs[0] + else: + self.input_tensor.copy_from_cpu(norm_img_batch) + self.predictor.run() + outputs = [] + for output_tensor in self.output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) + if self.benchmark: + self.autolog.times.stamp() + if len(outputs) != 1: + preds = outputs + else: + preds = outputs[0] rec_result = self.postprocess_op(preds) for rno in range(len(rec_result)): rec_res[indices[beg_img_no + rno]] = rec_result[rno] - elapse += time.time() - starttime - return rec_res, elapse + if self.benchmark: + self.autolog.times.end(stamp=True) + return rec_res, time.time() - st def main(args): @@ -250,6 +404,17 @@ def main(args): text_recognizer = TextRecognizer(args) valid_image_file_list = [] img_list = [] + + logger.info( + "In PP-OCRv3, rec_image_shape parameter defaults to '3, 48, 320', " + "if you are using recognition model with PP-OCRv2 or an older version, please set --rec_image_shape='3,32,320" + ) + # warmup 2 times + if args.warmup: + img = np.random.uniform(0, 255, [48, 320, 3]).astype(np.uint8) + for i in range(2): + res = text_recognizer([img] * int(args.rec_batch_num)) + for image_file in image_file_list: img, flag = check_and_read_gif(image_file) if not flag: @@ -260,21 +425,17 @@ def main(args): valid_image_file_list.append(image_file) img_list.append(img) try: - rec_res, predict_time = text_recognizer(img_list) - except: + rec_res, _ = text_recognizer(img_list) + + except Exception as E: logger.info(traceback.format_exc()) - logger.info( - "ERROR!!!! \n" - "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n" - "If your model has tps module: " - "TPS does not support variable shape.\n" - "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ") + logger.info(E) exit() for ino in range(len(img_list)): logger.info("Predicts of {}:{}".format(valid_image_file_list[ino], rec_res[ino])) - logger.info("Total predict time for {} images, cost: {:.3f}".format( - len(img_list), predict_time)) + if args.benchmark: + text_recognizer.autolog.report() if __name__ == "__main__": diff --git a/backend/tools/infer/predict_system.py b/backend/tools/infer/predict_system.py index de7ee9d3..4af3da70 100755 --- a/backend/tools/infer/predict_system.py +++ b/backend/tools/infer/predict_system.py @@ -13,17 +13,20 @@ # limitations under the License. import os import sys +import subprocess __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) -sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..'))) os.environ["FLAGS_allocator_strategy"] = 'auto_growth' import cv2 import copy import numpy as np +import json import time +import logging from PIL import Image import tools.infer.utility as utility import tools.infer.predict_rec as predict_rec @@ -31,13 +34,15 @@ import tools.infer.predict_cls as predict_cls from ppocr.utils.utility import get_image_file_list, check_and_read_gif from ppocr.utils.logging import get_logger -from tools.infer.utility import draw_ocr_box_txt - +from tools.infer.utility import draw_ocr_box_txt, get_rotate_crop_image logger = get_logger() class TextSystem(object): def __init__(self, args): + if not args.show_log: + logger.setLevel(logging.INFO) + self.text_detector = predict_det.TextDetector(args) self.text_recognizer = predict_rec.TextRecognizer(args) self.use_angle_cls = args.use_angle_cls @@ -45,50 +50,24 @@ def __init__(self, args): if self.use_angle_cls: self.text_classifier = predict_cls.TextClassifier(args) - def get_rotate_crop_image(self, img, points): - ''' - img_height, img_width = img.shape[0:2] - left = int(np.min(points[:, 0])) - right = int(np.max(points[:, 0])) - top = int(np.min(points[:, 1])) - bottom = int(np.max(points[:, 1])) - img_crop = img[top:bottom, left:right, :].copy() - points[:, 0] = points[:, 0] - left - points[:, 1] = points[:, 1] - top - ''' - img_crop_width = int( - max( - np.linalg.norm(points[0] - points[1]), - np.linalg.norm(points[2] - points[3]))) - img_crop_height = int( - max( - np.linalg.norm(points[0] - points[3]), - np.linalg.norm(points[1] - points[2]))) - pts_std = np.float32([[0, 0], [img_crop_width, 0], - [img_crop_width, img_crop_height], - [0, img_crop_height]]) - M = cv2.getPerspectiveTransform(points, pts_std) - dst_img = cv2.warpPerspective( - img, - M, (img_crop_width, img_crop_height), - borderMode=cv2.BORDER_REPLICATE, - flags=cv2.INTER_CUBIC) - dst_img_height, dst_img_width = dst_img.shape[0:2] - if dst_img_height * 1.0 / dst_img_width >= 1.5: - dst_img = np.rot90(dst_img) - return dst_img - - def print_draw_crop_rec_res(self, img_crop_list, rec_res): + self.args = args + self.crop_image_res_index = 0 + + def draw_crop_rec_res(self, output_dir, img_crop_list, rec_res): + os.makedirs(output_dir, exist_ok=True) bbox_num = len(img_crop_list) for bno in range(bbox_num): - cv2.imwrite("./output/img_crop_%d.jpg" % bno, img_crop_list[bno]) - logger.info(bno, rec_res[bno]) + cv2.imwrite( + os.path.join(output_dir, + f"mg_crop_{bno+self.crop_image_res_index}.jpg"), + img_crop_list[bno]) + logger.debug(f"{bno}, {rec_res[bno]}") + self.crop_image_res_index += bbox_num - def __call__(self, img): + def __call__(self, img, cls=True): ori_im = img.copy() dt_boxes, elapse = self.text_detector(img) - logger.info("dt_boxes num : {}, elapse : {}".format( - len(dt_boxes), elapse)) + if dt_boxes is None: return None, None img_crop_list = [] @@ -97,24 +76,23 @@ def __call__(self, img): for bno in range(len(dt_boxes)): tmp_box = copy.deepcopy(dt_boxes[bno]) - img_crop = self.get_rotate_crop_image(ori_im, tmp_box) + img_crop = get_rotate_crop_image(ori_im, tmp_box) img_crop_list.append(img_crop) - if self.use_angle_cls: + if self.use_angle_cls and cls: img_crop_list, angle_list, elapse = self.text_classifier( img_crop_list) - logger.info("cls num : {}, elapse : {}".format( - len(img_crop_list), elapse)) + rec_res, elapse = self.text_recognizer(img_crop_list) - logger.info("rec_res num : {}, elapse : {}".format( - len(rec_res), elapse)) - # self.print_draw_crop_rec_res(img_crop_list, rec_res) + if self.args.save_crop_res: + self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list, + rec_res) filter_boxes, filter_rec_res = [], [] - for box, rec_reuslt in zip(dt_boxes, rec_res): - text, score = rec_reuslt + for box, rec_result in zip(dt_boxes, rec_res): + text, score = rec_result if score >= self.drop_score: filter_boxes.append(box) - filter_rec_res.append(rec_reuslt) + filter_rec_res.append(rec_result) return filter_boxes, filter_rec_res @@ -141,24 +119,49 @@ def sorted_boxes(dt_boxes): def main(args): image_file_list = get_image_file_list(args.image_dir) + image_file_list = image_file_list[args.process_id::args.total_process_num] text_sys = TextSystem(args) is_visualize = True font_path = args.vis_font_path drop_score = args.drop_score - for image_file in image_file_list: + draw_img_save_dir = args.draw_img_save_dir + os.makedirs(draw_img_save_dir, exist_ok=True) + save_results = [] + + logger.info("In PP-OCRv3, rec_image_shape parameter defaults to '3, 48, 320', " + "if you are using recognition model with PP-OCRv2 or an older version, please set --rec_image_shape='3,32,320") + + # warm up 10 times + if args.warmup: + img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8) + for i in range(10): + res = text_sys(img) + + total_time = 0 + cpu_mem, gpu_mem, gpu_util = 0, 0, 0 + _st = time.time() + count = 0 + for idx, image_file in enumerate(image_file_list): + img, flag = check_and_read_gif(image_file) if not flag: img = cv2.imread(image_file) if img is None: - logger.info("error in loading image:{}".format(image_file)) + logger.debug("error in loading image:{}".format(image_file)) continue starttime = time.time() dt_boxes, rec_res = text_sys(img) elapse = time.time() - starttime - logger.info("Predict time of %s: %.3fs" % (image_file, elapse)) + total_time += elapse + - for text, score in rec_res: - logger.info("{}, {:.3f}".format(text, score)) + res = [{ + "transcription": rec_res[idx][0], + "points": np.array(dt_boxes[idx]).astype(np.int32).tolist(), + } for idx in range(len(dt_boxes))] + save_pred = os.path.basename(image_file) + "\t" + json.dumps( + res, ensure_ascii=False) + "\n" + save_results.append(save_pred) if is_visualize: image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) @@ -173,15 +176,35 @@ def main(args): scores, drop_score=drop_score, font_path=font_path) - draw_img_save = "./inference_results/" - if not os.path.exists(draw_img_save): - os.makedirs(draw_img_save) + if flag: + image_file = image_file[:-3] + "png" cv2.imwrite( - os.path.join(draw_img_save, os.path.basename(image_file)), + os.path.join(draw_img_save_dir, os.path.basename(image_file)), draw_img[:, :, ::-1]) - logger.info("The visualized image saved in {}".format( - os.path.join(draw_img_save, os.path.basename(image_file)))) + + + logger.info("The predict total time is {}".format(time.time() - _st)) + if args.benchmark: + text_sys.text_detector.autolog.report() + text_sys.text_recognizer.autolog.report() + + with open(os.path.join(draw_img_save_dir, "system_results.txt"), 'w', encoding='utf-8') as f: + f.writelines(save_results) if __name__ == "__main__": - main(utility.parse_args()) + args = utility.parse_args() + if args.use_mp: + p_list = [] + total_process_num = args.total_process_num + for process_id in range(total_process_num): + cmd = [sys.executable, "-u"] + sys.argv + [ + "--process_id={}".format(process_id), + "--use_mp={}".format(False) + ] + p = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stdout) + p_list.append(p) + for p in p_list: + p.wait() + else: + main(args) diff --git a/backend/tools/infer/utility.py b/backend/tools/infer/utility.py old mode 100755 new mode 100644 index 92f3e745..29b3755e --- a/backend/tools/infer/utility.py +++ b/backend/tools/infer/utility.py @@ -15,24 +15,29 @@ import argparse import os import sys +import platform import cv2 import numpy as np -import json +import paddle from PIL import Image, ImageDraw, ImageFont import math from paddle import inference +import time +from ppocr.utils.logging import get_logger -def parse_args(): - def str2bool(v): - return v.lower() in ("true", "t", "1") +def str2bool(v): + return v.lower() in ("true", "t", "1") + +def init_args(): parser = argparse.ArgumentParser() # params for prediction engine parser.add_argument("--use_gpu", type=str2bool, default=True) parser.add_argument("--ir_optim", type=str2bool, default=True) parser.add_argument("--use_tensorrt", type=str2bool, default=False) - parser.add_argument("--use_fp16", type=str2bool, default=False) + parser.add_argument("--min_subgraph_size", type=int, default=15) + parser.add_argument("--precision", type=str, default="fp32") parser.add_argument("--gpu_mem", type=int, default=500) # params for text detector @@ -44,11 +49,11 @@ def str2bool(v): # DB parmas parser.add_argument("--det_db_thresh", type=float, default=0.3) - parser.add_argument("--det_db_box_thresh", type=float, default=0.5) - parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6) + parser.add_argument("--det_db_box_thresh", type=float, default=0.6) + parser.add_argument("--det_db_unclip_ratio", type=float, default=1.5) parser.add_argument("--max_batch_size", type=int, default=10) - parser.add_argument("--use_dilation", type=bool, default=False) - + parser.add_argument("--use_dilation", type=str2bool, default=False) + parser.add_argument("--det_db_score_mode", type=str, default="fast") # EAST parmas parser.add_argument("--det_east_score_thresh", type=float, default=0.8) parser.add_argument("--det_east_cover_thresh", type=float, default=0.1) @@ -57,13 +62,26 @@ def str2bool(v): # SAST parmas parser.add_argument("--det_sast_score_thresh", type=float, default=0.5) parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2) - parser.add_argument("--det_sast_polygon", type=bool, default=False) + parser.add_argument("--det_sast_polygon", type=str2bool, default=False) + + # PSE parmas + parser.add_argument("--det_pse_thresh", type=float, default=0) + parser.add_argument("--det_pse_box_thresh", type=float, default=0.85) + parser.add_argument("--det_pse_min_area", type=float, default=16) + parser.add_argument("--det_pse_box_type", type=str, default='quad') + parser.add_argument("--det_pse_scale", type=int, default=1) + + # FCE parmas + parser.add_argument("--scales", type=list, default=[8, 16, 32]) + parser.add_argument("--alpha", type=float, default=1.0) + parser.add_argument("--beta", type=float, default=1.0) + parser.add_argument("--fourier_degree", type=int, default=5) + parser.add_argument("--det_fce_box_type", type=str, default='poly') # params for text recognizer parser.add_argument("--rec_algorithm", type=str, default='CRNN') parser.add_argument("--rec_model_dir", type=str) - parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320") - parser.add_argument("--rec_char_type", type=str, default='ch') + parser.add_argument("--rec_image_shape", type=str, default="3, 48, 320") parser.add_argument("--rec_batch_num", type=int, default=6) parser.add_argument("--max_text_length", type=int, default=25) parser.add_argument( @@ -75,6 +93,19 @@ def str2bool(v): "--vis_font_path", type=str, default="./doc/fonts/simfang.ttf") parser.add_argument("--drop_score", type=float, default=0.5) + # params for e2e + parser.add_argument("--e2e_algorithm", type=str, default='PGNet') + parser.add_argument("--e2e_model_dir", type=str) + parser.add_argument("--e2e_limit_side_len", type=float, default=768) + parser.add_argument("--e2e_limit_type", type=str, default='max') + + # PGNet parmas + parser.add_argument("--e2e_pgnet_score_thresh", type=float, default=0.5) + parser.add_argument( + "--e2e_char_dict_path", type=str, default="./ppocr/utils/ic15_dict.txt") + parser.add_argument("--e2e_pgnet_valid_set", type=str, default='totaltext') + parser.add_argument("--e2e_pgnet_mode", type=str, default='fast') + # params for text classifier parser.add_argument("--use_angle_cls", type=str2bool, default=False) parser.add_argument("--cls_model_dir", type=str) @@ -84,8 +115,31 @@ def str2bool(v): parser.add_argument("--cls_thresh", type=float, default=0.9) parser.add_argument("--enable_mkldnn", type=str2bool, default=False) + parser.add_argument("--cpu_threads", type=int, default=10) parser.add_argument("--use_pdserving", type=str2bool, default=False) + parser.add_argument("--warmup", type=str2bool, default=False) + + # + parser.add_argument( + "--draw_img_save_dir", type=str, default="./inference_results") + parser.add_argument("--save_crop_res", type=str2bool, default=False) + parser.add_argument("--crop_res_save_dir", type=str, default="./output") + + # multi-process + parser.add_argument("--use_mp", type=str2bool, default=False) + parser.add_argument("--total_process_num", type=int, default=1) + parser.add_argument("--process_id", type=int, default=0) + + parser.add_argument("--benchmark", type=str2bool, default=False) + parser.add_argument("--save_log_path", type=str, default="./log_output/") + parser.add_argument("--show_log", type=str2bool, default=False) + parser.add_argument("--use_onnx", type=str2bool, default=False) + return parser + + +def parse_args(): + parser = init_args() return parser.parse_args() @@ -94,59 +148,224 @@ def create_predictor(args, mode, logger): model_dir = args.det_model_dir elif mode == 'cls': model_dir = args.cls_model_dir - else: + elif mode == 'rec': model_dir = args.rec_model_dir + elif mode == 'table': + model_dir = args.table_model_dir + else: + model_dir = args.e2e_model_dir if model_dir is None: logger.info("not find {} model file path {}".format(mode, model_dir)) sys.exit(0) - model_file_path = model_dir + "/inference.pdmodel" - params_file_path = model_dir + "/inference.pdiparams" - if not os.path.exists(model_file_path): - logger.info("not find model file path {}".format(model_file_path)) - sys.exit(0) - if not os.path.exists(params_file_path): - logger.info("not find params file path {}".format(params_file_path)) - sys.exit(0) + if args.use_onnx: + import onnxruntime as ort + model_file_path = model_dir + if not os.path.exists(model_file_path): + raise ValueError("not find model file path {}".format( + model_file_path)) + sess = ort.InferenceSession(model_file_path) + return sess, sess.get_inputs()[0], None, None - config = inference.Config(model_file_path, params_file_path) - - if args.use_gpu: - config.enable_use_gpu(args.gpu_mem, 0) - if args.use_tensorrt: - config.enable_tensorrt_engine( - precision_mode=inference.PrecisionType.Half - if args.use_fp16 else inference.PrecisionType.Float32, - max_batch_size=args.max_batch_size) else: - config.disable_gpu() - config.set_cpu_math_library_num_threads(6) - if args.enable_mkldnn: - # cache 10 different shapes for mkldnn to avoid memory leak - config.set_mkldnn_cache_capacity(10) - config.enable_mkldnn() - # TODO LDOUBLEV: fix mkldnn bug when bach_size > 1 - #config.set_mkldnn_op({'conv2d', 'depthwise_conv2d', 'pool2d', 'batch_norm'}) - args.rec_batch_num = 1 - - # enable memory optim - config.enable_memory_optim() - config.disable_glog_info() - - config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") - config.switch_use_feed_fetch_ops(False) - - # create predictor - predictor = inference.create_predictor(config) - input_names = predictor.get_input_names() - for name in input_names: - input_tensor = predictor.get_input_handle(name) + model_file_path = model_dir + "/inference.pdmodel" + params_file_path = model_dir + "/inference.pdiparams" + if not os.path.exists(model_file_path): + raise ValueError("not find model file path {}".format( + model_file_path)) + if not os.path.exists(params_file_path): + raise ValueError("not find params file path {}".format( + params_file_path)) + + config = inference.Config(model_file_path, params_file_path) + + if hasattr(args, 'precision'): + if args.precision == "fp16" and args.use_tensorrt: + precision = inference.PrecisionType.Half + elif args.precision == "int8": + precision = inference.PrecisionType.Int8 + else: + precision = inference.PrecisionType.Float32 + else: + precision = inference.PrecisionType.Float32 + + if args.use_gpu: + gpu_id = get_infer_gpuid() + if gpu_id is None: + logger.warning( + "GPU is not found in current device by nvidia-smi. Please check your device or ignore it if run on jetson." + ) + config.enable_use_gpu(args.gpu_mem, 0) + if args.use_tensorrt: + config.enable_tensorrt_engine( + workspace_size=1 << 30, + precision_mode=precision, + max_batch_size=args.max_batch_size, + min_subgraph_size=args.min_subgraph_size) + # skip the minmum trt subgraph + use_dynamic_shape = True + if mode == "det": + min_input_shape = { + "x": [1, 3, 50, 50], + "conv2d_92.tmp_0": [1, 120, 20, 20], + "conv2d_91.tmp_0": [1, 24, 10, 10], + "conv2d_59.tmp_0": [1, 96, 20, 20], + "nearest_interp_v2_1.tmp_0": [1, 256, 10, 10], + "nearest_interp_v2_2.tmp_0": [1, 256, 20, 20], + "conv2d_124.tmp_0": [1, 256, 20, 20], + "nearest_interp_v2_3.tmp_0": [1, 64, 20, 20], + "nearest_interp_v2_4.tmp_0": [1, 64, 20, 20], + "nearest_interp_v2_5.tmp_0": [1, 64, 20, 20], + "elementwise_add_7": [1, 56, 2, 2], + "nearest_interp_v2_0.tmp_0": [1, 256, 2, 2] + } + max_input_shape = { + "x": [1, 3, 1536, 1536], + "conv2d_92.tmp_0": [1, 120, 400, 400], + "conv2d_91.tmp_0": [1, 24, 200, 200], + "conv2d_59.tmp_0": [1, 96, 400, 400], + "nearest_interp_v2_1.tmp_0": [1, 256, 200, 200], + "conv2d_124.tmp_0": [1, 256, 400, 400], + "nearest_interp_v2_2.tmp_0": [1, 256, 400, 400], + "nearest_interp_v2_3.tmp_0": [1, 64, 400, 400], + "nearest_interp_v2_4.tmp_0": [1, 64, 400, 400], + "nearest_interp_v2_5.tmp_0": [1, 64, 400, 400], + "elementwise_add_7": [1, 56, 400, 400], + "nearest_interp_v2_0.tmp_0": [1, 256, 400, 400] + } + opt_input_shape = { + "x": [1, 3, 640, 640], + "conv2d_92.tmp_0": [1, 120, 160, 160], + "conv2d_91.tmp_0": [1, 24, 80, 80], + "conv2d_59.tmp_0": [1, 96, 160, 160], + "nearest_interp_v2_1.tmp_0": [1, 256, 80, 80], + "nearest_interp_v2_2.tmp_0": [1, 256, 160, 160], + "conv2d_124.tmp_0": [1, 256, 160, 160], + "nearest_interp_v2_3.tmp_0": [1, 64, 160, 160], + "nearest_interp_v2_4.tmp_0": [1, 64, 160, 160], + "nearest_interp_v2_5.tmp_0": [1, 64, 160, 160], + "elementwise_add_7": [1, 56, 40, 40], + "nearest_interp_v2_0.tmp_0": [1, 256, 40, 40] + } + min_pact_shape = { + "nearest_interp_v2_26.tmp_0": [1, 256, 20, 20], + "nearest_interp_v2_27.tmp_0": [1, 64, 20, 20], + "nearest_interp_v2_28.tmp_0": [1, 64, 20, 20], + "nearest_interp_v2_29.tmp_0": [1, 64, 20, 20] + } + max_pact_shape = { + "nearest_interp_v2_26.tmp_0": [1, 256, 400, 400], + "nearest_interp_v2_27.tmp_0": [1, 64, 400, 400], + "nearest_interp_v2_28.tmp_0": [1, 64, 400, 400], + "nearest_interp_v2_29.tmp_0": [1, 64, 400, 400] + } + opt_pact_shape = { + "nearest_interp_v2_26.tmp_0": [1, 256, 160, 160], + "nearest_interp_v2_27.tmp_0": [1, 64, 160, 160], + "nearest_interp_v2_28.tmp_0": [1, 64, 160, 160], + "nearest_interp_v2_29.tmp_0": [1, 64, 160, 160] + } + min_input_shape.update(min_pact_shape) + max_input_shape.update(max_pact_shape) + opt_input_shape.update(opt_pact_shape) + elif mode == "rec": + if args.rec_algorithm != "CRNN": + use_dynamic_shape = False + imgH = int(args.rec_image_shape.split(',')[-2]) + min_input_shape = {"x": [1, 3, imgH, 10]} + max_input_shape = {"x": [args.rec_batch_num, 3, imgH, 1536]} + opt_input_shape = {"x": [args.rec_batch_num, 3, imgH, 320]} + elif mode == "cls": + min_input_shape = {"x": [1, 3, 48, 10]} + max_input_shape = {"x": [args.rec_batch_num, 3, 48, 1024]} + opt_input_shape = {"x": [args.rec_batch_num, 3, 48, 320]} + else: + use_dynamic_shape = False + if use_dynamic_shape: + config.set_trt_dynamic_shape_info( + min_input_shape, max_input_shape, opt_input_shape) + + else: + config.disable_gpu() + if hasattr(args, "cpu_threads"): + config.set_cpu_math_library_num_threads(args.cpu_threads) + else: + # default cpu threads as 10 + config.set_cpu_math_library_num_threads(10) + if args.enable_mkldnn: + # cache 10 different shapes for mkldnn to avoid memory leak + config.set_mkldnn_cache_capacity(10) + config.enable_mkldnn() + if args.precision == "fp16": + config.enable_mkldnn_bfloat16() + # enable memory optim + config.enable_memory_optim() + config.disable_glog_info() + config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") + config.delete_pass("matmul_transpose_reshape_fuse_pass") + if mode == 'table': + config.delete_pass("fc_fuse_pass") # not supported for table + config.switch_use_feed_fetch_ops(False) + config.switch_ir_optim(True) + + # create predictor + predictor = inference.create_predictor(config) + input_names = predictor.get_input_names() + for name in input_names: + input_tensor = predictor.get_input_handle(name) + output_tensors = get_output_tensors(args, mode, predictor) + return predictor, input_tensor, output_tensors, config + + +def get_output_tensors(args, mode, predictor): output_names = predictor.get_output_names() output_tensors = [] - for output_name in output_names: - output_tensor = predictor.get_output_handle(output_name) - output_tensors.append(output_tensor) - return predictor, input_tensor, output_tensors + if mode == "rec" and args.rec_algorithm == "CRNN": + output_name = 'softmax_0.tmp_0' + if output_name in output_names: + return [predictor.get_output_handle(output_name)] + else: + for output_name in output_names: + output_tensor = predictor.get_output_handle(output_name) + output_tensors.append(output_tensor) + else: + for output_name in output_names: + output_tensor = predictor.get_output_handle(output_name) + output_tensors.append(output_tensor) + return output_tensors + + +def get_infer_gpuid(): + sysstr = platform.system() + if sysstr == "Windows": + return 0 + + if not paddle.core.is_compiled_with_rocm(): + cmd = "env | grep CUDA_VISIBLE_DEVICES" + else: + cmd = "env | grep HIP_VISIBLE_DEVICES" + env_cuda = os.popen(cmd).readlines() + if len(env_cuda) == 0: + return 0 + else: + gpu_id = env_cuda[0].strip().split("=")[1] + return int(gpu_id[0]) + + +def draw_e2e_res(dt_boxes, strs, img_path): + src_im = cv2.imread(img_path) + for box, str in zip(dt_boxes, strs): + box = box.astype(np.int32).reshape((-1, 1, 2)) + cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2) + cv2.putText( + src_im, + str, + org=(int(box[0, 0, 0]), int(box[0, 0, 1])), + fontFace=cv2.FONT_HERSHEY_COMPLEX, + fontScale=0.7, + color=(0, 255, 0), + thickness=1) + return src_im def draw_text_det_res(dt_boxes, img_path): @@ -174,7 +393,7 @@ def draw_ocr(image, txts=None, scores=None, drop_score=0.5, - font_path="./doc/simfang.ttf"): + font_path="./doc/fonts/simfang.ttf"): """ Visualize the results of OCR detection and recognition args: @@ -216,7 +435,7 @@ def draw_ocr_box_txt(image, scores=None, drop_score=0.5, font_path="./doc/simfang.ttf"): - h, w = image.frame_height, image.frame_width + h, w = image.height, image.width img_left = image.copy() img_right = Image.new('RGB', (w, h), (255, 255, 255)) @@ -381,23 +600,46 @@ def draw_boxes(image, boxes, scores=None, drop_score=0.5): return image +def get_rotate_crop_image(img, points): + ''' + img_height, img_width = img.shape[0:2] + left = int(np.min(points[:, 0])) + right = int(np.max(points[:, 0])) + top = int(np.min(points[:, 1])) + bottom = int(np.max(points[:, 1])) + img_crop = img[top:bottom, left:right, :].copy() + points[:, 0] = points[:, 0] - left + points[:, 1] = points[:, 1] - top + ''' + assert len(points) == 4, "shape of points must be 4*2" + img_crop_width = int( + max( + np.linalg.norm(points[0] - points[1]), + np.linalg.norm(points[2] - points[3]))) + img_crop_height = int( + max( + np.linalg.norm(points[0] - points[3]), + np.linalg.norm(points[1] - points[2]))) + pts_std = np.float32([[0, 0], [img_crop_width, 0], + [img_crop_width, img_crop_height], + [0, img_crop_height]]) + M = cv2.getPerspectiveTransform(points, pts_std) + dst_img = cv2.warpPerspective( + img, + M, (img_crop_width, img_crop_height), + borderMode=cv2.BORDER_REPLICATE, + flags=cv2.INTER_CUBIC) + dst_img_height, dst_img_width = dst_img.shape[0:2] + if dst_img_height * 1.0 / dst_img_width >= 1.5: + dst_img = np.rot90(dst_img) + return dst_img + + +def check_gpu(use_gpu): + if use_gpu and not paddle.is_compiled_with_cuda(): + use_gpu = False + return use_gpu + + if __name__ == '__main__': - test_img = "./doc/test_v2" - predict_txt = "./doc/predict.txt" - f = open(predict_txt, 'r') - data = f.readlines() - img_path, anno = data[0].strip().split('\t') - img_name = os.path.basename(img_path) - img_path = os.path.join(test_img, img_name) - image = Image.open(img_path) - - data = json.loads(anno) - boxes, txts, scores = [], [], [] - for dic in data: - boxes.append(dic['points']) - txts.append(dic['transcription']) - scores.append(round(dic['scores'], 3)) - - new_img = draw_ocr(image, boxes, txts, scores) - - cv2.imwrite(img_name, new_img) + pass diff --git a/backend/tools/infer_cls.py b/backend/tools/infer_cls.py index 49696482..7fd6b536 100755 --- a/backend/tools/infer_cls.py +++ b/backend/tools/infer_cls.py @@ -23,7 +23,7 @@ __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) -sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..'))) os.environ["FLAGS_allocator_strategy"] = 'auto_growth' @@ -32,7 +32,7 @@ from ppocr.data import create_operators, transform from ppocr.modeling.architectures import build_model from ppocr.postprocess import build_post_process -from ppocr.utils.save_load import init_model +from ppocr.utils.save_load import load_model from ppocr.utils.utility import get_image_file_list import tools.program as program @@ -47,7 +47,7 @@ def main(): # build model model = build_model(config['Architecture']) - init_model(config, model, logger) + load_model(config, model) # create data ops transforms = [] @@ -57,6 +57,8 @@ def main(): continue elif op_name == 'KeepKeys': op[op_name]['keep_keys'] = ['image'] + elif op_name == "SSLRotateResize": + op[op_name]["mode"] = "test" transforms.append(op) global_config['infer_mode'] = True ops = create_operators(transforms, global_config) @@ -73,8 +75,8 @@ def main(): images = paddle.to_tensor(images) preds = model(images) post_result = post_process_class(preds) - for rec_reuslt in post_result: - logger.info('\t result: {}'.format(rec_reuslt)) + for rec_result in post_result: + logger.info('\t result: {}'.format(rec_result)) logger.info("success!") diff --git a/backend/tools/infer_det.py b/backend/tools/infer_det.py index 913d617d..1acecedf 100755 --- a/backend/tools/infer_det.py +++ b/backend/tools/infer_det.py @@ -23,7 +23,7 @@ __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) -sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..'))) os.environ["FLAGS_allocator_strategy"] = 'auto_growth' @@ -34,35 +34,33 @@ from ppocr.data import create_operators, transform from ppocr.modeling.architectures import build_model from ppocr.postprocess import build_post_process -from ppocr.utils.save_load import init_model +from ppocr.utils.save_load import load_model from ppocr.utils.utility import get_image_file_list import tools.program as program -def draw_det_res(dt_boxes, config, img, img_name): +def draw_det_res(dt_boxes, config, img, img_name, save_path): if len(dt_boxes) > 0: import cv2 src_im = img for box in dt_boxes: box = box.astype(np.int32).reshape((-1, 1, 2)) cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2) - save_det_path = os.path.dirname(config['Global'][ - 'save_res_path']) + "/det_results/" - if not os.path.exists(save_det_path): - os.makedirs(save_det_path) - save_path = os.path.join(save_det_path, os.path.basename(img_name)) + if not os.path.exists(save_path): + os.makedirs(save_path) + save_path = os.path.join(save_path, os.path.basename(img_name)) cv2.imwrite(save_path, src_im) logger.info("The detected Image saved in {}".format(save_path)) +@paddle.no_grad() def main(): global_config = config['Global'] # build model model = build_model(config['Architecture']) - init_model(config, model, logger) - + load_model(config, model) # build post process post_process_class = build_post_process(config['PostProcess']) @@ -96,20 +94,41 @@ def main(): images = paddle.to_tensor(images) preds = model(images) post_result = post_process_class(preds, shape_list) - boxes = post_result[0]['points'] - # write result + + src_img = cv2.imread(file) + dt_boxes_json = [] - for box in boxes: - tmp_json = {"transcription": ""} - tmp_json['points'] = box.tolist() - dt_boxes_json.append(tmp_json) + # parser boxes if post_result is dict + if isinstance(post_result, dict): + det_box_json = {} + for k in post_result.keys(): + boxes = post_result[k][0]['points'] + dt_boxes_list = [] + for box in boxes: + tmp_json = {"transcription": ""} + tmp_json['points'] = box.tolist() + dt_boxes_list.append(tmp_json) + det_box_json[k] = dt_boxes_list + save_det_path = os.path.dirname(config['Global'][ + 'save_res_path']) + "/det_results_{}/".format(k) + draw_det_res(boxes, config, src_img, file, save_det_path) + else: + boxes = post_result[0]['points'] + dt_boxes_json = [] + # write result + for box in boxes: + tmp_json = {"transcription": ""} + tmp_json['points'] = box.tolist() + dt_boxes_json.append(tmp_json) + save_det_path = os.path.dirname(config['Global'][ + 'save_res_path']) + "/det_results/" + draw_det_res(boxes, config, src_img, file, save_det_path) otstr = file + "\t" + json.dumps(dt_boxes_json) + "\n" fout.write(otstr.encode()) - src_img = cv2.imread(file) - draw_det_res(boxes, config, src_img, file) + logger.info("success!") if __name__ == '__main__': config, device, logger, vdl_writer = program.preprocess() - main() \ No newline at end of file + main() diff --git a/backend/tools/infer_e2e.py b/backend/tools/infer_e2e.py new file mode 100755 index 00000000..d3e6b28f --- /dev/null +++ b/backend/tools/infer_e2e.py @@ -0,0 +1,122 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..'))) + +os.environ["FLAGS_allocator_strategy"] = 'auto_growth' + +import cv2 +import json +import paddle + +from ppocr.data import create_operators, transform +from ppocr.modeling.architectures import build_model +from ppocr.postprocess import build_post_process +from ppocr.utils.save_load import load_model +from ppocr.utils.utility import get_image_file_list +import tools.program as program + + +def draw_e2e_res(dt_boxes, strs, config, img, img_name): + if len(dt_boxes) > 0: + src_im = img + for box, str in zip(dt_boxes, strs): + box = box.astype(np.int32).reshape((-1, 1, 2)) + cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2) + cv2.putText( + src_im, + str, + org=(int(box[0, 0, 0]), int(box[0, 0, 1])), + fontFace=cv2.FONT_HERSHEY_COMPLEX, + fontScale=0.7, + color=(0, 255, 0), + thickness=1) + save_det_path = os.path.dirname(config['Global'][ + 'save_res_path']) + "/e2e_results/" + if not os.path.exists(save_det_path): + os.makedirs(save_det_path) + save_path = os.path.join(save_det_path, os.path.basename(img_name)) + cv2.imwrite(save_path, src_im) + logger.info("The e2e Image saved in {}".format(save_path)) + + +def main(): + global_config = config['Global'] + + # build model + model = build_model(config['Architecture']) + + load_model(config, model) + + # build post process + post_process_class = build_post_process(config['PostProcess'], + global_config) + + # create data ops + transforms = [] + for op in config['Eval']['dataset']['transforms']: + op_name = list(op)[0] + if 'Label' in op_name: + continue + elif op_name == 'KeepKeys': + op[op_name]['keep_keys'] = ['image', 'shape'] + transforms.append(op) + + ops = create_operators(transforms, global_config) + + save_res_path = config['Global']['save_res_path'] + if not os.path.exists(os.path.dirname(save_res_path)): + os.makedirs(os.path.dirname(save_res_path)) + + model.eval() + with open(save_res_path, "wb") as fout: + for file in get_image_file_list(config['Global']['infer_img']): + logger.info("infer_img: {}".format(file)) + with open(file, 'rb') as f: + img = f.read() + data = {'image': img} + batch = transform(data, ops) + images = np.expand_dims(batch[0], axis=0) + shape_list = np.expand_dims(batch[1], axis=0) + images = paddle.to_tensor(images) + preds = model(images) + post_result = post_process_class(preds, shape_list) + points, strs = post_result['points'], post_result['texts'] + # write result + dt_boxes_json = [] + for poly, str in zip(points, strs): + tmp_json = {"transcription": str} + tmp_json['points'] = poly.tolist() + dt_boxes_json.append(tmp_json) + otstr = file + "\t" + json.dumps(dt_boxes_json) + "\n" + fout.write(otstr.encode()) + src_img = cv2.imread(file) + draw_e2e_res(points, strs, config, src_img, file) + logger.info("success!") + + +if __name__ == '__main__': + config, device, logger, vdl_writer = program.preprocess() + main() diff --git a/backend/tools/infer_kie.py b/backend/tools/infer_kie.py new file mode 100755 index 00000000..0cb0b870 --- /dev/null +++ b/backend/tools/infer_kie.py @@ -0,0 +1,153 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import paddle.nn.functional as F + +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..'))) + +os.environ["FLAGS_allocator_strategy"] = 'auto_growth' + +import cv2 +import paddle + +from ppocr.data import create_operators, transform +from ppocr.modeling.architectures import build_model +from ppocr.utils.save_load import load_model +import tools.program as program +import time + + +def read_class_list(filepath): + dict = {} + with open(filepath, "r") as f: + lines = f.readlines() + for line in lines: + key, value = line.split(" ") + dict[key] = value.rstrip() + return dict + + +def draw_kie_result(batch, node, idx_to_cls, count): + img = batch[6].copy() + boxes = batch[7] + h, w = img.shape[:2] + pred_img = np.ones((h, w * 2, 3), dtype=np.uint8) * 255 + max_value, max_idx = paddle.max(node, -1), paddle.argmax(node, -1) + node_pred_label = max_idx.numpy().tolist() + node_pred_score = max_value.numpy().tolist() + + for i, box in enumerate(boxes): + if i >= len(node_pred_label): + break + new_box = [[box[0], box[1]], [box[2], box[1]], [box[2], box[3]], + [box[0], box[3]]] + Pts = np.array([new_box], np.int32) + cv2.polylines( + img, [Pts.reshape((-1, 1, 2))], + True, + color=(255, 255, 0), + thickness=1) + x_min = int(min([point[0] for point in new_box])) + y_min = int(min([point[1] for point in new_box])) + + pred_label = str(node_pred_label[i]) + if pred_label in idx_to_cls: + pred_label = idx_to_cls[pred_label] + pred_score = '{:.2f}'.format(node_pred_score[i]) + text = pred_label + '(' + pred_score + ')' + cv2.putText(pred_img, text, (x_min * 2, y_min), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1) + vis_img = np.ones((h, w * 3, 3), dtype=np.uint8) * 255 + vis_img[:, :w] = img + vis_img[:, w:] = pred_img + save_kie_path = os.path.dirname(config['Global'][ + 'save_res_path']) + "/kie_results/" + if not os.path.exists(save_kie_path): + os.makedirs(save_kie_path) + save_path = os.path.join(save_kie_path, str(count) + ".png") + cv2.imwrite(save_path, vis_img) + logger.info("The Kie Image saved in {}".format(save_path)) + + +def main(): + global_config = config['Global'] + + # build model + model = build_model(config['Architecture']) + load_model(config, model) + + # create data ops + transforms = [] + for op in config['Eval']['dataset']['transforms']: + transforms.append(op) + + data_dir = config['Eval']['dataset']['data_dir'] + + ops = create_operators(transforms, global_config) + + save_res_path = config['Global']['save_res_path'] + class_path = config['Global']['class_path'] + idx_to_cls = read_class_list(class_path) + if not os.path.exists(os.path.dirname(save_res_path)): + os.makedirs(os.path.dirname(save_res_path)) + + model.eval() + + warmup_times = 0 + count_t = [] + with open(save_res_path, "wb") as fout: + with open(config['Global']['infer_img'], "rb") as f: + lines = f.readlines() + for index, data_line in enumerate(lines): + if index == 10: + warmup_t = time.time() + data_line = data_line.decode('utf-8') + substr = data_line.strip("\n").split("\t") + img_path, label = data_dir + "/" + substr[0], substr[1] + data = {'img_path': img_path, 'label': label} + with open(data['img_path'], 'rb') as f: + img = f.read() + data['image'] = img + st = time.time() + batch = transform(data, ops) + batch_pred = [0] * len(batch) + for i in range(len(batch)): + batch_pred[i] = paddle.to_tensor( + np.expand_dims( + batch[i], axis=0)) + st = time.time() + node, edge = model(batch_pred) + node = F.softmax(node, -1) + count_t.append(time.time() - st) + draw_kie_result(batch, node, idx_to_cls, index) + logger.info("success!") + logger.info("It took {} s for predict {} images.".format( + np.sum(count_t), len(count_t))) + ips = len(count_t[warmup_times:]) / np.sum(count_t[warmup_times:]) + logger.info("The ips is {} images/s".format(ips)) + + +if __name__ == '__main__': + config, device, logger, vdl_writer = program.preprocess() + main() diff --git a/backend/tools/infer_rec.py b/backend/tools/infer_rec.py index 075ec261..193e24a4 100755 --- a/backend/tools/infer_rec.py +++ b/backend/tools/infer_rec.py @@ -20,10 +20,11 @@ import os import sys +import json __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) -sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..'))) os.environ["FLAGS_allocator_strategy"] = 'auto_growth' @@ -32,7 +33,7 @@ from ppocr.data import create_operators, transform from ppocr.modeling.architectures import build_model from ppocr.postprocess import build_post_process -from ppocr.utils.save_load import init_model +from ppocr.utils.save_load import load_model from ppocr.utils.utility import get_image_file_list import tools.program as program @@ -46,12 +47,38 @@ def main(): # build model if hasattr(post_process_class, 'character'): - config['Architecture']["Head"]['out_channels'] = len( - getattr(post_process_class, 'character')) + char_num = len(getattr(post_process_class, 'character')) + if config['Architecture']["algorithm"] in ["Distillation", + ]: # distillation model + for key in config['Architecture']["Models"]: + if config['Architecture']['Models'][key]['Head'][ + 'name'] == 'MultiHead': # for multi head + out_channels_list = {} + if config['PostProcess'][ + 'name'] == 'DistillationSARLabelDecode': + char_num = char_num - 2 + out_channels_list['CTCLabelDecode'] = char_num + out_channels_list['SARLabelDecode'] = char_num + 2 + config['Architecture']['Models'][key]['Head'][ + 'out_channels_list'] = out_channels_list + else: + config['Architecture']["Models"][key]["Head"][ + 'out_channels'] = char_num + elif config['Architecture']['Head'][ + 'name'] == 'MultiHead': # for multi head loss + out_channels_list = {} + if config['PostProcess']['name'] == 'SARLabelDecode': + char_num = char_num - 2 + out_channels_list['CTCLabelDecode'] = char_num + out_channels_list['SARLabelDecode'] = char_num + 2 + config['Architecture']['Head'][ + 'out_channels_list'] = out_channels_list + else: # base rec model + config['Architecture']["Head"]['out_channels'] = char_num model = build_model(config['Architecture']) - init_model(config, model, logger) + load_model(config, model) # create data ops transforms = [] @@ -67,41 +94,70 @@ def main(): 'image', 'encoder_word_pos', 'gsrm_word_pos', 'gsrm_slf_attn_bias1', 'gsrm_slf_attn_bias2' ] + elif config['Architecture']['algorithm'] == "SAR": + op[op_name]['keep_keys'] = ['image', 'valid_ratio'] else: op[op_name]['keep_keys'] = ['image'] transforms.append(op) global_config['infer_mode'] = True ops = create_operators(transforms, global_config) + save_res_path = config['Global'].get('save_res_path', + "./output/rec/predicts_rec.txt") + if not os.path.exists(os.path.dirname(save_res_path)): + os.makedirs(os.path.dirname(save_res_path)) + model.eval() - for file in get_image_file_list(config['Global']['infer_img']): - logger.info("infer_img: {}".format(file)) - with open(file, 'rb') as f: - img = f.read() - data = {'image': img} - batch = transform(data, ops) - if config['Architecture']['algorithm'] == "SRN": - encoder_word_pos_list = np.expand_dims(batch[1], axis=0) - gsrm_word_pos_list = np.expand_dims(batch[2], axis=0) - gsrm_slf_attn_bias1_list = np.expand_dims(batch[3], axis=0) - gsrm_slf_attn_bias2_list = np.expand_dims(batch[4], axis=0) - - others = [ - paddle.to_tensor(encoder_word_pos_list), - paddle.to_tensor(gsrm_word_pos_list), - paddle.to_tensor(gsrm_slf_attn_bias1_list), - paddle.to_tensor(gsrm_slf_attn_bias2_list) - ] - - images = np.expand_dims(batch[0], axis=0) - images = paddle.to_tensor(images) - if config['Architecture']['algorithm'] == "SRN": - preds = model(images, others) - else: - preds = model(images) - post_result = post_process_class(preds) - for rec_reuslt in post_result: - logger.info('\t result: {}'.format(rec_reuslt)) + + with open(save_res_path, "w") as fout: + for file in get_image_file_list(config['Global']['infer_img']): + logger.info("infer_img: {}".format(file)) + with open(file, 'rb') as f: + img = f.read() + data = {'image': img} + batch = transform(data, ops) + if config['Architecture']['algorithm'] == "SRN": + encoder_word_pos_list = np.expand_dims(batch[1], axis=0) + gsrm_word_pos_list = np.expand_dims(batch[2], axis=0) + gsrm_slf_attn_bias1_list = np.expand_dims(batch[3], axis=0) + gsrm_slf_attn_bias2_list = np.expand_dims(batch[4], axis=0) + + others = [ + paddle.to_tensor(encoder_word_pos_list), + paddle.to_tensor(gsrm_word_pos_list), + paddle.to_tensor(gsrm_slf_attn_bias1_list), + paddle.to_tensor(gsrm_slf_attn_bias2_list) + ] + if config['Architecture']['algorithm'] == "SAR": + valid_ratio = np.expand_dims(batch[-1], axis=0) + img_metas = [paddle.to_tensor(valid_ratio)] + + images = np.expand_dims(batch[0], axis=0) + images = paddle.to_tensor(images) + if config['Architecture']['algorithm'] == "SRN": + preds = model(images, others) + elif config['Architecture']['algorithm'] == "SAR": + preds = model(images, img_metas) + else: + preds = model(images) + post_result = post_process_class(preds) + info = None + if isinstance(post_result, dict): + rec_info = dict() + for key in post_result: + if len(post_result[key][0]) >= 2: + rec_info[key] = { + "label": post_result[key][0][0], + "score": float(post_result[key][0][1]), + } + info = json.dumps(rec_info, ensure_ascii=False) + else: + if len(post_result[0]) >= 2: + info = post_result[0][0] + "\t" + str(post_result[0][1]) + + if info is not None: + logger.info("\t result: {}".format(info)) + fout.write(file + "\t" + info) logger.info("success!") diff --git a/backend/tools/infer_table.py b/backend/tools/infer_table.py new file mode 100644 index 00000000..66c2da44 --- /dev/null +++ b/backend/tools/infer_table.py @@ -0,0 +1,107 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +import os +import sys +import json + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..'))) + +os.environ["FLAGS_allocator_strategy"] = 'auto_growth' + +import paddle +from paddle.jit import to_static + +from ppocr.data import create_operators, transform +from ppocr.modeling.architectures import build_model +from ppocr.postprocess import build_post_process +from ppocr.utils.save_load import load_model +from ppocr.utils.utility import get_image_file_list +import tools.program as program +import cv2 + + +def main(config, device, logger, vdl_writer): + global_config = config['Global'] + + # build post process + post_process_class = build_post_process(config['PostProcess'], + global_config) + + # build model + if hasattr(post_process_class, 'character'): + config['Architecture']["Head"]['out_channels'] = len( + getattr(post_process_class, 'character')) + + model = build_model(config['Architecture']) + + load_model(config, model) + + # create data ops + transforms = [] + use_padding = False + for op in config['Eval']['dataset']['transforms']: + op_name = list(op)[0] + if 'Label' in op_name: + continue + if op_name == 'KeepKeys': + op[op_name]['keep_keys'] = ['image'] + if op_name == "ResizeTableImage": + use_padding = True + padding_max_len = op['ResizeTableImage']['max_len'] + transforms.append(op) + + global_config['infer_mode'] = True + ops = create_operators(transforms, global_config) + + model.eval() + for file in get_image_file_list(config['Global']['infer_img']): + logger.info("infer_img: {}".format(file)) + with open(file, 'rb') as f: + img = f.read() + data = {'image': img} + batch = transform(data, ops) + images = np.expand_dims(batch[0], axis=0) + images = paddle.to_tensor(images) + preds = model(images) + post_result = post_process_class(preds) + res_html_code = post_result['res_html_code'] + res_loc = post_result['res_loc'] + img = cv2.imread(file) + imgh, imgw = img.shape[0:2] + res_loc_final = [] + for rno in range(len(res_loc[0])): + x0, y0, x1, y1 = res_loc[0][rno] + left = max(int(imgw * x0), 0) + top = max(int(imgh * y0), 0) + right = min(int(imgw * x1), imgw - 1) + bottom = min(int(imgh * y1), imgh - 1) + cv2.rectangle(img, (left, top), (right, bottom), (0, 0, 255), 2) + res_loc_final.append([left, top, right, bottom]) + res_loc_str = json.dumps(res_loc_final) + logger.info("result: {}, {}".format(res_html_code, res_loc_final)) + logger.info("success!") + + +if __name__ == '__main__': + config, device, logger, vdl_writer = program.preprocess() + main(config, device, logger, vdl_writer) diff --git a/backend/tools/infer_vqa_token_ser.py b/backend/tools/infer_vqa_token_ser.py new file mode 100755 index 00000000..83ed72b3 --- /dev/null +++ b/backend/tools/infer_vqa_token_ser.py @@ -0,0 +1,135 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..'))) + +os.environ["FLAGS_allocator_strategy"] = 'auto_growth' +import cv2 +import json +import paddle + +from ppocr.data import create_operators, transform +from ppocr.modeling.architectures import build_model +from ppocr.postprocess import build_post_process +from ppocr.utils.save_load import load_model +from ppocr.utils.visual import draw_ser_results +from ppocr.utils.utility import get_image_file_list, load_vqa_bio_label_maps +import tools.program as program + + +def to_tensor(data): + import numbers + from collections import defaultdict + data_dict = defaultdict(list) + to_tensor_idxs = [] + for idx, v in enumerate(data): + if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)): + if idx not in to_tensor_idxs: + to_tensor_idxs.append(idx) + data_dict[idx].append(v) + for idx in to_tensor_idxs: + data_dict[idx] = paddle.to_tensor(data_dict[idx]) + return list(data_dict.values()) + + +class SerPredictor(object): + def __init__(self, config): + global_config = config['Global'] + + # build post process + self.post_process_class = build_post_process(config['PostProcess'], + global_config) + + # build model + self.model = build_model(config['Architecture']) + + load_model( + config, self.model, model_type=config['Architecture']["model_type"]) + + from paddleocr import PaddleOCR + + self.ocr_engine = PaddleOCR(use_angle_cls=False, show_log=False) + + # create data ops + transforms = [] + for op in config['Eval']['dataset']['transforms']: + op_name = list(op)[0] + if 'Label' in op_name: + op[op_name]['ocr_engine'] = self.ocr_engine + elif op_name == 'KeepKeys': + op[op_name]['keep_keys'] = [ + 'input_ids', 'labels', 'bbox', 'image', 'attention_mask', + 'token_type_ids', 'segment_offset_id', 'ocr_info', + 'entities' + ] + + transforms.append(op) + global_config['infer_mode'] = True + self.ops = create_operators(config['Eval']['dataset']['transforms'], + global_config) + self.model.eval() + + def __call__(self, img_path): + with open(img_path, 'rb') as f: + img = f.read() + data = {'image': img} + batch = transform(data, self.ops) + batch = to_tensor(batch) + preds = self.model(batch) + post_result = self.post_process_class( + preds, + attention_masks=batch[4], + segment_offset_ids=batch[6], + ocr_infos=batch[7]) + return post_result, batch + + +if __name__ == '__main__': + config, device, logger, vdl_writer = program.preprocess() + os.makedirs(config['Global']['save_res_path'], exist_ok=True) + + ser_engine = SerPredictor(config) + + infer_imgs = get_image_file_list(config['Global']['infer_img']) + with open( + os.path.join(config['Global']['save_res_path'], + "infer_results.txt"), + "w", + encoding='utf-8') as fout: + for idx, img_path in enumerate(infer_imgs): + save_img_path = os.path.join( + config['Global']['save_res_path'], + os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg") + logger.info("process: [{}/{}], save result to {}".format( + idx, len(infer_imgs), save_img_path)) + + result, _ = ser_engine(img_path) + result = result[0] + fout.write(img_path + "\t" + json.dumps( + { + "ocr_info": result, + }, ensure_ascii=False) + "\n") + img_res = draw_ser_results(img_path, result) + cv2.imwrite(save_img_path, img_res) diff --git a/backend/tools/infer_vqa_token_ser_re.py b/backend/tools/infer_vqa_token_ser_re.py new file mode 100755 index 00000000..40f1dd5c --- /dev/null +++ b/backend/tools/infer_vqa_token_ser_re.py @@ -0,0 +1,199 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..'))) + +os.environ["FLAGS_allocator_strategy"] = 'auto_growth' +import cv2 +import json +import paddle +import paddle.distributed as dist + +from ppocr.data import create_operators, transform +from ppocr.modeling.architectures import build_model +from ppocr.postprocess import build_post_process +from ppocr.utils.save_load import load_model +from ppocr.utils.visual import draw_re_results +from ppocr.utils.logging import get_logger +from ppocr.utils.utility import get_image_file_list, load_vqa_bio_label_maps, print_dict +from tools.program import ArgsParser, load_config, merge_config, check_gpu +from tools.infer_vqa_token_ser import SerPredictor + + +class ReArgsParser(ArgsParser): + def __init__(self): + super(ReArgsParser, self).__init__() + self.add_argument( + "-c_ser", "--config_ser", help="ser configuration file to use") + self.add_argument( + "-o_ser", + "--opt_ser", + nargs='+', + help="set ser configuration options ") + + def parse_args(self, argv=None): + args = super(ReArgsParser, self).parse_args(argv) + assert args.config_ser is not None, \ + "Please specify --config_ser=ser_configure_file_path." + args.opt_ser = self._parse_opt(args.opt_ser) + return args + + +def make_input(ser_inputs, ser_results): + entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2} + + entities = ser_inputs[8][0] + ser_results = ser_results[0] + assert len(entities) == len(ser_results) + + # entities + start = [] + end = [] + label = [] + entity_idx_dict = {} + for i, (res, entity) in enumerate(zip(ser_results, entities)): + if res['pred'] == 'O': + continue + entity_idx_dict[len(start)] = i + start.append(entity['start']) + end.append(entity['end']) + label.append(entities_labels[res['pred']]) + entities = dict(start=start, end=end, label=label) + + # relations + head = [] + tail = [] + for i in range(len(entities["label"])): + for j in range(len(entities["label"])): + if entities["label"][i] == 1 and entities["label"][j] == 2: + head.append(i) + tail.append(j) + + relations = dict(head=head, tail=tail) + + batch_size = ser_inputs[0].shape[0] + entities_batch = [] + relations_batch = [] + entity_idx_dict_batch = [] + for b in range(batch_size): + entities_batch.append(entities) + relations_batch.append(relations) + entity_idx_dict_batch.append(entity_idx_dict) + + ser_inputs[8] = entities_batch + ser_inputs.append(relations_batch) + # remove ocr_info segment_offset_id and label in ser input + ser_inputs.pop(7) + ser_inputs.pop(6) + ser_inputs.pop(1) + return ser_inputs, entity_idx_dict_batch + + +class SerRePredictor(object): + def __init__(self, config, ser_config): + self.ser_engine = SerPredictor(ser_config) + + # init re model + global_config = config['Global'] + + # build post process + self.post_process_class = build_post_process(config['PostProcess'], + global_config) + + # build model + self.model = build_model(config['Architecture']) + + load_model( + config, self.model, model_type=config['Architecture']["model_type"]) + + self.model.eval() + + def __call__(self, img_path): + ser_results, ser_inputs = self.ser_engine(img_path) + paddle.save(ser_inputs, 'ser_inputs.npy') + paddle.save(ser_results, 'ser_results.npy') + re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results) + preds = self.model(re_input) + post_result = self.post_process_class( + preds, + ser_results=ser_results, + entity_idx_dict_batch=entity_idx_dict_batch) + return post_result + + +def preprocess(): + FLAGS = ReArgsParser().parse_args() + config = load_config(FLAGS.settings_config) + config = merge_config(config, FLAGS.opt) + + ser_config = load_config(FLAGS.config_ser) + ser_config = merge_config(ser_config, FLAGS.opt_ser) + + logger = get_logger() + + # check if set use_gpu=True in paddlepaddle cpu version + use_gpu = config['Global']['use_gpu'] + check_gpu(use_gpu) + + device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' + device = paddle.set_device(device) + + logger.info('{} re config {}'.format('*' * 10, '*' * 10)) + print_dict(config, logger) + logger.info('\n') + logger.info('{} ser config {}'.format('*' * 10, '*' * 10)) + print_dict(ser_config, logger) + logger.info('train with paddle {} and device {}'.format(paddle.__version__, + device)) + return config, ser_config, device, logger + + +if __name__ == '__main__': + config, ser_config, device, logger = preprocess() + os.makedirs(config['Global']['save_res_path'], exist_ok=True) + + ser_re_engine = SerRePredictor(config, ser_config) + + infer_imgs = get_image_file_list(config['Global']['infer_img']) + with open( + os.path.join(config['Global']['save_res_path'], + "infer_results.txt"), + "w", + encoding='utf-8') as fout: + for idx, img_path in enumerate(infer_imgs): + save_img_path = os.path.join( + config['Global']['save_res_path'], + os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg") + logger.info("process: [{}/{}], save result to {}".format( + idx, len(infer_imgs), save_img_path)) + + result = ser_re_engine(img_path) + result = result[0] + fout.write(img_path + "\t" + json.dumps( + { + "ser_result": result, + }, ensure_ascii=False) + "\n") + img_res = draw_re_results(img_path, result) + cv2.imwrite(save_img_path, img_res) diff --git a/backend/tools/makedist.py b/backend/tools/makedist.py new file mode 100644 index 00000000..c8fc9435 --- /dev/null +++ b/backend/tools/makedist.py @@ -0,0 +1,12 @@ +if __name__ == '__main__': + # 导入QPT + from qpt.executor import CreateExecutableModule as CEM + import os + WORK_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + print(WORK_DIR) + LAUNCH_PATH = os.path.join(WORK_DIR, 'gui.py') + SAVE_PATH = os.path.join(os.path.dirname(WORK_DIR), 'vse_out') + ICON_PATH = os.path.join(WORK_DIR, "design", "vse.ico") + module = CEM(work_dir=WORK_DIR, launcher_py_path=LAUNCH_PATH, save_path=SAVE_PATH, icon=ICON_PATH, hidden_terminal=False) + # 开始打包 + module.make() diff --git a/backend/tools/ocr.py b/backend/tools/ocr.py new file mode 100644 index 00000000..018140db --- /dev/null +++ b/backend/tools/ocr.py @@ -0,0 +1,124 @@ +from tools.infer import utility +from tools.infer.predict_system import TextSystem +import config +import importlib + + +# 加载文本检测+识别模型 +class OcrRecogniser: + def __init__(self): + # 获取参数对象 + importlib.reload(config) + self.args = utility.parse_args() + self.recogniser = self.init_model() + + @staticmethod + def y_round(y): + y_min = y + 10 - y % 10 + y_max = y - y % 10 + if abs(y - y_min) < abs(y - y_max): + return y_min + else: + return y_max + + def predict(self, image): + detection_box, recognise_result = self.recogniser(image) + if len(detection_box) > 0: + coordinate_list = list() + if isinstance(detection_box, list): + for i in detection_box: + i = list(i) + (x1, y1) = int(i[0][0]), int(i[0][1]) + (x2, y2) = int(i[1][0]), int(i[1][1]) + (x3, y3) = int(i[2][0]), int(i[2][1]) + (x4, y4) = int(i[3][0]), int(i[3][1]) + xmin = max(x1, x4) + xmax = min(x2, x3) + ymin = max(y1, y2) + ymax = min(y3, y4) + coordinate_list.append([xmin, xmax, ymin, ymax]) + + # 计算有多少行字幕,将每行字幕最小的ymin值放入lines + lines = [] + for i in coordinate_list: + if len(lines) < 1: + lines.append(self.y_round(i[2])) + else: + if self.y_round(i[2]) not in lines \ + and self.y_round(i[2]) + 10 not in lines \ + and self.y_round(i[2]) - 10 not in lines: + lines.append(self.y_round(i[2])) + lines = sorted(lines) + + for i in coordinate_list: + for j in lines: + if abs(j - self.y_round(i[2])) <= 10: + i[2] = j + + to_rank_res = list(zip(coordinate_list, recognise_result)) + ranked_res = [] + for line in lines: + tmp_list = [] + for i in to_rank_res: + if i[0][2] == line: + tmp_list.append(i) + # 先根据纵坐标排序 + for k in range(1, len(tmp_list)): + for j in range(0, len(tmp_list) - k): + if tmp_list[j][0][2] > tmp_list[j + 1][0][2]: + print(tmp_list[j][0][2]) + tmp_list[j], tmp_list[j + 1] = tmp_list[j + 1], tmp_list[j] + # 再根据横坐标排列 + for l in range(1, len(tmp_list)): + for j in range(0, len(tmp_list) - l): + if tmp_list[j][0][0] > tmp_list[j + 1][0][0]: + tmp_list[j], tmp_list[j + 1] = tmp_list[j + 1], tmp_list[j] + for m in tmp_list: + ranked_res.append(m) + dt_box = [] + for i in [j[0] for j in ranked_res]: + dt_box.append([(i[0], i[2]), (i[1], i[2]), (i[1], i[3]), (i[0], i[3])]) + res = [i[1] for i in ranked_res] + return dt_box, res + else: + return detection_box, recognise_result + + def init_model(self): + self.args.use_gpu = config.USE_GPU + if not config.USE_GPU: + import paddle + paddle.set_device('cpu') + # 设置文本检测模型路径 + self.args.det_model_dir = config.DET_MODEL_PATH + # 设置文本识别模型路径 + self.args.rec_model_dir = config.REC_MODEL_PATH + self.args.rec_char_dict_path = config.DICT_PATH + self.args.rec_image_shape = config.REC_IMAGE_SHAPE + # 设置识别文本的类型 + self.args.rec_char_type = config.REC_CHAR_TYPE + # 设置每张图文本框批处理数量 + self.args.rec_batch_num = config.REC_BATCH_NUM + self.args.max_batch_size = config.MAX_BATCH_SIZE + return TextSystem(self.args) + + +def get_coordinates(dt_box): + """ + 从返回的检测框中获取坐标 + :param dt_box 检测框返回结果 + :return list 坐标点列表 + """ + coordinate_list = list() + if isinstance(dt_box, list): + for i in dt_box: + i = list(i) + (x1, y1) = int(i[0][0]), int(i[0][1]) + (x2, y2) = int(i[1][0]), int(i[1][1]) + (x3, y3) = int(i[2][0]), int(i[2][1]) + (x4, y4) = int(i[3][0]), int(i[3][1]) + xmin = max(x1, x4) + xmax = min(x2, x3) + ymin = max(y1, y2) + ymax = min(y3, y4) + coordinate_list.append((xmin, xmax, ymin, ymax)) + return coordinate_list diff --git a/backend/tools/program.py b/backend/tools/program.py index ae649176..7c02dc01 100755 --- a/backend/tools/program.py +++ b/backend/tools/program.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,9 +18,10 @@ import os import sys +import platform import yaml import time -import shutil +import datetime import paddle import paddle.distributed as dist from tqdm import tqdm @@ -28,10 +29,11 @@ from ppocr.utils.stats import TrainingStats from ppocr.utils.save_load import save_model -from ppocr.utils.utility import print_dict +from ppocr.utils.utility import print_dict, AverageMeter from ppocr.utils.logging import get_logger +from ppocr.utils.loggers import VDLLogger, WandbLogger, Loggers +from ppocr.utils import profiler from ppocr.data import build_dataloader -import numpy as np class ArgsParser(ArgumentParser): @@ -41,6 +43,14 @@ def __init__(self): self.add_argument("-c", "--config", help="configuration file to use") self.add_argument( "-o", "--opt", nargs='+', help="set configuration options") + self.add_argument( + '-p', + '--profiler_options', + type=str, + default=None, + help='The option of profiler, which should be in format ' \ + '\"key1=value1;key2=value2;key3=value3\".' + ) def parse_args(self, argv=None): args = super(ArgsParser, self).parse_args(argv) @@ -60,24 +70,6 @@ def _parse_opt(self, opts): return config -class AttrDict(dict): - """Single level attribute dict, NOT recursive""" - - def __init__(self, **kwargs): - super(AttrDict, self).__init__() - super(AttrDict, self).update(kwargs) - - def __getattr__(self, key): - if key in self: - return self[key] - raise AttributeError("object has no attribute '{}'".format(key)) - - -global_config = AttrDict() - -default_config = {'Global': {'debug': False, }} - - def load_config(file_path): """ Load config from yml/yaml file. @@ -85,38 +77,39 @@ def load_config(file_path): file_path (str): Path of the config file to be loaded. Returns: global config """ - merge_config(default_config) _, ext = os.path.splitext(file_path) assert ext in ['.yml', '.yaml'], "only support yaml files for now" - merge_config(yaml.load(open(file_path, 'rb'), Loader=yaml.Loader)) - return global_config + config = yaml.load(open(file_path, 'rb'), Loader=yaml.Loader) + return config -def merge_config(config): +def merge_config(config, opts): """ Merge config into global config. Args: config (dict): Config to be merged. Returns: global config """ - for key, value in config.items(): + for key, value in opts.items(): if "." not in key: - if isinstance(value, dict) and key in global_config: - global_config[key].update(value) + if isinstance(value, dict) and key in config: + config[key].update(value) else: - global_config[key] = value + config[key] = value else: sub_keys = key.split('.') assert ( - sub_keys[0] in global_config - ), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format( - global_config.keys(), sub_keys[0]) - cur = global_config[sub_keys[0]] + sub_keys[0] in config + ), "the sub_keys can only be one of global_config: {}, but get: " \ + "{}, please check your running command".format( + config.keys(), sub_keys[0]) + cur = config[sub_keys[0]] for idx, sub_key in enumerate(sub_keys[1:]): if idx == len(sub_keys) - 2: cur[sub_key] = value else: cur = cur[sub_key] + return config def check_gpu(use_gpu): @@ -138,6 +131,25 @@ def check_gpu(use_gpu): pass +def check_xpu(use_xpu): + """ + Log error and exit when set use_xpu=true in paddlepaddle + cpu/gpu version. + """ + err = "Config use_xpu cannot be set as true while you are " \ + "using paddlepaddle cpu/gpu version ! \nPlease try: \n" \ + "\t1. Install paddlepaddle-xpu to run model on XPU \n" \ + "\t2. Set use_xpu as false in config file to run " \ + "model on CPU/GPU" + + try: + if use_xpu and not paddle.is_compiled_with_xpu(): + print(err) + sys.exit(1) + except Exception as e: + pass + + def train(config, train_dataloader, valid_dataloader, @@ -150,26 +162,33 @@ def train(config, eval_class, pre_best_model_dict, logger, - vdl_writer=None): + log_writer=None, + scaler=None): cal_metric_during_train = config['Global'].get('cal_metric_during_train', False) + calc_epoch_interval = config['Global'].get('calc_epoch_interval', 1) log_smooth_window = config['Global']['log_smooth_window'] epoch_num = config['Global']['epoch_num'] print_batch_step = config['Global']['print_batch_step'] eval_batch_step = config['Global']['eval_batch_step'] + profiler_options = config['profiler_options'] global_step = 0 + if 'global_step' in pre_best_model_dict: + global_step = pre_best_model_dict['global_step'] start_eval_step = 0 if type(eval_batch_step) == list and len(eval_batch_step) >= 2: start_eval_step = eval_batch_step[0] eval_batch_step = eval_batch_step[1] if len(valid_dataloader) == 0: logger.info( - 'No Images in eval dataset, evaluation during training will be disabled' + 'No Images in eval dataset, evaluation during training ' \ + 'will be disabled' ) start_eval_step = 1e111 logger.info( - "During the training process, after the {}th iteration, an evaluation is run every {} iterations". + "During the training process, after the {}th iteration, " \ + "an evaluation is run every {} iterations". format(start_eval_step, eval_batch_step)) save_epoch_step = config['Global']['save_epoch_step'] save_model_dir = config['Global']['save_model_dir'] @@ -183,39 +202,96 @@ def train(config, model.train() use_srn = config['Architecture']['algorithm'] == "SRN" - - if 'start_epoch' in best_model_dict: - start_epoch = best_model_dict['start_epoch'] + extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR"] + extra_input = False + if config['Architecture']['algorithm'] == 'Distillation': + for key in config['Architecture']["Models"]: + extra_input = extra_input or config['Architecture']['Models'][key][ + 'algorithm'] in extra_input_models else: - start_epoch = 1 + extra_input = config['Architecture']['algorithm'] in extra_input_models + try: + model_type = config['Architecture']['model_type'] + except: + model_type = None + + algorithm = config['Architecture']['algorithm'] + + start_epoch = best_model_dict[ + 'start_epoch'] if 'start_epoch' in best_model_dict else 1 + + total_samples = 0 + train_reader_cost = 0.0 + train_batch_cost = 0.0 + reader_start = time.time() + eta_meter = AverageMeter() + + max_iter = len(train_dataloader) - 1 if platform.system( + ) == "Windows" else len(train_dataloader) for epoch in range(start_epoch, epoch_num + 1): - train_dataloader = build_dataloader( - config, 'Train', device, logger, seed=epoch) - train_batch_cost = 0.0 - train_reader_cost = 0.0 - batch_sum = 0 - batch_start = time.time() + if train_dataloader.dataset.need_reset: + train_dataloader = build_dataloader( + config, 'Train', device, logger, seed=epoch) + max_iter = len(train_dataloader) - 1 if platform.system( + ) == "Windows" else len(train_dataloader) for idx, batch in enumerate(train_dataloader): - train_reader_cost += time.time() - batch_start - if idx >= len(train_dataloader): + profiler.add_profiler_step(profiler_options) + train_reader_cost += time.time() - reader_start + if idx >= max_iter: break lr = optimizer.get_lr() images = batch[0] if use_srn: - others = batch[-4:] - preds = model(images, others) model_average = True + + # use amp + if scaler: + with paddle.amp.auto_cast(): + if model_type == 'table' or extra_input: + preds = model(images, data=batch[1:]) + else: + preds = model(images) else: - preds = model(images) + if model_type == 'table' or extra_input: + preds = model(images, data=batch[1:]) + elif model_type in ["kie", 'vqa']: + preds = model(batch) + else: + preds = model(images) + loss = loss_class(preds, batch) avg_loss = loss['loss'] - avg_loss.backward() - optimizer.step() + + if scaler: + scaled_avg_loss = scaler.scale(avg_loss) + scaled_avg_loss.backward() + scaler.minimize(optimizer, scaled_avg_loss) + else: + avg_loss.backward() + optimizer.step() optimizer.clear_grad() - train_batch_cost += time.time() - batch_start - batch_sum += len(images) + if cal_metric_during_train and epoch % calc_epoch_interval == 0: # only rec and cls need + batch = [item.numpy() for item in batch] + if model_type in ['table', 'kie']: + eval_class(preds, batch) + else: + if config['Loss']['name'] in ['MultiLoss', 'MultiLoss_v2' + ]: # for multi head loss + post_result = post_process_class( + preds['ctc'], batch[1]) # for CTC head out + else: + post_result = post_process_class(preds, batch[1]) + eval_class(post_result, batch) + metric = eval_class.get_metric() + train_stats.update(metric) + + train_batch_time = time.time() - reader_start + train_batch_cost += train_batch_time + eta_meter.update(train_batch_time) + global_step += 1 + total_samples += len(images) if not isinstance(lr_scheduler, float): lr_scheduler.step() @@ -225,32 +301,34 @@ def train(config, stats['lr'] = lr train_stats.update(stats) - if cal_metric_during_train: # only rec and cls need - batch = [item.numpy() for item in batch] - post_result = post_process_class(preds, batch[1]) - eval_class(post_result, batch) - metric = eval_class.get_metric() - train_stats.update(metric) + if log_writer is not None and dist.get_rank() == 0: + log_writer.log_metrics(metrics=train_stats.get(), prefix="TRAIN", step=global_step) - if vdl_writer is not None and dist.get_rank() == 0: - for k, v in train_stats.get().items(): - vdl_writer.add_scalar('TRAIN/{}'.format(k), v, global_step) - vdl_writer.add_scalar('TRAIN/lr', lr, global_step) - - if dist.get_rank( - ) == 0 and global_step > 0 and global_step % print_batch_step == 0: + if dist.get_rank() == 0 and ( + (global_step > 0 and global_step % print_batch_step == 0) or + (idx >= len(train_dataloader) - 1)): logs = train_stats.log() - strs = 'epoch: [{}/{}], iter: {}, {}, reader_cost: {:.5f} s, batch_cost: {:.5f} s, samples: {}, ips: {:.5f}'.format( - epoch, epoch_num, global_step, logs, train_reader_cost / - print_batch_step, train_batch_cost / print_batch_step, - batch_sum, batch_sum / train_batch_cost) + + eta_sec = ((epoch_num + 1 - epoch) * \ + len(train_dataloader) - idx - 1) * eta_meter.avg + eta_sec_format = str(datetime.timedelta(seconds=int(eta_sec))) + strs = 'epoch: [{}/{}], global_step: {}, {}, avg_reader_cost: ' \ + '{:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, ' \ + 'ips: {:.5f} samples/s, eta: {}'.format( + epoch, epoch_num, global_step, logs, + train_reader_cost / print_batch_step, + train_batch_cost / print_batch_step, + total_samples / print_batch_step, + total_samples / train_batch_cost, eta_sec_format) logger.info(strs) - train_batch_cost = 0.0 + + total_samples = 0 train_reader_cost = 0.0 - batch_sum = 0 + train_batch_cost = 0.0 # eval if global_step > start_eval_step and \ - (global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0: + (global_step - start_eval_step) % eval_batch_step == 0 \ + and dist.get_rank() == 0: if model_average: Model_Average = paddle.incubate.optimizer.ModelAverage( 0.15, @@ -263,17 +341,16 @@ def train(config, valid_dataloader, post_process_class, eval_class, - use_srn=use_srn) + model_type, + extra_input=extra_input) cur_metric_str = 'cur metric, {}'.format(', '.join( ['{}: {}'.format(k, v) for k, v in cur_metric.items()])) logger.info(cur_metric_str) # logger metric - if vdl_writer is not None: - for k, v in cur_metric.items(): - if isinstance(v, (float, int)): - vdl_writer.add_scalar('EVAL/{}'.format(k), - cur_metric[k], global_step) + if log_writer is not None: + log_writer.log_metrics(metrics=cur_metric, prefix="EVAL", step=global_step) + if cur_metric[main_indicator] >= best_model_dict[ main_indicator]: best_model_dict.update(cur_metric) @@ -283,75 +360,111 @@ def train(config, optimizer, save_model_dir, logger, + config, is_best=True, prefix='best_accuracy', best_model_dict=best_model_dict, - epoch=epoch) + epoch=epoch, + global_step=global_step) best_str = 'best metric, {}'.format(', '.join([ '{}: {}'.format(k, v) for k, v in best_model_dict.items() ])) logger.info(best_str) # logger best metric - if vdl_writer is not None: - vdl_writer.add_scalar('EVAL/best_{}'.format(main_indicator), - best_model_dict[main_indicator], - global_step) - global_step += 1 - optimizer.clear_grad() - batch_start = time.time() + if log_writer is not None: + log_writer.log_metrics(metrics={ + "best_{}".format(main_indicator): best_model_dict[main_indicator] + }, prefix="EVAL", step=global_step) + + log_writer.log_model(is_best=True, prefix="best_accuracy", metadata=best_model_dict) + + reader_start = time.time() if dist.get_rank() == 0: save_model( model, optimizer, save_model_dir, logger, + config, is_best=False, prefix='latest', best_model_dict=best_model_dict, - epoch=epoch) + epoch=epoch, + global_step=global_step) + + if log_writer is not None: + log_writer.log_model(is_best=False, prefix="latest") + if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0: save_model( model, optimizer, save_model_dir, logger, + config, is_best=False, prefix='iter_epoch_{}'.format(epoch), best_model_dict=best_model_dict, - epoch=epoch) + epoch=epoch, + global_step=global_step) + if log_writer is not None: + log_writer.log_model(is_best=False, prefix='iter_epoch_{}'.format(epoch)) + best_str = 'best metric, {}'.format(', '.join( ['{}: {}'.format(k, v) for k, v in best_model_dict.items()])) logger.info(best_str) - if dist.get_rank() == 0 and vdl_writer is not None: - vdl_writer.close() + if dist.get_rank() == 0 and log_writer is not None: + log_writer.close() return -def eval(model, valid_dataloader, post_process_class, eval_class, - use_srn=False): +def eval(model, + valid_dataloader, + post_process_class, + eval_class, + model_type=None, + extra_input=False): model.eval() with paddle.no_grad(): total_frame = 0.0 total_time = 0.0 - pbar = tqdm(total=len(valid_dataloader), desc='eval model:') + pbar = tqdm( + total=len(valid_dataloader), + desc='eval model:', + position=0, + leave=True) + max_iter = len(valid_dataloader) - 1 if platform.system( + ) == "Windows" else len(valid_dataloader) for idx, batch in enumerate(valid_dataloader): - if idx >= len(valid_dataloader): + if idx >= max_iter: break images = batch[0] start = time.time() - - if use_srn: - others = batch[-4:] - preds = model(images, others) + if model_type == 'table' or extra_input: + preds = model(images, data=batch[1:]) + elif model_type in ["kie", 'vqa']: + preds = model(batch) else: preds = model(images) - batch = [item.numpy() for item in batch] + batch_numpy = [] + for item in batch: + if isinstance(item, paddle.Tensor): + batch_numpy.append(item.numpy()) + else: + batch_numpy.append(item) # Obtain usable results from post-processing methods - post_result = post_process_class(preds, batch[1]) total_time += time.time() - start # Evaluate the results of the current batch - eval_class(post_result, batch) + if model_type in ['table', 'kie']: + eval_class(preds, batch_numpy) + elif model_type in ['vqa']: + post_result = post_process_class(preds, batch_numpy) + eval_class(post_result, batch_numpy) + else: + post_result = post_process_class(preds, batch_numpy[1]) + eval_class(post_result, batch_numpy) + pbar.update(1) total_frame += len(images) # Get final metric,eg. acc or hmean @@ -363,44 +476,127 @@ def eval(model, valid_dataloader, post_process_class, eval_class, return metric +def update_center(char_center, post_result, preds): + result, label = post_result + feats, logits = preds + logits = paddle.argmax(logits, axis=-1) + feats = feats.numpy() + logits = logits.numpy() + + for idx_sample in range(len(label)): + if result[idx_sample][0] == label[idx_sample][0]: + feat = feats[idx_sample] + logit = logits[idx_sample] + for idx_time in range(len(logit)): + index = logit[idx_time] + if index in char_center.keys(): + char_center[index][0] = ( + char_center[index][0] * char_center[index][1] + + feat[idx_time]) / (char_center[index][1] + 1) + char_center[index][1] += 1 + else: + char_center[index] = [feat[idx_time], 1] + return char_center + + +def get_center(model, eval_dataloader, post_process_class): + pbar = tqdm(total=len(eval_dataloader), desc='get center:') + max_iter = len(eval_dataloader) - 1 if platform.system( + ) == "Windows" else len(eval_dataloader) + char_center = dict() + for idx, batch in enumerate(eval_dataloader): + if idx >= max_iter: + break + images = batch[0] + start = time.time() + preds = model(images) + + batch = [item.numpy() for item in batch] + # Obtain usable results from post-processing methods + post_result = post_process_class(preds, batch[1]) + + #update char_center + char_center = update_center(char_center, post_result, preds) + pbar.update(1) + + pbar.close() + for key in char_center.keys(): + char_center[key] = char_center[key][0] + return char_center + + def preprocess(is_train=False): FLAGS = ArgsParser().parse_args() + profiler_options = FLAGS.profiler_options config = load_config(FLAGS.config) - merge_config(FLAGS.opt) + config = merge_config(config, FLAGS.opt) + profile_dic = {"profiler_options": FLAGS.profiler_options} + config = merge_config(config, profile_dic) + + if is_train: + # save_config + save_model_dir = config['Global']['save_model_dir'] + os.makedirs(save_model_dir, exist_ok=True) + with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f: + yaml.dump( + dict(config), f, default_flow_style=False, sort_keys=False) + log_file = '{}/train.log'.format(save_model_dir) + else: + log_file = None + logger = get_logger(log_file=log_file) # check if set use_gpu=True in paddlepaddle cpu version use_gpu = config['Global']['use_gpu'] check_gpu(use_gpu) + # check if set use_xpu=True in paddlepaddle cpu/gpu version + use_xpu = False + if 'use_xpu' in config['Global']: + use_xpu = config['Global']['use_xpu'] + check_xpu(use_xpu) + alg = config['Architecture']['algorithm'] assert alg in [ - 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS' + 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', + 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE', + 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR' ] - device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' + device = 'cpu' + if use_gpu: + device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) + if use_xpu: + device = 'xpu' device = paddle.set_device(device) config['Global']['distributed'] = dist.get_world_size() != 1 - if is_train: - # save_config - save_model_dir = config['Global']['save_model_dir'] - os.makedirs(save_model_dir, exist_ok=True) - with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f: - yaml.dump( - dict(config), f, default_flow_style=False, sort_keys=False) - log_file = '{}/train.log'.format(save_model_dir) - else: - log_file = None - logger = get_logger(name='root', log_file=log_file) - if config['Global']['use_visualdl']: - from visualdl import LogWriter + + loggers = [] + + if 'use_visualdl' in config['Global'] and config['Global']['use_visualdl']: save_model_dir = config['Global']['save_model_dir'] vdl_writer_path = '{}/vdl/'.format(save_model_dir) - os.makedirs(vdl_writer_path, exist_ok=True) - vdl_writer = LogWriter(logdir=vdl_writer_path) + log_writer = VDLLogger(save_model_dir) + loggers.append(log_writer) + if ('use_wandb' in config['Global'] and config['Global']['use_wandb']) or 'wandb' in config: + save_dir = config['Global']['save_model_dir'] + wandb_writer_path = "{}/wandb".format(save_dir) + if "wandb" in config: + wandb_params = config['wandb'] + else: + wandb_params = dict() + wandb_params.update({'save_dir': save_model_dir}) + log_writer = WandbLogger(**wandb_params, config=config) + loggers.append(log_writer) else: - vdl_writer = None + log_writer = None print_dict(config, logger) + + if loggers: + log_writer = Loggers(loggers) + else: + log_writer = None + logger.info('train with paddle {} and device {}'.format(paddle.__version__, device)) - return config, device, logger, vdl_writer + return config, device, logger, log_writer diff --git a/backend/tools/reformat.py b/backend/tools/reformat.py new file mode 100644 index 00000000..a044073d --- /dev/null +++ b/backend/tools/reformat.py @@ -0,0 +1,153 @@ +# -*- coding: UTF-8 -*- +""" +@author: eritpchy +@file : reformat.py +@time : 2021/12/17 15:43 +@desc : 将连起来的英文单词切分 +""" +import json +import os +import sys + +import pysrt +import wordsegment as ws +import re + + +def execute(path, lang='en'): + # fix "RecursionError: maximum recursion depth exceeded in comparison" in wordsegment.segment call + if sys.getrecursionlimit() < 100000: + sys.setrecursionlimit(100000) + + wordsegment = ws.Segmenter() + wordsegment.load() + subs = pysrt.open(path) + verb_forms = ["I'm", "you're", "he's", "she's", "we're", "it's", "isn't", "aren't", "they're", "there's", "wasn't", + "weren't", "I've", "you've", "we've", "they've", "hasn't", "haven't", "I'd", "you'd", "he'd", "she'd", + "it'd", "we'd", "they'd", "doesn't", "don't", "didn't", "I'll", "you'll", "he'll", "she'll", "we'll", + "they'll", "there'll", "there'd", "can't", "couldn't", "daren't", "hadn't", "mightn't", "mustn't", + "needn't", "oughtn't", "shan't", "shouldn't", "usedn't", "won't", "wouldn't", "that's", "what's", "it'll"] + verb_form_map = {} + + with open(os.path.join(os.path.dirname(os.path.dirname(__file__)), 'configs', 'typoMap.json'), 'r', encoding='utf-8') as load_f: + typo_map = json.load(load_f) + + for verb in verb_forms: + verb_form_map[verb.replace("'", "").lower()] = verb + + def format_seg_list(seg_list): + new_seg = [] + for seg in seg_list: + if seg in verb_form_map: + new_seg.append([seg, verb_form_map[seg]]) + else: + new_seg.append([seg]) + return new_seg + + def typo_fix(text): + for k, v in typo_map.items(): + text = re.sub(re.compile(k, re.I), v, text) + return text + + # 逆向过滤seg + def remove_invalid_segment(seg, text): + seg_len = len(seg) + span = None + new_seg = [] + for i in range(seg_len - 1, -1, -1): + s = seg[i] + if len(s) > 1: + regex = re.compile(f"({s[0]}|{s[1]})", re.I) + else: + regex = re.compile(f"({s[0]})", re.I) + try: + ss = [(i) for i in re.finditer(regex, text)][-1] + except IndexError: + ss = None + if ss is None: + continue + text = text[:ss.span()[0]] + if span is None: + span = ss.span() + new_seg.append(s) + continue + if span > ss.span(): + new_seg.append(s) + span = ss.span() + return list(reversed(new_seg)) + + for sub in subs: + sub.text = typo_fix(sub.text) + seg = wordsegment.segment(sub.text) + if len(seg) == 1: + seg = wordsegment.segment(re.sub(re.compile(f"(\ni)([^\\s])", re.I), "\\1 \\2", sub.text)) + seg = format_seg_list(seg) + + # 替换中文前的多个空格成单个空格, 避免中英文分行出错 + sub.text = re.sub(' +([\\u4e00-\\u9fa5])', ' \\1', sub.text) + # 中英文分行 + if lang in ["ch", "ch_tra"]: + sub.text = sub.text.replace(" ", "\n") + lines = [] + remain = sub.text + seg = remove_invalid_segment(seg, sub.text) + seg_len = len(seg) + for i in range(0, seg_len): + s = seg[i] + global regex + if len(s) > 1: + regex = re.compile(f"(.*?)({s[0]}|{s[1]})", re.I) + else: + regex = re.compile(f"(.*?)({s[0]})", re.I) + ss = re.search(regex, remain) + if ss is None: + if i == seg_len - 1: + lines.append(remain.strip()) + continue + + lines.append(remain[:ss.span()[1]].strip()) + remain = remain[ss.span()[1]:].strip() + if i == seg_len - 1: + lines.append(remain) + if seg_len > 0: + ss = " ".join(lines) + else: + ss = remain + # again + ss = typo_fix(ss) + # 非大写字母的大写字母前加空格 + ss = re.sub("([^\\sA-Z\\-])([A-Z])", "\\1 \\2", ss) + # 删除重复空格 + ss = ss.replace(" ", " ") + ss = ss.replace("。", ".") + # 删除,?!,前的多个空格 + ss = re.sub(" *([\\.\\?\\!\\,])", "\\1", ss) + # 删除'的前后多个空格 + ss = re.sub(" *([\\']) *", "\\1", ss) + # 删除换行后的多个空格, 通常时第二行的开始的多个空格 + ss = re.sub('\n\\s*', '\n', ss) + # 删除开始的多个空格 + ss = re.sub('^\\s*', '', ss) + # 删除-左侧空格 + ss = re.sub("([A-Za-z0-9]) (\\-[A-Za-z0-9])", '\\1\\2', ss) + # 删除%左侧空格 + ss = re.sub("([A-Za-z0-9]) %", '\\1%', ss) + # 结尾·改成. + ss = re.sub('·$', '.', ss) + # 移除Dr.后的空格 + ss = re.sub(r'\bDr\. *\b', "Dr.", ss) + # 中文引号转英文 + ss = re.sub(r'[“”]', "\"", ss) + # 中文逗号转英文 + ss = re.sub(r',', ",", ss) + # .,?后面加空格 + ss = re.sub('([\\.,\\!\\?])([A-Za-z0-9\\u4e00-\\u9fa5])', '\\1 \\2', ss) + ss = ss.replace("\n\n", "\n") + sub.text = ss.strip() + subs.save(path, encoding='utf-8') + + +if __name__ == '__main__': + path = "/home/yao/Videos/null.srt" + execute(path) + diff --git a/backend/tools/subtitle_ocr.py b/backend/tools/subtitle_ocr.py new file mode 100644 index 00000000..5b452dd8 --- /dev/null +++ b/backend/tools/subtitle_ocr.py @@ -0,0 +1,280 @@ +import os +import re +from multiprocessing import Queue, Process +import cv2 +from PIL import ImageFont, ImageDraw, Image +from tqdm import tqdm +from tools.ocr import OcrRecogniser, get_coordinates +from tools.constant import SubtitleArea +from tools import constant +from threading import Thread +import queue +from shapely.geometry import Polygon +from types import SimpleNamespace +import shutil +import numpy as np +from collections import namedtuple + + +def extract_subtitles(data, text_recogniser, img, raw_subtitle_file, + sub_area, options, dt_box_arg, rec_res_arg, ocr_loss_debug_path): + """ + 提取视频帧中的字幕信息 + """ + # 从参数中获取检测框与检测结果 + dt_box = dt_box_arg + rec_res = rec_res_arg + # 如果没有检测结果,则获取检测结果 + if dt_box is None or rec_res is None: + dt_box, rec_res = text_recogniser.predict(img) + # rec_res格式为: ("hello", 0.997) + # 获取文本坐标 + coordinates = get_coordinates(dt_box) + # 将结果写入txt文本中 + if options.REC_CHAR_TYPE == 'en': + # 如果识别语言为英文,则去除中文 + text_res = [(re.sub('[\u4e00-\u9fa5]', '', res[0]), res[1]) for res in rec_res] + else: + text_res = [(res[0], res[1]) for res in rec_res] + line = '' + loss_list = [] + for content, coordinate in zip(text_res, coordinates): + text = content[0] + prob = content[1] + if sub_area is not None: + selected = False + # 初始化超界偏差为0 + overflow_area_rate = 0 + # 用户指定的字幕区域 + sub_area_polygon = sub_area_to_polygon(sub_area) + # 识别出的字幕区域 + coordinate_polygon = coordinate_to_polygon(coordinate) + # 计算两个区域是否有交集交集 + intersection = sub_area_polygon.intersection(coordinate_polygon) + # 如果有交集 + if not intersection.is_empty: + # 计算越界允许偏差 + overflow_area_rate = ((sub_area_polygon.area + coordinate_polygon.area - intersection.area) / sub_area_polygon.area) - 1 + # 如果越界比例低于设定阈值且该行文本识别的置信度高于设定阈值 + if overflow_area_rate <= options.SUB_AREA_DEVIATION_RATE and prob > options.DROP_SCORE: + # 保留该帧 + selected = True + line += f'{str(data["i"]).zfill(8)}\t{coordinate}\t{text}\n' + raw_subtitle_file.write(f'{str(data["i"]).zfill(8)}\t{coordinate}\t{text}\n') + # 保存丢掉的识别结果 + loss_info = namedtuple('loss_info', 'text prob overflow_area_rate coordinate selected') + loss_list.append(loss_info(text, prob, overflow_area_rate, coordinate, selected)) + else: + raw_subtitle_file.write(f'{str(data["i"]).zfill(8)}\t{coordinate}\t{text}\n') + # 输出调试信息 + dump_debug_info(options, line, img, loss_list, ocr_loss_debug_path, sub_area, data) + + +def dump_debug_info(options, line, img, loss_list, ocr_loss_debug_path, sub_area, data): + loss = False + if options.DEBUG_OCR_LOSS and options.REC_CHAR_TYPE in ('ch', 'japan ', 'korea', 'ch_tra'): + loss = len(line) > 0 and re.search(r'[\u4e00-\u9fa5\u3400-\u4db5\u3130-\u318F\uAC00-\uD7A3\u0800-\u4e00]', line) is None + if loss: + if not os.path.exists(ocr_loss_debug_path): + os.makedirs(ocr_loss_debug_path, mode=0o777, exist_ok=True) + img = cv2.rectangle(img, (sub_area[2], sub_area[0]), (sub_area[3], sub_area[1]), constant.BGR_COLOR_BLUE, 2) + for loss_info in loss_list: + coordinate = loss_info.coordinate + color = constant.BGR_COLOR_GREEN if loss_info.selected else constant.BGR_COLOR_RED + text = f"[{loss_info.text}] prob:{loss_info.prob:.4f} or:{loss_info.overflow_area_rate:.2f}" + img = paint_chinese_opencv(img, text, pos=(coordinate[0], coordinate[2] - 30), color=color) + img = cv2.rectangle(img, (coordinate[0], coordinate[2]), (coordinate[1], coordinate[3]), color, 2) + cv2.imwrite(os.path.join(os.path.abspath(ocr_loss_debug_path), f'{str(data["i"]).zfill(8)}.png'), img) + + +def sub_area_to_polygon(sub_area): + s_ymin = sub_area[0] + s_ymax = sub_area[1] + s_xmin = sub_area[2] + s_xmax = sub_area[3] + return Polygon([[s_xmin, s_ymin], [s_xmax, s_ymin], [s_xmax, s_ymax], [s_xmin, s_ymax]]) + + +def coordinate_to_polygon(coordinate): + xmin = coordinate[0] + xmax = coordinate[1] + ymin = coordinate[2] + ymax = coordinate[3] + return Polygon([[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]]) + + +FONT_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'NotoSansCJK-Bold.otf') +FONT = ImageFont.truetype(FONT_PATH, 20) + + +def paint_chinese_opencv(im, chinese, pos, color): + img_pil = Image.fromarray(im) + fill_color = color # (color[2], color[1], color[0]) + position = pos + draw = ImageDraw.Draw(img_pil) + draw.text(position, chinese, font=FONT, fill=fill_color) + img = np.asarray(img_pil) + return img + + +def ocr_task_consumer(ocr_queue, raw_subtitle_path, sub_area, video_path, options): + """ + 消费者: 消费ocr_queue,将ocr队列中的数据取出,进行ocr识别,写入字幕文件中 + :param ocr_queue (current_frame_no当前帧帧号, frame 视频帧, dt_box检测框, rec_res识别结果) + :param raw_subtitle_path + :param sub_area + :param video_path + :param options + """ + data = {'i': 1} + # 初始化文本识别对象 + text_recogniser = OcrRecogniser() + # 丢失字幕的存储路径 + ocr_loss_debug_path = os.path.join(os.path.abspath(os.path.splitext(video_path)[0]), 'loss') + # 删除之前的缓存垃圾 + if os.path.exists(ocr_loss_debug_path): + shutil.rmtree(ocr_loss_debug_path, True) + + with open(raw_subtitle_path, mode='w+', encoding='utf-8') as raw_subtitle_file: + while True: + try: + frame_no, frame, dt_box, rec_res = ocr_queue.get(block=True) + if frame_no == -1: + return + data['i'] = frame_no + extract_subtitles(data, text_recogniser, frame, raw_subtitle_file, sub_area, options, dt_box, + rec_res, ocr_loss_debug_path) + except Exception as e: + print(e) + break + + +def ocr_task_producer(ocr_queue, task_queue, progress_queue, video_path, raw_subtitle_path): + """ + 生产者:负责生产用于OCR识别的数据,将需要进行ocr识别的数据加入ocr_queue中 + :param ocr_queue (current_frame_no当前帧帧号, frame 视频帧, dt_box检测框, rec_res识别结果) + :param task_queue (total_frame_count总帧数, current_frame_no当前帧帧号, dt_box检测框, rec_res识别结果, subtitle_area字幕区域) + :param progress_queue + :param video_path + :param raw_subtitle_path + """ + cap = cv2.VideoCapture(video_path) + tbar = None + while True: + try: + # 从任务队列中提取任务信息 + total_frame_count, current_frame_no, dt_box, rec_res, total_ms, default_subtitle_area = task_queue.get(block=True) + progress_queue.put(current_frame_no) + if tbar is None: + tbar = tqdm(total=round(total_frame_count), position=1) + # current_frame 等于-1说明所有视频帧已经读完 + if current_frame_no == -1: + # ocr识别队列加入结束标志 + ocr_queue.put((-1, None, None, None)) + # 更新进度条 + tbar.update(tbar.total - tbar.n) + break + tbar.update(round(current_frame_no - tbar.n)) + # 设置当前视频帧 + # 如果total_ms不为空,则使用了VSF提取字幕 + if total_ms is not None: + cap.set(cv2.CAP_PROP_POS_MSEC, total_ms) + else: + cap.set(cv2.CAP_PROP_POS_FRAMES, current_frame_no - 1) + # 读取视频帧 + ret, frame = cap.read() + ocr = OcrRecogniser() + dt_box, rec_res = ocr.predict(frame) + # 如果读取成功 + if ret: + # 根据默认字幕位置,则对视频帧进行裁剪,裁剪后处理 + if default_subtitle_area is not None: + frame = frame_preprocess(default_subtitle_area, frame) + ocr_queue.put((current_frame_no, frame, dt_box, rec_res)) + except Exception as e: + print(e) + break + cap.release() + + +def subtitle_extract_handler(task_queue, progress_queue, video_path, raw_subtitle_path, sub_area, options): + """ + 创建并开启一个视频帧提取线程与一个ocr识别线程 + :param task_queue 任务队列,(total_frame_count总帧数, current_frame_no当前帧, dt_box检测框, rec_res识别结果, subtitle_area字幕区域) + :param progress_queue 进度队列 + :param video_path 视频路径 + :param raw_subtitle_path 原始字幕文件路径 + :param sub_area 字幕区域 + :param options 选项 + """ + # 删除缓存 + if os.path.exists(raw_subtitle_path): + os.remove(raw_subtitle_path) + # 创建一个OCR队列,大小建议值8-20 + ocr_queue = queue.Queue(20) + # 创建一个OCR事件生产者线程 + ocr_event_producer_thread = Thread(target=ocr_task_producer, + args=(ocr_queue, task_queue, progress_queue, video_path, raw_subtitle_path,), + daemon=True) + # 创建一个OCR事件消费者提取线程 + ocr_event_consumer_thread = Thread(target=ocr_task_consumer, + args=(ocr_queue, raw_subtitle_path, sub_area, video_path, options,), + daemon=True) + # 开启消费者线程 + ocr_event_producer_thread.start() + # 开启生产者线程 + ocr_event_consumer_thread.start() + # join方法让主线程任务结束之后,进入阻塞状态,一直等待其他的子线程执行结束之后,主线程再终止 + ocr_event_producer_thread.join() + ocr_event_consumer_thread.join() + + +def async_start(video_path, raw_subtitle_path, sub_area, options): + """ + 开始进程处理异步任务 + options.REC_CHAR_TYPE + options.DROP_SCORE + options.SUB_AREA_DEVIATION_RATE + options.DEBUG_OCR_LOSS + """ + assert 'REC_CHAR_TYPE' in options, "options缺少参数:REC_CHAR_TYPE" + assert 'DROP_SCORE' in options, "options缺少参数: DROP_SCORE'" + assert 'SUB_AREA_DEVIATION_RATE' in options, "options缺少参数: SUB_AREA_DEVIATION_RATE" + assert 'DEBUG_OCR_LOSS' in options, "options缺少参数: DEBUG_OCR_LOSS" + # 创建一个任务队列 + # 任务格式为:(total_frame_count总帧数, current_frame_no当前帧, dt_box检测框, rec_res识别结果, subtitle_area字幕区域) + task_queue = Queue() + # 创建一个进度更新队列 + progress_queue = Queue() + # 新建一个进程 + p = Process(target=subtitle_extract_handler, + args=(task_queue, progress_queue, video_path, raw_subtitle_path, sub_area, SimpleNamespace(**options),)) + # 启动进程 + p.start() + return p, task_queue, progress_queue + + +def frame_preprocess(subtitle_area, frame): + """ + 将视频帧进行裁剪 + """ + # 对于分辨率大于1920*1080的视频,将其视频帧进行等比缩放至1280*720进行识别 + # paddlepaddle会将图像压缩为640*640 + # if self.frame_width > 1280: + # scale_rate = round(float(1280 / self.frame_width), 2) + # frames = cv2.resize(frames, None, fx=scale_rate, fy=scale_rate, interpolation=cv2.INTER_AREA) + # 如果字幕出现的区域在下部分 + if subtitle_area == SubtitleArea.LOWER_PART: + cropped = int(frame.shape[0] // 2) + # 将视频帧切割为下半部分 + frame = frame[cropped:] + # 如果字幕出现的区域在上半部分 + elif subtitle_area == SubtitleArea.UPPER_PART: + cropped = int(frame.shape[0] // 2) + # 将视频帧切割为下半部分 + frame = frame[:cropped] + return frame + + +if __name__ == "__main__": + pass diff --git a/backend/tools/test_hubserving.py b/backend/tools/test_hubserving.py index 3beb4965..ec17a941 100755 --- a/backend/tools/test_hubserving.py +++ b/backend/tools/test_hubserving.py @@ -25,7 +25,9 @@ import time from PIL import Image from ppocr.utils.utility import get_image_file_list -from tools.infer.utility import draw_ocr, draw_boxes +from tools.infer.utility import draw_ocr, draw_boxes, str2bool +from ppstructure.utility import draw_structure_result +from ppstructure.predict_system import to_excel import requests import json @@ -64,12 +66,38 @@ def draw_server_result(image_file, res): scores.append(res[dno]['confidence']) boxes = np.array(boxes) scores = np.array(scores) - draw_img = draw_ocr(image, boxes, texts, scores, drop_score=0.5) + draw_img = draw_ocr( + image, boxes, texts, scores, draw_txt=True, drop_score=0.5) return draw_img -def main(url, image_path): - image_file_list = get_image_file_list(image_path) +def save_structure_res(res, save_folder, image_file): + img = cv2.imread(image_file) + excel_save_folder = os.path.join(save_folder, os.path.basename(image_file)) + os.makedirs(excel_save_folder, exist_ok=True) + # save res + with open( + os.path.join(excel_save_folder, 'res.txt'), 'w', + encoding='utf8') as f: + for region in res: + if region['type'] == 'Table': + excel_path = os.path.join(excel_save_folder, + '{}.xlsx'.format(region['bbox'])) + to_excel(region['res'], excel_path) + elif region['type'] == 'Figure': + x1, y1, x2, y2 = region['bbox'] + print(region['bbox']) + roi_img = img[y1:y2, x1:x2, :] + img_path = os.path.join(excel_save_folder, + '{}.jpg'.format(region['bbox'])) + cv2.imwrite(img_path, roi_img) + else: + for text_result in region['res']: + f.write('{}\n'.format(json.dumps(text_result))) + + +def main(args): + image_file_list = get_image_file_list(args.image_dir) is_visualize = False headers = {"Content-type": "application/json"} cnt = 0 @@ -79,38 +107,51 @@ def main(url, image_path): if img is None: logger.info("error in loading image:{}".format(image_file)) continue - - # 发送HTTP请求 + img_name = os.path.basename(image_file) + # seed http request starttime = time.time() data = {'images': [cv2_to_base64(img)]} - r = requests.post(url=url, headers=headers, data=json.dumps(data)) + r = requests.post( + url=args.server_url, headers=headers, data=json.dumps(data)) elapse = time.time() - starttime total_time += elapse logger.info("Predict time of %s: %.3fs" % (image_file, elapse)) res = r.json()["results"][0] logger.info(res) - if is_visualize: - draw_img = draw_server_result(image_file, res) + if args.visualize: + draw_img = None + if 'structure_table' in args.server_url: + to_excel(res['html'], './{}.xlsx'.format(img_name)) + elif 'structure_system' in args.server_url: + save_structure_res(res['regions'], args.output, image_file) + else: + draw_img = draw_server_result(image_file, res) if draw_img is not None: - draw_img_save = "./server_results/" - if not os.path.exists(draw_img_save): - os.makedirs(draw_img_save) + if not os.path.exists(args.output): + os.makedirs(args.output) cv2.imwrite( - os.path.join(draw_img_save, os.path.basename(image_file)), + os.path.join(args.output, os.path.basename(image_file)), draw_img[:, :, ::-1]) logger.info("The visualized image saved in {}".format( - os.path.join(draw_img_save, os.path.basename(image_file)))) + os.path.join(args.output, os.path.basename(image_file)))) cnt += 1 if cnt % 100 == 0: logger.info("{} processed".format(cnt)) logger.info("avg time cost: {}".format(float(total_time) / cnt)) +def parse_args(): + import argparse + parser = argparse.ArgumentParser(description="args for hub serving") + parser.add_argument("--server_url", type=str, required=True) + parser.add_argument("--image_dir", type=str, required=True) + parser.add_argument("--visualize", type=str2bool, default=False) + parser.add_argument("--output", type=str, default='./hubserving_result') + args = parser.parse_args() + return args + + if __name__ == '__main__': - if len(sys.argv) != 3: - logger.info("Usage: %s server_url image_path" % sys.argv[0]) - else: - server_url = sys.argv[1] - image_path = sys.argv[2] - main(server_url, image_path) + args = parse_args() + main(args) diff --git a/backend/tools/train.py b/backend/tools/train.py index fab10b64..42aba548 100755 --- a/backend/tools/train.py +++ b/backend/tools/train.py @@ -21,21 +21,20 @@ __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) -sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..'))) import yaml import paddle import paddle.distributed as dist -paddle.seed(2) - from ppocr.data import build_dataloader from ppocr.modeling.architectures import build_model from ppocr.losses import build_loss from ppocr.optimizer import build_optimizer from ppocr.postprocess import build_post_process from ppocr.metrics import build_metric -from ppocr.utils.save_load import init_model +from ppocr.utils.save_load import load_model +from ppocr.utils.utility import set_seed import tools.program as program dist.get_world_size() @@ -52,7 +51,10 @@ def main(config, device, logger, vdl_writer): train_dataloader = build_dataloader(config, 'Train', device, logger) if len(train_dataloader) == 0: logger.error( - 'No Images in train dataset, please check annotation file and path in the configuration file' + "No Images in train dataset, please ensure\n" + + "\t1. The images num in the train label_file_list should be larger than or equal with batch size.\n" + + + "\t2. The annotation file and path in the configuration file are provided normally." ) return @@ -69,7 +71,52 @@ def main(config, device, logger, vdl_writer): # for rec algorithm if hasattr(post_process_class, 'character'): char_num = len(getattr(post_process_class, 'character')) - config['Architecture']["Head"]['out_channels'] = char_num + if config['Architecture']["algorithm"] in ["Distillation", + ]: # distillation model + for key in config['Architecture']["Models"]: + if config['Architecture']['Models'][key]['Head'][ + 'name'] == 'MultiHead': # for multi head + if config['PostProcess'][ + 'name'] == 'DistillationSARLabelDecode': + char_num = char_num - 2 + # update SARLoss params + assert list(config['Loss']['loss_config_list'][-1].keys())[ + 0] == 'DistillationSARLoss' + config['Loss']['loss_config_list'][-1][ + 'DistillationSARLoss']['ignore_index'] = char_num + 1 + out_channels_list = {} + out_channels_list['CTCLabelDecode'] = char_num + out_channels_list['SARLabelDecode'] = char_num + 2 + config['Architecture']['Models'][key]['Head'][ + 'out_channels_list'] = out_channels_list + else: + config['Architecture']["Models"][key]["Head"][ + 'out_channels'] = char_num + elif config['Architecture']['Head'][ + 'name'] == 'MultiHead': # for multi head + if config['PostProcess']['name'] == 'SARLabelDecode': + char_num = char_num - 2 + # update SARLoss params + assert list(config['Loss']['loss_config_list'][1].keys())[ + 0] == 'SARLoss' + if config['Loss']['loss_config_list'][1]['SARLoss'] is None: + config['Loss']['loss_config_list'][1]['SARLoss'] = { + 'ignore_index': char_num + 1 + } + else: + config['Loss']['loss_config_list'][1]['SARLoss'][ + 'ignore_index'] = char_num + 1 + out_channels_list = {} + out_channels_list['CTCLabelDecode'] = char_num + out_channels_list['SARLabelDecode'] = char_num + 2 + config['Architecture']['Head'][ + 'out_channels_list'] = out_channels_list + else: # base rec model + config['Architecture']["Head"]['out_channels'] = char_num + + if config['PostProcess']['name'] == 'SARLabelDecode': # for SAR model + config['Loss']['ignore_index'] = char_num - 1 + model = build_model(config['Architecture']) if config['Global']['distributed']: model = paddle.DataParallel(model) @@ -82,19 +129,38 @@ def main(config, device, logger, vdl_writer): config['Optimizer'], epochs=config['Global']['epoch_num'], step_each_epoch=len(train_dataloader), - parameters=model.parameters()) + model=model) # build metric eval_class = build_metric(config['Metric']) # load pretrain model - pre_best_model_dict = init_model(config, model, logger, optimizer) + pre_best_model_dict = load_model(config, model, optimizer, + config['Architecture']["model_type"]) + logger.info('train dataloader has {} iters'.format(len(train_dataloader))) + if valid_dataloader is not None: + logger.info('valid dataloader has {} iters'.format( + len(valid_dataloader))) + + use_amp = config["Global"].get("use_amp", False) + if use_amp: + AMP_RELATED_FLAGS_SETTING = { + 'FLAGS_cudnn_batchnorm_spatial_persistent': 1, + 'FLAGS_max_inplace_grad_add': 8, + } + paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING) + scale_loss = config["Global"].get("scale_loss", 1.0) + use_dynamic_loss_scaling = config["Global"].get( + "use_dynamic_loss_scaling", False) + scaler = paddle.amp.GradScaler( + init_loss_scaling=scale_loss, + use_dynamic_loss_scaling=use_dynamic_loss_scaling) + else: + scaler = None - logger.info('train dataloader has {} iters, valid dataloader has {} iters'. - format(len(train_dataloader), len(valid_dataloader))) # start train program.train(config, train_dataloader, valid_dataloader, device, model, loss_class, optimizer, lr_scheduler, post_process_class, - eval_class, pre_best_model_dict, logger, vdl_writer) + eval_class, pre_best_model_dict, logger, vdl_writer, scaler) def test_reader(config, device, logger): @@ -117,5 +183,7 @@ def test_reader(config, device, logger): if __name__ == '__main__': config, device, logger, vdl_writer = program.preprocess(is_train=True) + seed = config['Global']['seed'] if 'seed' in config['Global'] else 1024 + set_seed(seed) main(config, device, logger, vdl_writer) # test_reader(config, device, logger) diff --git a/config.py b/config.py deleted file mode 100644 index 80a930c6..00000000 --- a/config.py +++ /dev/null @@ -1,75 +0,0 @@ -# -*- coding: utf-8 -*- -""" -@Author : Fang Yao -@Time : 2021/3/24 9:36 上午 -@FileName: config.py -@desc: 项目配置文件,可以在这里调参,牺牲时间换取精确度,或者牺牲准确度换取时间 -""" -import os -from pathlib import Path -from enum import Enum -from fsplit.filesplit import Filesplit - -# --------------------- 请你不要改 start----------------------------- -# 项目的base目录 -BASE_DIR = str(Path(os.path.abspath(__file__)).parent) - -# 模型文件目录 -# 文本检测模型 -DET_MODEL_PATH = os.path.join(BASE_DIR, 'backend', 'models', 'ch_det') -# 文本识别模型 -REC_MODEL_PATH = os.path.join(BASE_DIR, 'backend', 'models', 'ch_rec') - -# 查看该路径下是否有文本模型识别完整文件,没有的话合并小文件生成完整文件 -if 'inference.pdiparams' not in (os.listdir(REC_MODEL_PATH)): - fs = Filesplit() - fs.merge(input_dir=REC_MODEL_PATH) - -# 字典路径 -DICT_PATH = os.path.join(BASE_DIR, 'backend', 'ppocr', 'utils', 'ppocr_keys_v1.txt') - - -# 默认字幕出现的大致区域 -class SubtitleArea(Enum): - # 字幕区域出现在下半部分 - LOWER_PART = 0 - # 字幕区域出现在上半部分 - UPPER_PART = 1 - # 不知道字幕区域可能出现的位置 - UNKNOWN = 2 - # 明确知道字幕区域出现的位置 - CUSTOM = 3 -# --------------------- 请你不要改 end----------------------------- - - -# --------------------- 请根据自己的实际情况改 start----------------------------- -# 是否使用GPU -# 使用GPU可以提速20倍+,你要是有N卡你就改成 True -USE_GPU = False - -# 默认字幕出现区域为下方 -SUBTITLE_AREA = SubtitleArea.LOWER_PART - -# 余弦相似度阈值 -# 数值越小生成的视频帧越少,相对提取速度更快但生成的字幕越不精准 -# 1表示最精准,每一帧视频帧都进行字幕检测与提取,生成的字幕最精准 -# 0.925表示,当视频帧1与视频帧2相似度高达92.5%时,视频帧2将直接pass,不字检测与提取视频帧2的字幕 -COSINE_SIMILARITY_THRESHOLD = 0.95 if SUBTITLE_AREA == SubtitleArea.UNKNOWN else 0.91 - -# 欧式距离相似值 -EUCLIDEAN_SIMILARITY_THRESHOLD = 0.9 - -# 容忍的像素点偏差 -PIXEL_TOLERANCE_Y = 50 # 允许检测框纵向偏差50个像素点 -PIXEL_TOLERANCE_X = 100 # 允许检测框横向偏差100个像素点 - -# 字幕区域偏移量 -SUBTITLE_AREA_DEVIATION_PIXEL = 50 - -# 最有可能出现的水印区域 -WATERMARK_AREA_NUM = 5 - -# 文本相似度阈值 -# 用于去重时判断两行字幕是不是统一行 -TEXT_SIMILARITY_THRESHOLD = 0.95 -# --------------------- 请根据自己的实际情况改 end----------------------------- diff --git a/design/UI design.png b/design/UI design.png new file mode 100644 index 00000000..fb2e0ed3 Binary files /dev/null and b/design/UI design.png differ diff --git a/design/bg.png b/design/bg.png new file mode 100644 index 00000000..9223cb14 Binary files /dev/null and b/design/bg.png differ diff --git a/design/demo.gif b/design/demo.gif new file mode 100644 index 00000000..276874e8 Binary files /dev/null and b/design/demo.gif differ diff --git a/design/demo.png b/design/demo.png new file mode 100644 index 00000000..a9c20e8b Binary files /dev/null and b/design/demo.png differ diff --git a/design/demo2.gif b/design/demo2.gif new file mode 100644 index 00000000..d96b9927 Binary files /dev/null and b/design/demo2.gif differ diff --git a/design/gui.spec b/design/gui.spec new file mode 100644 index 00000000..4fe25080 --- /dev/null +++ b/design/gui.spec @@ -0,0 +1,41 @@ +# -*- mode: python ; coding: utf-8 -*- + + +block_cipher = None + + +a = Analysis(['gui.py'], + pathex=['/Users/yao/anaconda3/envs/subEnv/lib/python3.7/site-packages', '/Users/yao/Github/video-subtitle-extractor'], + binaries=[('/Users/yao/Github/video-subtitle-extractor/dylib/libgeos_c.dylib', '.')], + datas=[('/Users/yao/Github/video-subtitle-extractor/backend', 'backend'), + ('/Users/yao/Github/video-subtitle-extractor/vse.ico', '.') + ], + hiddenimports=['imgaug', 'skimage.filters.rank.core_cy_3d', + 'pyclipper', 'lmdb'], + hookspath=[], + runtime_hooks=[], + excludes=[], + win_no_prefer_redirects=False, + win_private_assemblies=False, + cipher=block_cipher, + noarchive=False) +pyz = PYZ(a.pure, a.zipped_data, + cipher=block_cipher) +exe = EXE(pyz, + a.scripts, + a.binaries, + a.zipfiles, + a.datas, + [], + name='vse', + debug=False, + bootloader_ignore_signals=False, + strip=False, + upx=True, + upx_exclude=[], + runtime_tmpdir=None, + console=False , icon='vse.ico') +app = BUNDLE(exe, + name='Subtitle Extractor.app', + icon='vse.ico', + bundle_identifier=None) diff --git a/design/paper (2020).pdf b/design/paper (2020).pdf new file mode 100644 index 00000000..440a065a Binary files /dev/null and b/design/paper (2020).pdf differ diff --git a/design/vse.ico b/design/vse.ico new file mode 100644 index 00000000..84a782b7 Binary files /dev/null and b/design/vse.ico differ diff --git a/google_colab.ipynb b/google_colab.ipynb new file mode 100644 index 00000000..61c4789e --- /dev/null +++ b/google_colab.ipynb @@ -0,0 +1,164 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "WAJ7lA2wuvR8" + }, + "source": [ + "# 运行教程\n", + "\n", + "1. 点击“修改” -> \"笔记本设置\" -> \"硬件加速器GPU\" -> 保存\n", + "\n", + "\n", + "2. 点左侧文件夹图标,上传视频文件,复制上传的视频路径\n", + "\n", + "\n", + "\n", + "3. 运行代码, 输入粘贴的视频路径\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_jPi_FBwyZyr" + }, + "source": [ + "查看是否有GPU" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "eHPHc_Bheo-j" + }, + "outputs": [], + "source": [ + "!nvidia-smi" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "TkQKKGKZkkT2" + }, + "outputs": [], + "source": [ + "!nvcc -V" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_85O6zgPyhto" + }, + "source": [ + "# 安装依赖" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ICeq0T1FeqjT" + }, + "outputs": [], + "source": [ + "!git clone --depth=1 https://github.com/YaoFANGUK/video-subtitle-extractor.git" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "GHutEWynkMKR" + }, + "outputs": [], + "source": [ + "cd video-subtitle-extractor" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ynJydzo1kMKR" + }, + "outputs": [], + "source": [ + "!pip install -r requirements_gpu.txt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3-GdvmaGl-aF" + }, + "outputs": [], + "source": [ + "!pip install paddlepaddle-gpu==2.4.2.post117 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SGb0i3tPyw9Q" + }, + "source": [ + "# 运行程序" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "输入视频路径,如:/content/video-subtitle-extractor/test/test_cn2.mp4\n", + "\n", + "输入字幕区域,如:842 1069 72 1368" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "B2MPjMOOgGbD" + }, + "outputs": [], + "source": [ + "!python ./backend/main.py" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "video-subtitle-extractor.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/google_colab_en.ipynb b/google_colab_en.ipynb new file mode 100644 index 00000000..aac8014e --- /dev/null +++ b/google_colab_en.ipynb @@ -0,0 +1,177 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Set Up\n", + "\n", + "1. click “Edit” -> \"Notebook Settings\" -> \"Hardware accelerator, GPU\" -> Save\n", + "\n", + "\n", + "2. Click the folder icon on the left, upload your video file, and copy the uploaded video path\n", + "\n", + "\n", + "\n", + "3. Run the code, enter the pasted video path\n", + "\n", + "\n" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_jPi_FBwyZyr" + }, + "source": [ + "check whether GPU works" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "eHPHc_Bheo-j" + }, + "outputs": [], + "source": [ + "!nvidia-smi" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "TkQKKGKZkkT2" + }, + "outputs": [], + "source": [ + "!nvcc -V" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_85O6zgPyhto" + }, + "source": [ + "# Install Dependencies" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ICeq0T1FeqjT" + }, + "outputs": [], + "source": [ + "!git clone --depth=1 https://github.com/YaoFANGUK/video-subtitle-extractor.git" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "GHutEWynkMKR" + }, + "outputs": [], + "source": [ + "cd video-subtitle-extractor" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ynJydzo1kMKR" + }, + "outputs": [], + "source": [ + "!pip install -r requirements_gpu.txt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "!echo -e '[DEFAULT]\\nInterface = English\\nLanguage = en\\nMode = fast' > /content/video-subtitle-extractor/settings.ini" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3-GdvmaGl-aF" + }, + "outputs": [], + "source": [ + "!pip install paddlepaddle-gpu==2.4.2.post117 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SGb0i3tPyw9Q" + }, + "source": [ + "# Run" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here is an example:\n", + "\n", + "input video path: /content/video-subtitle-extractor/test/test_en.mp4\n", + "\n", + "input subtitle area: 612 717 90 1191" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "B2MPjMOOgGbD" + }, + "outputs": [], + "source": [ + "!python ./backend/main.py" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "video-subtitle-extractor.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/gui.py b/gui.py new file mode 100644 index 00000000..b9b9dfc8 --- /dev/null +++ b/gui.py @@ -0,0 +1,592 @@ +# -*- coding: utf-8 -*- +""" +@Author : Fang Yao +@Time : 2021/4/1 6:07 下午 +@FileName: gui.py +@desc: 字幕提取器图形化界面 +""" +import backend.main +import os +import configparser +import PySimpleGUI as sg +import cv2 +from threading import Thread +import multiprocessing + + +class SubtitleExtractorGUI: + def _load_config(self): + self.config_file = os.path.join(os.path.dirname(__file__), 'settings.ini') + self.subtitle_config_file = os.path.join(os.path.dirname(__file__), 'subtitle.ini') + self.config = configparser.ConfigParser() + self.interface_config = configparser.ConfigParser() + if not os.path.exists(self.config_file): + # 如果没有配置文件,默认弹出语言选择界面 + LanguageModeGUI(self).run() + self.INTERFACE_KEY_NAME_MAP = { + '简体中文': 'ch', + '繁體中文': 'chinese_cht', + 'English': 'en', + '한국어': 'ko', + '日本語': 'japan', + 'Tiếng Việt': 'vi', + 'Español': 'es' + } + self.config.read(self.config_file, encoding='utf-8') + self.interface_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'backend', 'interface', + f"{self.INTERFACE_KEY_NAME_MAP[self.config['DEFAULT']['Interface']]}.ini") + self.interface_config.read(self.interface_file, encoding='utf-8') + + def __init__(self): + # 初次运行检查运行环境是否正常 + from paddle import utils + utils.run_check() + self.font = 'Arial 10' + self.theme = 'LightBrown12' + sg.theme(self.theme) + self.icon = os.path.join(os.path.dirname(__file__), 'design', 'vse.ico') + self._load_config() + self.screen_width, self.screen_height = sg.Window.get_screen_size() + print(self.screen_width, self.screen_height) + # 设置视频预览区域大小 + self.video_preview_width = 960 + self.video_preview_height = self.video_preview_width * 9 // 16 + # 默认组件大小 + self.horizontal_slider_size = (120, 20) + self.output_size = (100, 10) + self.progressbar_size = (60, 20) + # 分辨率低于1080 + if self.screen_width // 2 < 960: + self.video_preview_width = 640 + self.video_preview_height = self.video_preview_width * 9 // 16 + self.horizontal_slider_size = (60, 20) + self.output_size = (58, 10) + self.progressbar_size = (28, 20) + # 字幕提取器布局 + self.layout = None + # 字幕提取其窗口 + self.window = None + # 视频路径 + self.video_path = None + # 视频cap + self.video_cap = None + # 视频的帧率 + self.fps = None + # 视频的帧数 + self.frame_count = None + # 视频的宽 + self.frame_width = None + # 视频的高 + self.frame_height = None + # 设置字幕区域高宽 + self.xmin = None + self.xmax = None + self.ymin = None + self.ymax = None + # 字幕提取器 + self.se = None + + def run(self): + # 创建布局 + self._create_layout() + # 创建窗口 + self.window = sg.Window(title=self.interface_config['SubtitleExtractorGUI']['Title'], layout=self.layout, + icon=self.icon) + while True: + # 循环读取事件 + event, values = self.window.read(timeout=10) + # 处理【打开】事件 + self._file_event_handler(event, values) + # 处理【滑动】事件 + self._slide_event_handler(event, values) + # 处理【识别语言】事件 + self._language_mode_event_handler(event) + # 处理【运行】事件 + self._run_event_handler(event, values) + # 如果关闭软件,退出 + if event == sg.WIN_CLOSED: + break + # 更新进度条 + if self.se is not None: + self.window['-PROG-'].update(self.se.progress_total) + if self.se.isFinished: + # 1) 打开修改字幕滑块区域按钮 + self.window['-Y-SLIDER-'].update(disabled=False) + self.window['-X-SLIDER-'].update(disabled=False) + self.window['-Y-SLIDER-H-'].update(disabled=False) + self.window['-X-SLIDER-W-'].update(disabled=False) + # 2) 打开【运行】、【打开】和【识别语言】按钮 + self.window['-RUN-'].update(disabled=False) + self.window['-FILE-'].update(disabled=False) + self.window['-FILE_BTN-'].update(disabled=False) + self.window['-LANGUAGE-MODE-'].update(disabled=False) + self.se = None + if len(self.video_paths) >= 1: + # 1) 关闭修改字幕滑块区域按钮 + self.window['-Y-SLIDER-'].update(disabled=True) + self.window['-X-SLIDER-'].update(disabled=True) + self.window['-Y-SLIDER-H-'].update(disabled=True) + self.window['-X-SLIDER-W-'].update(disabled=True) + # 2) 关闭【运行】、【打开】和【识别语言】按钮 + self.window['-RUN-'].update(disabled=True) + self.window['-FILE-'].update(disabled=True) + self.window['-FILE_BTN-'].update(disabled=True) + self.window['-LANGUAGE-MODE-'].update(disabled=True) + + def update_interface_text(self): + self._load_config() + self.window.set_title(self.interface_config['SubtitleExtractorGUI']['Title']) + self.window['-FILE_BTN-'].Update(self.interface_config['SubtitleExtractorGUI']['Open']) + self.window['-FRAME1-'].Update(self.interface_config['SubtitleExtractorGUI']['Vertical']) + self.window['-FRAME2-'].Update(self.interface_config['SubtitleExtractorGUI']['Horizontal']) + self.window['-RUN-'].Update(self.interface_config['SubtitleExtractorGUI']['Run']) + self.window['-LANGUAGE-MODE-'].Update(self.interface_config['SubtitleExtractorGUI']['Setting']) + + def _create_layout(self): + """ + 创建字幕提取器布局 + """ + garbage = os.path.join(os.path.dirname(__file__), 'output') + if os.path.exists(garbage): + import shutil + shutil.rmtree(garbage, True) + self.layout = [ + # 显示视频预览 + [sg.Image(size=(self.video_preview_width, self.video_preview_height), background_color='black', + key='-DISPLAY-')], + # 打开按钮 + 快进快退条 + [sg.Input(key='-FILE-', visible=False, enable_events=True), + sg.FilesBrowse(button_text=self.interface_config['SubtitleExtractorGUI']['Open'], file_types=(( + self.interface_config['SubtitleExtractorGUI']['AllFile'], '*.*'), ('mp4', '*.mp4'), + ('flv', '*.flv'), + ('wmv', '*.wmv'), + ('avi', '*.avi')), + key='-FILE_BTN-', size=(10, 1), font=self.font), + sg.Slider(size=self.horizontal_slider_size, range=(1, 1), key='-SLIDER-', orientation='h', + enable_events=True, font=self.font, + disable_number_display=True), + ], + # 输出区域 + [sg.Output(size=self.output_size, font=self.font), + sg.Frame(title=self.interface_config['SubtitleExtractorGUI']['Vertical'], font=self.font, key='-FRAME1-', + layout=[[ + sg.Slider(range=(0, 0), orientation='v', size=(10, 20), + disable_number_display=True, + enable_events=True, font=self.font, + pad=((10, 10), (20, 20)), + default_value=0, key='-Y-SLIDER-'), + sg.Slider(range=(0, 0), orientation='v', size=(10, 20), + disable_number_display=True, + enable_events=True, font=self.font, + pad=((10, 10), (20, 20)), + default_value=0, key='-Y-SLIDER-H-'), + ]], pad=((15, 5), (0, 0))), + sg.Frame(title=self.interface_config['SubtitleExtractorGUI']['Horizontal'], font=self.font, key='-FRAME2-', + layout=[[ + sg.Slider(range=(0, 0), orientation='v', size=(10, 20), + disable_number_display=True, + pad=((10, 10), (20, 20)), + enable_events=True, font=self.font, + default_value=0, key='-X-SLIDER-'), + sg.Slider(range=(0, 0), orientation='v', size=(10, 20), + disable_number_display=True, + pad=((10, 10), (20, 20)), + enable_events=True, font=self.font, + default_value=0, key='-X-SLIDER-W-'), + ]], pad=((15, 5), (0, 0))) + ], + + # 运行按钮 + 进度条 + [sg.Button(button_text=self.interface_config['SubtitleExtractorGUI']['Run'], key='-RUN-', + font=self.font, size=(20, 1)), + sg.Button(button_text=self.interface_config['SubtitleExtractorGUI']['Setting'], key='-LANGUAGE-MODE-', + font=self.font, size=(20, 1)), + sg.ProgressBar(100, orientation='h', size=self.progressbar_size, key='-PROG-', auto_size_text=True) + ], + ] + + def _file_event_handler(self, event, values): + """ + 当点击打开按钮时: + 1)打开视频文件,将画布显示视频帧 + 2)获取视频信息,初始化进度条滑块范围 + """ + if event == '-FILE-': + self.video_paths = values['-FILE-'].split(';') + self.video_path = self.video_paths[0] + if self.video_path != '': + self.video_cap = cv2.VideoCapture(self.video_path) + if self.video_cap is None: + return + if self.video_cap.isOpened(): + ret, frame = self.video_cap.read() + if ret: + for video in self.video_paths: + print(f"{self.interface_config['SubtitleExtractorGUI']['OpenVideoSuccess']}:{video}") + # 获取视频的帧数 + self.frame_count = self.video_cap.get(cv2.CAP_PROP_FRAME_COUNT) + # 获取视频的高度 + self.frame_height = self.video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT) + # 获取视频的宽度 + self.frame_width = self.video_cap.get(cv2.CAP_PROP_FRAME_WIDTH) + # 获取视频的帧率 + self.fps = self.video_cap.get(cv2.CAP_PROP_FPS) + # 调整视频帧大小,使播放器能够显示 + resized_frame = self._img_resize(frame) + # resized_frame = cv2.resize(src=frame, dsize=(self.video_preview_width, self.video_preview_height)) + # 显示视频帧 + self.window['-DISPLAY-'].update(data=cv2.imencode('.png', resized_frame)[1].tobytes()) + # 更新视频进度条滑块range + self.window['-SLIDER-'].update(range=(1, self.frame_count)) + self.window['-SLIDER-'].update(1) + # 预设字幕区域位置 + y_p, h_p, x_p, w_p = self.parse_subtitle_config() + y = self.frame_height * y_p + h = self.frame_height * h_p + x = self.frame_width * x_p + w = self.frame_width * w_p + # 更新视频字幕位置滑块range + # 更新Y-SLIDER范围 + self.window['-Y-SLIDER-'].update(range=(0, self.frame_height), disabled=False) + # 更新Y-SLIDER默认值 + self.window['-Y-SLIDER-'].update(y) + # 更新X-SLIDER范围 + self.window['-X-SLIDER-'].update(range=(0, self.frame_width), disabled=False) + # 更新X-SLIDER默认值 + self.window['-X-SLIDER-'].update(x) + # 更新Y-SLIDER-H范围 + self.window['-Y-SLIDER-H-'].update(range=(0, self.frame_height - y)) + # 更新Y-SLIDER-H默认值 + self.window['-Y-SLIDER-H-'].update(h) + # 更新X-SLIDER-W范围 + self.window['-X-SLIDER-W-'].update(range=(0, self.frame_width - x)) + # 更新X-SLIDER-W默认值 + self.window['-X-SLIDER-W-'].update(w) + self._update_preview(frame, (y, h, x, w)) + + def _language_mode_event_handler(self, event): + if event != '-LANGUAGE-MODE-': + return + if 'OK' == LanguageModeGUI(self).run(): + # 重新加载config + pass + + def _run_event_handler(self, event, values): + """ + 当点击运行按钮时: + 1) 禁止修改字幕滑块区域 + 2) 禁止再次点击【运行】和【打开】按钮 + 3) 设定字幕区域位置 + """ + if event == '-RUN-': + if self.video_cap is None: + print(self.interface_config['SubtitleExtractorGUI']['OpenVideoFirst']) + else: + # 1) 禁止修改字幕滑块区域 + self.window['-Y-SLIDER-'].update(disabled=True) + self.window['-X-SLIDER-'].update(disabled=True) + self.window['-Y-SLIDER-H-'].update(disabled=True) + self.window['-X-SLIDER-W-'].update(disabled=True) + # 2) 禁止再次点击【运行】、【打开】和【识别语言】按钮 + self.window['-RUN-'].update(disabled=True) + self.window['-FILE-'].update(disabled=True) + self.window['-FILE_BTN-'].update(disabled=True) + self.window['-LANGUAGE-MODE-'].update(disabled=True) + # 3) 设定字幕区域位置 + self.xmin = int(values['-X-SLIDER-']) + self.xmax = int(values['-X-SLIDER-'] + values['-X-SLIDER-W-']) + self.ymin = int(values['-Y-SLIDER-']) + self.ymax = int(values['-Y-SLIDER-'] + values['-Y-SLIDER-H-']) + if self.ymax > self.frame_height: + self.ymax = self.frame_height + if self.xmax > self.frame_width: + self.xmax = self.frame_width + print(f"{self.interface_config['SubtitleExtractorGUI']['SubtitleArea']}:({self.ymin},{self.ymax},{self.xmin},{self.xmax})") + subtitle_area = (self.ymin, self.ymax, self.xmin, self.xmax) + y_p = self.ymin / self.frame_height + h_p = (self.ymax - self.ymin) / self.frame_height + x_p = self.xmin / self.frame_width + w_p = (self.xmax - self.xmin) / self.frame_width + self.set_subtitle_config(y_p, h_p, x_p, w_p) + + def task(): + while self.video_paths: + video_path = self.video_paths.pop() + self.se = backend.main.SubtitleExtractor(video_path, subtitle_area) + self.se.run() + Thread(target=task, daemon=True).start() + self.video_cap.release() + self.video_cap = None + + def _slide_event_handler(self, event, values): + """ + 当滑动视频进度条/滑动字幕选择区域滑块时: + 1) 判断视频是否存在,如果存在则显示对应的视频帧 + 2) 绘制rectangle + """ + if event == '-SLIDER-' or event == '-Y-SLIDER-' or event == '-Y-SLIDER-H-' or event == '-X-SLIDER-' or event \ + == '-X-SLIDER-W-': + if self.video_cap is not None and self.video_cap.isOpened(): + frame_no = int(values['-SLIDER-']) + self.video_cap.set(cv2.CAP_PROP_POS_FRAMES, frame_no) + ret, frame = self.video_cap.read() + if ret: + self.window['-Y-SLIDER-H-'].update(range=(0, self.frame_height-values['-Y-SLIDER-'])) + self.window['-X-SLIDER-W-'].update(range=(0, self.frame_width-values['-X-SLIDER-'])) + # 画字幕框 + y = int(values['-Y-SLIDER-']) + h = int(values['-Y-SLIDER-H-']) + x = int(values['-X-SLIDER-']) + w = int(values['-X-SLIDER-W-']) + self._update_preview(frame, (y, h, x, w)) + + def _update_preview(self, frame, y_h_x_w): + y, h, x, w = y_h_x_w + # 画字幕框 + draw = cv2.rectangle(img=frame, pt1=(int(x), int(y)), pt2=(int(x) + int(w), int(y) + int(h)), + color=(0, 255, 0), thickness=3) + # 调整视频帧大小,使播放器能够显示 + resized_frame = self._img_resize(draw) + # 显示视频帧 + self.window['-DISPLAY-'].update(data=cv2.imencode('.png', resized_frame)[1].tobytes()) + + + def _img_resize(self, image): + top, bottom, left, right = (0, 0, 0, 0) + height, width = image.shape[0], image.shape[1] + # 对长短不想等的图片,找到最长的一边 + longest_edge = height + # 计算短边需要增加多少像素宽度使其与长边等长 + if width < longest_edge: + dw = longest_edge - width + left = dw // 2 + right = dw - left + else: + pass + # 给图像增加边界 + constant = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=[0, 0, 0]) + return cv2.resize(constant, (self.video_preview_width, self.video_preview_height)) + + def set_subtitle_config(self, y, h, x, w): + # 写入配置文件 + with open(self.subtitle_config_file, mode='w', encoding='utf-8') as f: + f.write('[AREA]\n') + f.write(f'Y = {y}\n') + f.write(f'H = {h}\n') + f.write(f'X = {x}\n') + f.write(f'W = {w}\n') + + def parse_subtitle_config(self): + y_p, h_p, x_p, w_p = .78, .21, .05, .9 + # 如果配置文件不存在,则写入配置文件 + if not os.path.exists(self.subtitle_config_file): + self.set_subtitle_config(y_p, h_p, x_p, w_p) + return y_p, h_p, x_p, w_p + else: + try: + config = configparser.ConfigParser() + config.read(self.subtitle_config_file, encoding='utf-8') + conf_y_p, conf_h_p, conf_x_p, conf_w_p = float(config['AREA']['Y']), float(config['AREA']['H']), float(config['AREA']['X']), float(config['AREA']['W']) + return conf_y_p, conf_h_p, conf_x_p, conf_w_p + except Exception: + self.set_subtitle_config(y_p, h_p, x_p, w_p) + return y_p, h_p, x_p, w_p + + +class LanguageModeGUI: + def __init__(self, subtitle_extractor_gui): + self.subtitle_extractor_gui = subtitle_extractor_gui + self.icon = os.path.join(os.path.dirname(__file__), 'design', 'vse.ico') + self.config_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'settings.ini') + # 设置界面 + self.INTERFACE_DEF = '简体中文' + if not os.path.exists(self.config_file): + self.interface_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'backend', 'interface', + "ch.ini") + self.interface_config = configparser.ConfigParser() + # 设置语言 + self.INTERFACE_KEY_NAME_MAP = { + '简体中文': 'ch', + '繁體中文': 'chinese_cht', + 'English': 'en', + '한국어': 'ko', + '日本語': 'japan', + 'Tiếng Việt': 'vi', + 'Español': 'es' + } + # 设置语言 + self.LANGUAGE_DEF = 'ch' + self.LANGUAGE_NAME_KEY_MAP = None + self.LANGUAGE_KEY_NAME_MAP = None + self.MODE_DEF = 'fast' + self.MODE_NAME_KEY_MAP = None + self.MODE_KEY_NAME_MAP = None + # 语言选择布局 + self.layout = None + # 语言选择窗口 + self.window = None + + def run(self): + # 创建布局 + title = self._create_layout() + # 创建窗口 + self.window = sg.Window(title=title, layout=self.layout, icon=self.icon) + while True: + # 循环读取事件 + event, values = self.window.read(timeout=10) + # 处理【OK】事件 + self._ok_event_handler(event, values) + # 处理【切换界面语言】事件 + self._interface_event_handler(event, values) + # 如果关闭软件,退出 + if event == sg.WIN_CLOSED: + if os.path.exists(self.config_file): + break + else: + exit(0) + if event == 'Cancel': + if os.path.exists(self.config_file): + self.window.close() + break + else: + exit(0) + + def _load_interface_text(self): + self.interface_config.read(self.interface_file, encoding='utf-8') + config_language_mode_gui = self.interface_config["LanguageModeGUI"] + # 设置界面 + self.INTERFACE_DEF = config_language_mode_gui["InterfaceDefault"] + + self.LANGUAGE_DEF = config_language_mode_gui["LanguageCH"] + self.LANGUAGE_NAME_KEY_MAP = {} + for lang in backend.main.config.MULTI_LANG: + self.LANGUAGE_NAME_KEY_MAP[config_language_mode_gui[f"Language{lang.upper()}"]] = lang + self.LANGUAGE_NAME_KEY_MAP = dict(sorted(self.LANGUAGE_NAME_KEY_MAP.items(), key=lambda item: item[1])) + self.LANGUAGE_KEY_NAME_MAP = {v: k for k, v in self.LANGUAGE_NAME_KEY_MAP.items()} + self.MODE_DEF = config_language_mode_gui['ModeFast'] + self.MODE_NAME_KEY_MAP = { + config_language_mode_gui['ModeAuto']: 'auto', + config_language_mode_gui['ModeFast']: 'fast', + config_language_mode_gui['ModeAccurate']: 'accurate', + } + self.MODE_KEY_NAME_MAP = {v: k for k, v in self.MODE_NAME_KEY_MAP.items()} + + def _create_layout(self): + interface_def, language_def, mode_def = self.parse_config(self.config_file) + # 加载界面文本 + self._load_interface_text() + choose_language_text = self.interface_config["LanguageModeGUI"]["InterfaceLanguage"] + choose_sub_lang_text = self.interface_config["LanguageModeGUI"]["SubtitleLanguage"] + choose_mode_text = self.interface_config["LanguageModeGUI"]["Mode"] + self.layout = [ + # 显示选择界面语言 + [sg.Text(choose_language_text), + sg.DropDown(values=list(self.INTERFACE_KEY_NAME_MAP.keys()), size=(30, 20), + pad=(0, 20), + key='-INTERFACE-', readonly=True, + default_value=interface_def), + sg.OK(key='-INTERFACE-OK-')], + # 显示选择字幕语言 + [sg.Text(choose_sub_lang_text), + sg.DropDown(values=list(self.LANGUAGE_NAME_KEY_MAP.keys()), size=(30, 20), + pad=(0, 20), + key='-LANGUAGE-', readonly=True, default_value=language_def)], + # 显示识别模式 + [sg.Text(choose_mode_text), + sg.DropDown(values=list(self.MODE_NAME_KEY_MAP.keys()), size=(30, 20), pad=(0, 20), + key='-MODE-', readonly=True, default_value=mode_def)], + # 显示确认关闭按钮 + [sg.OK(), sg.Cancel()] + ] + return self.interface_config["LanguageModeGUI"]["Title"] + + def _ok_event_handler(self, event, values): + if event == 'OK': + # 设置模型语言配置 + interface = None + language = None + mode = None + # 设置界面语言 + interface_str = values['-INTERFACE-'] + if interface_str in self.INTERFACE_KEY_NAME_MAP: + interface = interface_str + language_str = values['-LANGUAGE-'] + # 设置字幕语言 + print(self.interface_config["LanguageModeGUI"]["SubtitleLanguage"], language_str) + if language_str in self.LANGUAGE_NAME_KEY_MAP: + language = self.LANGUAGE_NAME_KEY_MAP[language_str] + # 设置模型语言配置 + mode_str = values['-MODE-'] + print(self.interface_config["LanguageModeGUI"]["Mode"], mode_str) + if mode_str in self.MODE_NAME_KEY_MAP: + mode = self.MODE_NAME_KEY_MAP[mode_str] + self.set_config(self.config_file, interface, language, mode) + if self.subtitle_extractor_gui is not None: + self.subtitle_extractor_gui.update_interface_text() + self.window.close() + + def _interface_event_handler(self, event, values): + if event == '-INTERFACE-OK-': + self.interface_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'backend', 'interface', + f"{self.INTERFACE_KEY_NAME_MAP[values['-INTERFACE-']]}.ini") + self.interface_config.read(self.interface_file, encoding='utf-8') + config = configparser.ConfigParser() + if os.path.exists(self.config_file): + config.read(self.config_file, encoding='utf-8') + self.set_config(self.config_file, values['-INTERFACE-'], config['DEFAULT']['Language'], + config['DEFAULT']['Mode']) + self.window.close() + title = self._create_layout() + self.window = sg.Window(title=title, layout=self.layout, icon=self.icon) + + @staticmethod + def set_config(config_file, interface, language_code, mode): + # 写入配置文件 + with open(config_file, mode='w', encoding='utf-8') as f: + f.write('[DEFAULT]\n') + f.write(f'Interface = {interface}\n') + f.write(f'Language = {language_code}\n') + f.write(f'Mode = {mode}\n') + + def parse_config(self, config_file): + if not os.path.exists(config_file): + self.interface_config.read(self.interface_file, encoding='utf-8') + interface_def = self.interface_config['LanguageModeGUI']['InterfaceDefault'] + language_def = self.interface_config['LanguageModeGUI']['InterfaceDefault'] + mode_def = self.interface_config['LanguageModeGUI']['ModeFast'] + return interface_def, language_def, mode_def + config = configparser.ConfigParser() + config.read(config_file, encoding='utf-8') + interface = config['DEFAULT']['Interface'] + language = config['DEFAULT']['Language'] + mode = config['DEFAULT']['Mode'] + self.interface_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'backend', 'interface', + f"{self.INTERFACE_KEY_NAME_MAP[interface]}.ini") + self._load_interface_text() + interface_def = interface if interface in self.INTERFACE_KEY_NAME_MAP else \ + self.INTERFACE_DEF + language_def = self.LANGUAGE_KEY_NAME_MAP[language] if language in self.LANGUAGE_KEY_NAME_MAP else \ + self.LANGUAGE_DEF + mode_def = self.MODE_KEY_NAME_MAP[mode] if mode in self.MODE_KEY_NAME_MAP else self.MODE_DEF + return interface_def, language_def, mode_def + + +if __name__ == '__main__': + try: + multiprocessing.set_start_method("spawn") + # 运行图形化界面 + subtitleExtractorGUI = SubtitleExtractorGUI() + subtitleExtractorGUI.run() + except Exception as e: + print(f'[{type(e)}] {e}') + import traceback + traceback.print_exc() + msg = traceback.format_exc() + err_log_path = os.path.join(os.path.expanduser('~'), 'VSE-Error-Message.log') + with open(err_log_path, 'w', encoding='utf-8') as f: + f.writelines(msg) + import platform + if platform.system() == 'Windows': + os.system('pause') + else: + input() diff --git a/main.py b/main.py deleted file mode 100644 index ad1c98ad..00000000 --- a/main.py +++ /dev/null @@ -1,545 +0,0 @@ -# -*- coding: utf-8 -*- -""" -@Author : Fang Yao -@Time : 2021/3/24 9:28 上午 -@FileName: main.py -@desc: 主程序入口文件 -""" -import config -from config import SubtitleArea -from backend.tools.infer.predict_system import TextSystem -from backend.tools.infer import utility -import cv2 -import random -import os -import math -from collections import Counter -import numpy as np -from PIL import Image -from numpy import average, dot, linalg -from Levenshtein import ratio - - -# 加载文本检测+识别模型 -def load_model(): - # 获取参数对象 - args = utility.parse_args() - # 设置文本检测模型路径 - args.det_model_dir = config.DET_MODEL_PATH - # 设置文本识别模型路径 - args.rec_model_dir = config.REC_MODEL_PATH - # 设置字典路径 - args.rec_char_dict_path = config.DICT_PATH - # 是否使用GPU加速 - args.use_gpu = config.USE_GPU - return TextSystem(args) - - -class SubtitleExtractor: - """ - 视频字幕提取类 - """ - - def __init__(self, vd_path): - # 视频路径 - self.video_path = vd_path - self.video_cap = cv2.VideoCapture(vd_path) - # 视频帧总数 - self.frame_count = self.video_cap.get(cv2.CAP_PROP_FRAME_COUNT) - # 视频帧率 - self.fps = self.video_cap.get(cv2.CAP_PROP_FPS) - # 视频尺寸 - self.frame_height = int(self.video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) - self.frame_width = int(self.video_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) - # 字幕出现区域 - self.subtitle_area = config.SUBTITLE_AREA - print(f'帧数:{self.frame_count},帧率:{self.fps}') - # 临时存储文件夹 - self.temp_output_dir = os.path.join(config.BASE_DIR, 'output') - # 提取的视频帧储存目录 - self.frame_output_dir = os.path.join(self.temp_output_dir, 'frames') - # 提取的字幕文件存储目录 - self.subtitle_output_dir = os.path.join(self.temp_output_dir, 'subtitle') - # 不存在则创建文件夹 - if not os.path.exists(self.frame_output_dir): - os.makedirs(self.frame_output_dir) - if not os.path.exists(self.subtitle_output_dir): - os.makedirs(self.subtitle_output_dir) - # 提取的原始字幕文本存储路径 - self.raw_subtitle_path = os.path.join(self.subtitle_output_dir, 'raw.txt') - - def run(self): - """ - 运行整个提取视频的步骤 - """ - self.extract_frame() - self.extract_subtitles() - self.detect_watermark_area() - self.filter_watermark() - self.detect_subtitle_area() - self.filter_scene_text() - self.generate_subtitle_file() - - def extract_frame(self): - """ - 根据视频的分辨率,将高分辨的视频帧缩放到1280*720p - 根据字幕区域位置,将该图像区域截取出来 - """ - # 当前视频帧的帧号 - frame_no = 0 - - while self.video_cap.isOpened(): - ret, frame = self.video_cap.read() - # 如果读取视频帧失败(视频读到最后一帧) - if not ret: - break - # 读取视频帧成功 - else: - frame_no += 1 - # 对于分辨率大于1920*1080的视频,将其视频帧进行等比缩放至1280*720进行识别 - # paddlepaddle会将图像压缩为640*640 - if self.frame_width > 1280: - scale_rate = round(float(1280 / self.frame_width), 2) - frame = cv2.resize(frame, None, fx=scale_rate, fy=scale_rate, interpolation=cv2.INTER_AREA) - - cropped = int(frame.shape[0] // 2) - - # 如果字幕出现的区域在下部分 - if self.subtitle_area == SubtitleArea.LOWER_PART: - # 将视频帧切割为下半部分 - frame = frame[cropped:] - # 如果字幕出现的区域在上半部分 - elif self.subtitle_area == SubtitleArea.UPPER_PART: - # 将视频帧切割为下半部分 - frame = frame[:cropped] - - # 帧名往前补零,后续用于排序与时间戳转换,补足8位 - # 一部10h电影,fps120帧最多也才1*60*60*120=432000 6位,所以8位足够 - filename = os.path.join(self.frame_output_dir, str(frame_no).zfill(8) + '.jpg') - # 保存视频帧 - cv2.imwrite(filename, frame) - - # 将当前帧与接下来的帧进行比较,计算余弦相似度 - while self.video_cap.isOpened(): - ret, frame_next = self.video_cap.read() - if ret: - frame_no += 1 - cosine_distance = self._compute_image_similarity(Image.fromarray(frame), - Image.fromarray(frame_next)) - print(cosine_distance) - if cosine_distance > config.COSINE_SIMILARITY_THRESHOLD: - # 如果下一帧与当前帧的相似度大于设定阈值,则略过该帧 - continue - # 如果相似度小于设定阈值,停止该while循环 - else: - break - else: - break - - self.video_cap.release() - cv2.destroyAllWindows() - - def extract_subtitles(self): - """ - 提取视频帧中的字幕信息,生成一个txt文件 - """ - # 初始化文本识别对象 - text_recogniser = load_model() - # 视频帧列表 - frame_list = [i for i in sorted(os.listdir(self.frame_output_dir)) if i.endswith('.jpg')] - - # 新建文件 - f = open(self.raw_subtitle_path, mode='w+', encoding='utf-8') - - for frame in frame_list: - # 读取视频帧 - img = cv2.imread(os.path.join(self.frame_output_dir, frame)) - # 获取检测结果 - dt_box, rec_res = text_recogniser(img) - # 获取文本坐标 - coordinates = self.__get_coordinates(dt_box) - # 将结果写入txt文本中 - for content, coordinate in zip(([res[0] for res in rec_res]), coordinates): - f.write(f'{os.path.splitext(frame)[0]}\t' - f'{coordinate}\t' - f'{content}\n') - # 关闭文件 - f.close() - - def detect_watermark_area(self): - """ - 根据识别出来的raw txt文件中的坐标点信息,查找水印区域 - 假定:水印区域(台标)的坐标在水平和垂直方向都是固定的,也就是具有(xmin, xmax, ymin, ymax)相对固定 - 根据坐标点信息,进行统计,将一直具有固定坐标的文本区域选出 - :return 返回最有可能的水印区域 - """ - f = open(self.raw_subtitle_path, mode='r', encoding='utf-8') # 打开txt文件,以‘utf-8’编码读取 - line = f.readline() # 以行的形式进行读取文件 - # 坐标点列表 - coordinates_list = [] - # 帧列表 - frame_no_list = [] - # 内容列表 - content_list = [] - while line: - frame_no = line.split('\t')[0] - text_position = line.split('\t')[1].split('(')[1].split(')')[0].split(', ') - content = line.split('\t')[2] - frame_no_list.append(frame_no) - coordinates_list.append((int(text_position[0]), - int(text_position[1]), - int(text_position[2]), - int(text_position[3]))) - content_list.append(content) - line = f.readline() - f.close() - # 将坐标列表的相似值统一 - coordinates_list = self._unite_coordinates(coordinates_list) - - # 将原txt文件的坐标更新为归一后的坐标 - with open(self.raw_subtitle_path, mode='w', encoding='utf-8') as f: - for frame_no, coordinate, content in zip(frame_no_list, coordinates_list, content_list): - f.write(f'{frame_no}\t{coordinate}\t{content}') - - if len(Counter(coordinates_list).most_common()) > config.WATERMARK_AREA_NUM: - # 读取配置文件,返回可能为水印区域的坐标列表 - return Counter(coordinates_list).most_common(config.WATERMARK_AREA_NUM) - else: - # 不够则有几个返回几个 - return Counter(coordinates_list).most_common() - - def filter_watermark(self): - """ - 去除原始字幕文本中的水印区域的文本 - """ - # 获取潜在水印区域 - watermark_areas = self.detect_watermark_area() - - # 从frame目录随机读取一张图片,将所水印区域标记出来,用户看图判断是否是水印区域 - frame_path = os.path.join(self.frame_output_dir, - random.choice( - [i for i in sorted(os.listdir(self.frame_output_dir)) if i.endswith('.jpg')])) - sample_frame = cv2.imread(frame_path) - - # 给潜在的水印区域编号 - area_num = ['E', 'D', 'C', 'B', 'A'] - - for watermark_area in watermark_areas: - ymin = watermark_area[0][2] - ymax = watermark_area[0][3] - xmin = watermark_area[0][0] - xmax = watermark_area[0][1] - cover = sample_frame[ymin:ymax, xmin:xmax] - cover = cv2.blur(cover, (10, 10)) - cv2.rectangle(cover, pt1=(0, cover.shape[0]), pt2=(cover.shape[1], 0), color=(0, 0, 255), thickness=3) - sample_frame[watermark_area[0][2]:watermark_area[0][3], watermark_area[0][0]:watermark_area[0][1]] = cover - position = ((xmin + xmax) // 2, ymax) - - cv2.putText(sample_frame, text=area_num.pop(), org=position, fontFace=cv2.FONT_HERSHEY_SIMPLEX, - fontScale=1, color=(255, 0, 0), thickness=2, lineType=cv2.LINE_AA) - - sample_frame_file_path = os.path.join(os.path.dirname(self.frame_output_dir), 'watermark_area.jpg') - cv2.imwrite(sample_frame_file_path, sample_frame) - print(f'请查看图片, 确定水印区域: {sample_frame_file_path}') - - area_num = ['E', 'D', 'C', 'B', 'A'] - for watermark_area in watermark_areas: - user_input = input(f'是否去除区域{area_num.pop()}{str(watermark_area)}中的字幕?' - f'\n输入 "y" 或 "回车" 表示去除,输入"n"或其他表示不去除: ').strip() - if user_input == 'y' or user_input == '\n': - with open(self.raw_subtitle_path, mode='r+', encoding='utf-8') as f: - content = f.readlines() - f.seek(0) - for i in content: - if i.find(str(watermark_area[0])) == -1: - f.write(i) - f.truncate() - print(f'已经删除该区域字幕...') - print('水印区域字幕过滤完毕...') - - def detect_subtitle_area(self): - """ - 读取过滤水印区域后的raw txt文件,根据坐标信息,查找字幕区域 - 假定:字幕区域在y轴上有一个相对固定的坐标范围,相对于场景文本,这个范围出现频率更高 - :return 返回字幕的区域位置 - """ - # 打开去水印区域处理过的raw txt - f = open(self.raw_subtitle_path, mode='r', encoding='utf-8') # 打开txt文件,以‘utf-8’编码读取 - line = f.readline() # 以行的形式进行读取文件 - # y坐标点列表 - y_coordinates_list = [] - while line: - text_position = line.split('\t')[1].split('(')[1].split(')')[0].split(', ') - y_coordinates_list.append((int(text_position[2]), int(text_position[3]))) - line = f.readline() - f.close() - return Counter(y_coordinates_list).most_common(1) - - def filter_scene_text(self): - """ - 将场景里提取的文字过滤,仅保留字幕区域 - """ - # 获取潜在字幕区域 - subtitle_area = self.detect_subtitle_area()[0][0] - - # 从frame目录随机读取一张图片,将所水印区域标记出来,用户看图判断是否是水印区域 - frame_path = os.path.join(self.frame_output_dir, - random.choice( - [i for i in sorted(os.listdir(self.frame_output_dir)) if i.endswith('.jpg')])) - sample_frame = cv2.imread(frame_path) - - # 为了防止有双行字幕,根据容忍度,将字幕区域y范围加高 - ymin = abs(subtitle_area[0] - config.SUBTITLE_AREA_DEVIATION_PIXEL) - ymax = subtitle_area[1] + config.SUBTITLE_AREA_DEVIATION_PIXEL - # 画出字幕框的区域 - cv2.rectangle(sample_frame, pt1=(0, ymin), pt2=(sample_frame.shape[1], ymax), color=(0, 0, 255), thickness=3) - sample_frame_file_path = os.path.join(os.path.dirname(self.frame_output_dir), 'subtitle_area.jpg') - cv2.imwrite(sample_frame_file_path, sample_frame) - print(f'请查看图片, 确定字幕区域是否正确: {sample_frame_file_path}') - - user_input = input(f'是否去除红色框区域外{(ymin, ymax)}的字幕?' - f'\n输入 "y" 或 "回车" 表示去除,输入"n"或其他表示不去除: ').strip() - if user_input == 'y' or user_input == '\n': - with open(self.raw_subtitle_path, mode='r+', encoding='utf-8') as f: - content = f.readlines() - f.seek(0) - for i in content: - i_ymin = int(i.split('\t')[1].split('(')[1].split(')')[0].split(', ')[2]) - i_ymax = int(i.split('\t')[1].split('(')[1].split(')')[0].split(', ')[3]) - if ymin <= i_ymin and i_ymax <= ymax: - f.write(i) - f.truncate() - - def generate_subtitle_file(self): - """ - 生成srt格式的字幕文件 - """ - subtitle_content = self._remove_duplicate_subtitle() - print(os.path.splitext(self.video_path)[0]) - srt_filename = os.path.join(os.path.splitext(self.video_path)[0] + '.srt') - with open(srt_filename, mode='w', encoding='utf-8') as f: - for index, content in enumerate(subtitle_content): - line_code = index + 1 - frame_start = self._frame_to_timecode(int(content[0])) - frame_end = self._frame_to_timecode(int(content[1])) - frame_content = content[2] - subtitle_line = f'{line_code}\n{frame_start} --> {frame_end}\n{frame_content}\n' - f.write(subtitle_line) - - def _frame_to_timecode(self, frame_no): - """ - 将视频帧转换成时间 - :param frame_no: 视频的帧号,i.e. 第几帧视频帧 - :param frame_rate: 视频的帧率 - :param drop: 帧率不为整数时,是否添加drop frame进行补帧 - :returns: SMPTE格式时间戳 as string, 如'01:02:12:32' 或者 '01:02:12;32' - """ - # 将小数点后两位的数字丢弃 - tmp = str(self.fps).split('.') - tmp[1] = tmp[1][:2] - frame_rate = float('.'.join(tmp)) - - drop = False - - if frame_rate in [29.97, 59.94]: - drop = True - - # 将fps就近取整,如29.97或59.94取整为30和60 - fps_int = int(round(frame_rate)) - # 然后添加drop frames进行补偿 - - if drop: - # drop-frame-mode - # 每分钟添加两个fake frames,每十分钟的时候不添加 - # 1分钟内,non-drop和drop的时间戳对比 - # frame: 1795 non-drop: 00:00:59:25 drop: 00:00:59;25 - # frame: 1796 non-drop: 00:00:59:26 drop: 00:00:59;26 - # frame: 1797 non-drop: 00:00:59:27 drop: 00:00:59;27 - # frame: 1798 non-drop: 00:00:59:28 drop: 00:00:59;28 - # frame: 1799 non-drop: 00:00:59:29 drop: 00:00:59;29 - # frame: 1800 non-drop: 00:01:00:00 drop: 00:01:00;02 - # frame: 1801 non-drop: 00:01:00:01 drop: 00:01:00;03 - # frame: 1802 non-drop: 00:01:00:02 drop: 00:01:00;04 - # frame: 1803 non-drop: 00:01:00:03 drop: 00:01:00;05 - # frame: 1804 non-drop: 00:01:00:04 drop: 00:01:00;06 - # frame: 1805 non-drop: 00:01:00:05 drop: 00:01:00;07 - # - # 10分钟内,non-drop和drop的时间戳对比 - # - # frame: 17977 non-drop: 00:09:59:07 drop: 00:09:59;25 - # frame: 17978 non-drop: 00:09:59:08 drop: 00:09:59;26 - # frame: 17979 non-drop: 00:09:59:09 drop: 00:09:59;27 - # frame: 17980 non-drop: 00:09:59:10 drop: 00:09:59;28 - # frame: 17981 non-drop: 00:09:59:11 drop: 00:09:59;29 - # frame: 17982 non-drop: 00:09:59:12 drop: 00:10:00;00 - # frame: 17983 non-drop: 00:09:59:13 drop: 00:10:00;01 - # frame: 17984 non-drop: 00:09:59:14 drop: 00:10:00;02 - # frame: 17985 non-drop: 00:09:59:15 drop: 00:10:00;03 - # frame: 17986 non-drop: 00:09:59:16 drop: 00:10:00;04 - # frame: 17987 non-drop: 00:09:59:17 drop: 00:10:00;05 - - # 计算29.97 std NTSC工作流程的丢帧数。1分钟一共有30*60 = 1800 frames - - FRAMES_IN_ONE_MINUTE = 1800 - 2 - - FRAMES_IN_TEN_MINUTES = (FRAMES_IN_ONE_MINUTE * 10) - 2 - - ten_minute_chunks = frame_no / FRAMES_IN_TEN_MINUTES - one_minute_chunks = frame_no % FRAMES_IN_TEN_MINUTES - - ten_minute_part = 18 * ten_minute_chunks - one_minute_part = 2 * ((one_minute_chunks - 2) / FRAMES_IN_ONE_MINUTE) - - if one_minute_part < 0: - one_minute_part = 0 - - # 添加额外的帧 - frame_no += ten_minute_part + one_minute_part - - # 对于60 fps的drop frame计算, 添加两倍的帧数 - if fps_int == 60: - frame_no = frame_no * 2 - - # time codes are on the form 12:12:12;12 - smpte_token = ";" - - else: - # time codes are on the form 12:12:12:12 - smpte_token = "," - - # 将视频帧转化为时间戳 - hours = int(frame_no / (3600 * fps_int)) - minutes = int(frame_no / (60 * fps_int) % 60) - seconds = int(frame_no / fps_int % 60) - frames = int(frame_no % fps_int) - return "%02d:%02d:%02d%s%02d" % (hours, minutes, seconds, smpte_token, frames) - - def _remove_duplicate_subtitle(self): - """ - 读取原始的raw txt,去除重复行,返回去除了重复后的字幕列表 - """ - with open(self.raw_subtitle_path, 'r') as r: - lines = r.readlines() - content_list = [] - for line in lines: - frame_no = line.split('\t')[0] - content = line.split('\t')[2] - content_list.append((frame_no, content)) - # 循环遍历每行字幕,记录开始时间与结束时间 - index = 0 - # 去重后的字幕列表 - unique_subtitle_list = [] - for i in content_list: - # TODO: 时间复杂度非常高,有待优化 - # 定义字幕开始帧帧号 - start_frame = i[0] - for j in content_list[index:]: - # 计算当前行与下一行的Levenshtein距离 - distance = ratio(i[1], j[1]) - if distance < config.TEXT_SIMILARITY_THRESHOLD or j == content_list[-1]: - # 定义字幕结束帧帧号 - end_frame = content_list[content_list.index(j) - 1][0] - if end_frame == start_frame: - end_frame = j[0] - if str(unique_subtitle_list).find(i[1].replace('\n', '')) == -1: - unique_subtitle_list.append((start_frame, end_frame, i[1])) - index += 1 - break - else: - continue - return unique_subtitle_list - - def _unite_coordinates(self, coordinates_list): - """ - 给定一个坐标列表,将这个列表中相似的坐标统一为一个值 - e.g. 由于检测框检测的结果不是一直的,相同位置文字的坐标可能一次检测为(255,123,456,789),另一次检测为(253,122,456,799) - 因此要对相似的坐标进行值的统一 - :param coordinates_list 包含坐标点的列表 - :return: 返回一个统一值后的坐标列表 - """ - # 将相似的坐标统一为一个 - index = 0 - for coordinate in coordinates_list: # TODO:时间复杂度n^2,待优化 - for i in coordinates_list: - if self.__is_coordinate_similar(coordinate, i): - coordinates_list[index] = i - index += 1 - return coordinates_list - - def _compute_image_similarity(self, image1, image2): - """ - 计算两张图片的余弦相似度 - """ - image1 = self.__get_thum(image1) - image2 = self.__get_thum(image2) - images = [image1, image2] - vectors = [] - norms = [] - for image in images: - vector = [] - for pixel_tuple in image.getdata(): - vector.append(average(pixel_tuple)) - vectors.append(vector) - # linalg=linear(线性)+algebra(代数),norm则表示范数 - # 求图片的范数 - norms.append(linalg.norm(vector, 2)) - a, b = vectors - a_norm, b_norm = norms - # dot返回的是点积,对二维数组(矩阵)进行计算 - res = dot(a / a_norm, b / b_norm) - return res - - @staticmethod - def __get_coordinates(dt_box): - """ - 从返回的检测框中获取坐标 - :param dt_box 检测框返回结果 - :return list 坐标点列表 - """ - coordinate_list = list() - if isinstance(dt_box, list): - for i in dt_box: - i = list(i) - (x1, y1) = int(i[0][0]), int(i[0][1]) - (x2, y2) = int(i[1][0]), int(i[1][1]) - (x3, y3) = int(i[2][0]), int(i[2][1]) - (x4, y4) = int(i[3][0]), int(i[3][1]) - xmin = max(x1, x4) - xmax = min(x2, x3) - ymin = max(y1, y2) - ymax = min(y3, y4) - coordinate_list.append((xmin, xmax, ymin, ymax)) - return coordinate_list - - @staticmethod - def __is_coordinate_similar(coordinate1, coordinate2): - """ - 计算两个坐标是否相似,如果两个坐标点的xmin,xmax,ymin,ymax的差值都在像素点容忍度内 - 则认为这两个坐标点相似 - """ - return abs(coordinate1[0] - coordinate2[0]) < config.PIXEL_TOLERANCE_X and \ - abs(coordinate1[1] - coordinate2[1]) < config.PIXEL_TOLERANCE_X and \ - abs(coordinate1[2] - coordinate2[2]) < config.PIXEL_TOLERANCE_Y and \ - abs(coordinate1[3] - coordinate2[3]) < config.PIXEL_TOLERANCE_Y - - @staticmethod - def __get_thum(image, size=(64, 64), greyscale=False): - """ - 对图片进行统一化处理 - """ - # 利用image对图像大小重新设置, Image.ANTIALIAS为高质量的 - image = image.resize(size, Image.ANTIALIAS) - if greyscale: - # 将图片转换为L模式,其为灰度图,其每个像素用8个bit表示 - image = image.convert('L') - return image - - -if __name__ == '__main__': - # 提示用户输入视频路径 - video_path = input("请输入视频完整路径:").strip() - # 新建字幕提取对象 - se = SubtitleExtractor(video_path) - # 开始提取字幕 - se.run() - diff --git a/requirements.txt b/requirements.txt index c324df0d..3b952f7a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,14 +1,13 @@ -shapely~=1.7.1 -scikit-image==0.17.2 +opencv-python==4.10.0.84 +python-Levenshtein==0.26.0 +pillow==10.4.0 +tqdm==4.66.5 +filesplit==3.0.2 +pysrt==1.1.2 +wordsegment==1.3.1 +scikit-image==0.24.0 +lmdb==1.5.1 imgaug==0.4.0 -pyclipper~=1.2.1 -lmdb~=1.1.1 -opencv-python==4.2.0.32 -tqdm~=4.59.0 -numpy~=1.19.0 -visualdl -python-Levenshtein -six~=1.15.0 -pillow~=8.1.2 -pyyaml~=5.4.1 -requests~=2.25.1 \ No newline at end of file +pyclipper==1.3.0.post5 +PySimpleGUI==4.70.1 +numpy==1.26.4 \ No newline at end of file diff --git a/test/test_ar.flv b/test/test_ar.flv new file mode 100644 index 00000000..1edac235 Binary files /dev/null and b/test/test_ar.flv differ diff --git a/test/test_chinese_cht.flv b/test/test_chinese_cht.flv new file mode 100644 index 00000000..09818b91 Binary files /dev/null and b/test/test_chinese_cht.flv differ diff --git a/test/test_cn.mp4 b/test/test_cn.mp4 new file mode 100644 index 00000000..e7899c05 Binary files /dev/null and b/test/test_cn.mp4 differ diff --git a/test/test_cn2.mp4 b/test/test_cn2.mp4 new file mode 100644 index 00000000..c0ff2d44 Binary files /dev/null and b/test/test_cn2.mp4 differ diff --git a/test/test_en.mp4 b/test/test_en.mp4 new file mode 100644 index 00000000..2d435b56 Binary files /dev/null and b/test/test_en.mp4 differ diff --git a/test/test_en_ch.mp4 b/test/test_en_ch.mp4 new file mode 100644 index 00000000..2b058dc5 Binary files /dev/null and b/test/test_en_ch.mp4 differ diff --git a/test/test_es.flv b/test/test_es.flv new file mode 100644 index 00000000..8e8243df Binary files /dev/null and b/test/test_es.flv differ diff --git a/test/test_german.mp4 b/test/test_german.mp4 new file mode 100644 index 00000000..f31f9ae5 Binary files /dev/null and b/test/test_german.mp4 differ diff --git a/test/test_it.flv b/test/test_it.flv new file mode 100644 index 00000000..de8b917e Binary files /dev/null and b/test/test_it.flv differ diff --git a/test/test_japan.mp4 b/test/test_japan.mp4 new file mode 100644 index 00000000..f0f854db Binary files /dev/null and b/test/test_japan.mp4 differ diff --git a/test/test_korean.flv b/test/test_korean.flv new file mode 100644 index 00000000..d0f0e86b Binary files /dev/null and b/test/test_korean.flv differ diff --git a/test/test_ru.flv b/test/test_ru.flv new file mode 100644 index 00000000..cfba4cf9 Binary files /dev/null and b/test/test_ru.flv differ