', '', ''] + dict_character
+ return dict_character
+
+
+class CTCLabelEncode(BaseRecLabelEncode):
+ """ Convert between text-label and text-index """
+
+ def __init__(self,
+ max_text_length,
+ character_dict_path=None,
+ use_space_char=False,
+ **kwargs):
+ super(CTCLabelEncode, self).__init__(
+ max_text_length, character_dict_path, use_space_char)
+
+ def __call__(self, data):
+ text = data['label']
+ text = self.encode(text)
+ if text is None:
+ return None
+ data['length'] = np.array(len(text))
+ text = text + [0] * (self.max_text_len - len(text))
+ data['label'] = np.array(text)
+
+ label = [0] * len(self.character)
+ for x in text:
+ label[x] += 1
+ data['label_ace'] = np.array(label)
+ return data
+
+ def add_special_char(self, dict_character):
+ dict_character = ['blank'] + dict_character
+ return dict_character
+
+
+class E2ELabelEncodeTest(BaseRecLabelEncode):
+ def __init__(self,
+ max_text_length,
+ character_dict_path=None,
+ use_space_char=False,
+ **kwargs):
+ super(E2ELabelEncodeTest, self).__init__(
+ max_text_length, character_dict_path, use_space_char)
+
+ def __call__(self, data):
+ import json
+ padnum = len(self.dict)
+ label = data['label']
+ label = json.loads(label)
+ nBox = len(label)
+ boxes, txts, txt_tags = [], [], []
+ for bno in range(0, nBox):
+ box = label[bno]['points']
+ txt = label[bno]['transcription']
+ boxes.append(box)
+ txts.append(txt)
+ if txt in ['*', '###']:
+ txt_tags.append(True)
+ else:
+ txt_tags.append(False)
+ boxes = np.array(boxes, dtype=np.float32)
+ txt_tags = np.array(txt_tags, dtype=np.bool)
+ data['polys'] = boxes
+ data['ignore_tags'] = txt_tags
+ temp_texts = []
+ for text in txts:
+ text = text.lower()
+ text = self.encode(text)
+ if text is None:
+ return None
+ text = text + [padnum] * (self.max_text_len - len(text)
+ ) # use 36 to pad
+ temp_texts.append(text)
+ data['texts'] = np.array(temp_texts)
+ return data
+
+
+class E2ELabelEncodeTrain(object):
+ def __init__(self, **kwargs):
+ pass
+
+ def __call__(self, data):
+ import json
+ label = data['label']
+ label = json.loads(label)
+ nBox = len(label)
+ boxes, txts, txt_tags = [], [], []
+ for bno in range(0, nBox):
+ box = label[bno]['points']
+ txt = label[bno]['transcription']
+ boxes.append(box)
+ txts.append(txt)
+ if txt in ['*', '###']:
+ txt_tags.append(True)
+ else:
+ txt_tags.append(False)
+ boxes = np.array(boxes, dtype=np.float32)
+ txt_tags = np.array(txt_tags, dtype=np.bool)
+
+ data['polys'] = boxes
+ data['texts'] = txts
+ data['ignore_tags'] = txt_tags
+ return data
+
+
+class KieLabelEncode(object):
+ def __init__(self, character_dict_path, norm=10, directed=False, **kwargs):
+ super(KieLabelEncode, self).__init__()
+ self.dict = dict({'': 0})
+ with open(character_dict_path, 'r', encoding='utf-8') as fr:
+ idx = 1
+ for line in fr:
+ char = line.strip()
+ self.dict[char] = idx
+ idx += 1
+ self.norm = norm
+ self.directed = directed
+
+ def compute_relation(self, boxes):
+ """Compute relation between every two boxes."""
+ x1s, y1s = boxes[:, 0:1], boxes[:, 1:2]
+ x2s, y2s = boxes[:, 4:5], boxes[:, 5:6]
+ ws, hs = x2s - x1s + 1, np.maximum(y2s - y1s + 1, 1)
+ dxs = (x1s[:, 0][None] - x1s) / self.norm
+ dys = (y1s[:, 0][None] - y1s) / self.norm
+ xhhs, xwhs = hs[:, 0][None] / hs, ws[:, 0][None] / hs
+ whs = ws / hs + np.zeros_like(xhhs)
+ relations = np.stack([dxs, dys, whs, xhhs, xwhs], -1)
+ bboxes = np.concatenate([x1s, y1s, x2s, y2s], -1).astype(np.float32)
+ return relations, bboxes
+
+ def pad_text_indices(self, text_inds):
+ """Pad text index to same length."""
+ max_len = 300
+ recoder_len = max([len(text_ind) for text_ind in text_inds])
+ padded_text_inds = -np.ones((len(text_inds), max_len), np.int32)
+ for idx, text_ind in enumerate(text_inds):
+ padded_text_inds[idx, :len(text_ind)] = np.array(text_ind)
+ return padded_text_inds, recoder_len
+
+ def list_to_numpy(self, ann_infos):
+ """Convert bboxes, relations, texts and labels to ndarray."""
+ boxes, text_inds = ann_infos['points'], ann_infos['text_inds']
+ boxes = np.array(boxes, np.int32)
+ relations, bboxes = self.compute_relation(boxes)
+
+ labels = ann_infos.get('labels', None)
+ if labels is not None:
+ labels = np.array(labels, np.int32)
+ edges = ann_infos.get('edges', None)
+ if edges is not None:
+ labels = labels[:, None]
+ edges = np.array(edges)
+ edges = (edges[:, None] == edges[None, :]).astype(np.int32)
+ if self.directed:
+ edges = (edges & labels == 1).astype(np.int32)
+ np.fill_diagonal(edges, -1)
+ labels = np.concatenate([labels, edges], -1)
+ padded_text_inds, recoder_len = self.pad_text_indices(text_inds)
+ max_num = 300
+ temp_bboxes = np.zeros([max_num, 4])
+ h, _ = bboxes.shape
+ temp_bboxes[:h, :] = bboxes
+
+ temp_relations = np.zeros([max_num, max_num, 5])
+ temp_relations[:h, :h, :] = relations
+
+ temp_padded_text_inds = np.zeros([max_num, max_num])
+ temp_padded_text_inds[:h, :] = padded_text_inds
+
+ temp_labels = np.zeros([max_num, max_num])
+ temp_labels[:h, :h + 1] = labels
+
+ tag = np.array([h, recoder_len])
+ return dict(
+ image=ann_infos['image'],
+ points=temp_bboxes,
+ relations=temp_relations,
+ texts=temp_padded_text_inds,
+ labels=temp_labels,
+ tag=tag)
+
+ def convert_canonical(self, points_x, points_y):
+
+ assert len(points_x) == 4
+ assert len(points_y) == 4
+
+ points = [Point(points_x[i], points_y[i]) for i in range(4)]
+
+ polygon = Polygon([(p.x, p.y) for p in points])
+ min_x, min_y, _, _ = polygon.bounds
+ points_to_lefttop = [
+ LineString([points[i], Point(min_x, min_y)]) for i in range(4)
+ ]
+ distances = np.array([line.length for line in points_to_lefttop])
+ sort_dist_idx = np.argsort(distances)
+ lefttop_idx = sort_dist_idx[0]
+
+ if lefttop_idx == 0:
+ point_orders = [0, 1, 2, 3]
+ elif lefttop_idx == 1:
+ point_orders = [1, 2, 3, 0]
+ elif lefttop_idx == 2:
+ point_orders = [2, 3, 0, 1]
+ else:
+ point_orders = [3, 0, 1, 2]
+
+ sorted_points_x = [points_x[i] for i in point_orders]
+ sorted_points_y = [points_y[j] for j in point_orders]
+
+ return sorted_points_x, sorted_points_y
+
+ def sort_vertex(self, points_x, points_y):
+
+ assert len(points_x) == 4
+ assert len(points_y) == 4
+
+ x = np.array(points_x)
+ y = np.array(points_y)
+ center_x = np.sum(x) * 0.25
+ center_y = np.sum(y) * 0.25
+
+ x_arr = np.array(x - center_x)
+ y_arr = np.array(y - center_y)
+
+ angle = np.arctan2(y_arr, x_arr) * 180.0 / np.pi
+ sort_idx = np.argsort(angle)
+
+ sorted_points_x, sorted_points_y = [], []
+ for i in range(4):
+ sorted_points_x.append(points_x[sort_idx[i]])
+ sorted_points_y.append(points_y[sort_idx[i]])
+
+ return self.convert_canonical(sorted_points_x, sorted_points_y)
+
+ def __call__(self, data):
+ import json
+ label = data['label']
+ annotations = json.loads(label)
+ boxes, texts, text_inds, labels, edges = [], [], [], [], []
+ for ann in annotations:
+ box = ann['points']
+ x_list = [box[i][0] for i in range(4)]
+ y_list = [box[i][1] for i in range(4)]
+ sorted_x_list, sorted_y_list = self.sort_vertex(x_list, y_list)
+ sorted_box = []
+ for x, y in zip(sorted_x_list, sorted_y_list):
+ sorted_box.append(x)
+ sorted_box.append(y)
+ boxes.append(sorted_box)
+ text = ann['transcription']
+ texts.append(ann['transcription'])
+ text_ind = [self.dict[c] for c in text if c in self.dict]
+ text_inds.append(text_ind)
+ labels.append(ann['label'])
+ edges.append(ann.get('edge', 0))
+ ann_infos = dict(
+ image=data['image'],
+ points=boxes,
+ texts=texts,
+ text_inds=text_inds,
+ edges=edges,
+ labels=labels)
+
+ return self.list_to_numpy(ann_infos)
+
+
+class AttnLabelEncode(BaseRecLabelEncode):
+ """ Convert between text-label and text-index """
+
+ def __init__(self,
+ max_text_length,
+ character_dict_path=None,
+ use_space_char=False,
+ **kwargs):
+ super(AttnLabelEncode, self).__init__(
+ max_text_length, character_dict_path, use_space_char)
+
+ def add_special_char(self, dict_character):
+ self.beg_str = "sos"
+ self.end_str = "eos"
+ dict_character = [self.beg_str] + dict_character + [self.end_str]
+ return dict_character
+
+ def __call__(self, data):
+ text = data['label']
+ text = self.encode(text)
+ if text is None:
+ return None
+ if len(text) >= self.max_text_len:
+ return None
+ data['length'] = np.array(len(text))
+ text = [0] + text + [len(self.character) - 1] + [0] * (self.max_text_len
+ - len(text) - 2)
+ data['label'] = np.array(text)
+ return data
+
+ def get_ignored_tokens(self):
+ beg_idx = self.get_beg_end_flag_idx("beg")
+ end_idx = self.get_beg_end_flag_idx("end")
+ return [beg_idx, end_idx]
+
+ def get_beg_end_flag_idx(self, beg_or_end):
+ if beg_or_end == "beg":
+ idx = np.array(self.dict[self.beg_str])
+ elif beg_or_end == "end":
+ idx = np.array(self.dict[self.end_str])
+ else:
+ assert False, "Unsupport type %s in get_beg_end_flag_idx" \
+ % beg_or_end
+ return idx
+
+
+class SEEDLabelEncode(BaseRecLabelEncode):
+ """ Convert between text-label and text-index """
+
+ def __init__(self,
+ max_text_length,
+ character_dict_path=None,
+ use_space_char=False,
+ **kwargs):
+ super(SEEDLabelEncode, self).__init__(
+ max_text_length, character_dict_path, use_space_char)
+
+ def add_special_char(self, dict_character):
+ self.padding = "padding"
+ self.end_str = "eos"
+ self.unknown = "unknown"
+ dict_character = dict_character + [
+ self.end_str, self.padding, self.unknown
+ ]
+ return dict_character
+
+ def __call__(self, data):
+ text = data['label']
+ text = self.encode(text)
+ if text is None:
+ return None
+ if len(text) >= self.max_text_len:
+ return None
+ data['length'] = np.array(len(text)) + 1 # conclude eos
+ text = text + [len(self.character) - 3] + [len(self.character) - 2] * (
+ self.max_text_len - len(text) - 1)
+ data['label'] = np.array(text)
+ return data
+
+
+class SRNLabelEncode(BaseRecLabelEncode):
+ """ Convert between text-label and text-index """
+
+ def __init__(self,
+ max_text_length=25,
+ character_dict_path=None,
+ use_space_char=False,
+ **kwargs):
+ super(SRNLabelEncode, self).__init__(
+ max_text_length, character_dict_path, use_space_char)
+
+ def add_special_char(self, dict_character):
+ dict_character = dict_character + [self.beg_str, self.end_str]
+ return dict_character
+
+ def __call__(self, data):
+ text = data['label']
+ text = self.encode(text)
+ char_num = len(self.character)
+ if text is None:
+ return None
+ if len(text) > self.max_text_len:
+ return None
+ data['length'] = np.array(len(text))
+ text = text + [char_num - 1] * (self.max_text_len - len(text))
+ data['label'] = np.array(text)
+ return data
+
+ def get_ignored_tokens(self):
+ beg_idx = self.get_beg_end_flag_idx("beg")
+ end_idx = self.get_beg_end_flag_idx("end")
+ return [beg_idx, end_idx]
+
+ def get_beg_end_flag_idx(self, beg_or_end):
+ if beg_or_end == "beg":
+ idx = np.array(self.dict[self.beg_str])
+ elif beg_or_end == "end":
+ idx = np.array(self.dict[self.end_str])
+ else:
+ assert False, "Unsupport type %s in get_beg_end_flag_idx" \
+ % beg_or_end
+ return idx
+
+
+class TableLabelEncode(object):
+ """ Convert between text-label and text-index """
+
+ def __init__(self,
+ max_text_length,
+ max_elem_length,
+ max_cell_num,
+ character_dict_path,
+ span_weight=1.0,
+ **kwargs):
+ self.max_text_length = max_text_length
+ self.max_elem_length = max_elem_length
+ self.max_cell_num = max_cell_num
+ list_character, list_elem = self.load_char_elem_dict(
+ character_dict_path)
+ list_character = self.add_special_char(list_character)
+ list_elem = self.add_special_char(list_elem)
+ self.dict_character = {}
+ for i, char in enumerate(list_character):
+ self.dict_character[char] = i
+ self.dict_elem = {}
+ for i, elem in enumerate(list_elem):
+ self.dict_elem[elem] = i
+ self.span_weight = span_weight
+
+ def load_char_elem_dict(self, character_dict_path):
+ list_character = []
+ list_elem = []
+ with open(character_dict_path, "rb") as fin:
+ lines = fin.readlines()
+ substr = lines[0].decode('utf-8').strip("\r\n").split("\t")
+ character_num = int(substr[0])
+ elem_num = int(substr[1])
+ for cno in range(1, 1 + character_num):
+ character = lines[cno].decode('utf-8').strip("\r\n")
+ list_character.append(character)
+ for eno in range(1 + character_num, 1 + character_num + elem_num):
+ elem = lines[eno].decode('utf-8').strip("\r\n")
+ list_elem.append(elem)
+ return list_character, list_elem
+
+ def add_special_char(self, list_character):
+ self.beg_str = "sos"
+ self.end_str = "eos"
+ list_character = [self.beg_str] + list_character + [self.end_str]
+ return list_character
+
+ def get_span_idx_list(self):
+ span_idx_list = []
+ for elem in self.dict_elem:
+ if 'span' in elem:
+ span_idx_list.append(self.dict_elem[elem])
+ return span_idx_list
+
+ def __call__(self, data):
+ cells = data['cells']
+ structure = data['structure']['tokens']
+ structure = self.encode(structure, 'elem')
+ if structure is None:
+ return None
+ elem_num = len(structure)
+ structure = [0] + structure + [len(self.dict_elem) - 1]
+ structure = structure + [0] * (self.max_elem_length + 2 - len(structure)
+ )
+ structure = np.array(structure)
+ data['structure'] = structure
+ elem_char_idx1 = self.dict_elem['']
+ elem_char_idx2 = self.dict_elem[' | 0:
+ span_weight = len(td_idx_list) * 1.0 / len(span_idx_list)
+ span_weight = min(max(span_weight, 1.0), self.span_weight)
+ for cno in range(len(cells)):
+ if 'bbox' in cells[cno]:
+ bbox = cells[cno]['bbox'].copy()
+ bbox[0] = bbox[0] * 1.0 / img_width
+ bbox[1] = bbox[1] * 1.0 / img_height
+ bbox[2] = bbox[2] * 1.0 / img_width
+ bbox[3] = bbox[3] * 1.0 / img_height
+ td_idx = td_idx_list[cno]
+ bbox_list[td_idx] = bbox
+ bbox_list_mask[td_idx] = 1.0
+ cand_span_idx = td_idx + 1
+ if cand_span_idx < (self.max_elem_length + 2):
+ if structure[cand_span_idx] in span_idx_list:
+ structure_mask[cand_span_idx] = span_weight
+
+ data['bbox_list'] = bbox_list
+ data['bbox_list_mask'] = bbox_list_mask
+ data['structure_mask'] = structure_mask
+ char_beg_idx = self.get_beg_end_flag_idx('beg', 'char')
+ char_end_idx = self.get_beg_end_flag_idx('end', 'char')
+ elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem')
+ elem_end_idx = self.get_beg_end_flag_idx('end', 'elem')
+ data['sp_tokens'] = np.array([
+ char_beg_idx, char_end_idx, elem_beg_idx, elem_end_idx,
+ elem_char_idx1, elem_char_idx2, self.max_text_length,
+ self.max_elem_length, self.max_cell_num, elem_num
+ ])
+ return data
+
+ def encode(self, text, char_or_elem):
+ """convert text-label into text-index.
+ """
+ if char_or_elem == "char":
+ max_len = self.max_text_length
+ current_dict = self.dict_character
+ else:
+ max_len = self.max_elem_length
+ current_dict = self.dict_elem
+ if len(text) > max_len:
+ return None
+ if len(text) == 0:
+ if char_or_elem == "char":
+ return [self.dict_character['space']]
+ else:
+ return None
+ text_list = []
+ for char in text:
+ if char not in current_dict:
+ return None
+ text_list.append(current_dict[char])
+ if len(text_list) == 0:
+ if char_or_elem == "char":
+ return [self.dict_character['space']]
+ else:
+ return None
+ return text_list
+
+ def get_ignored_tokens(self, char_or_elem):
+ beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
+ end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
+ return [beg_idx, end_idx]
+
+ def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
+ if char_or_elem == "char":
+ if beg_or_end == "beg":
+ idx = np.array(self.dict_character[self.beg_str])
+ elif beg_or_end == "end":
+ idx = np.array(self.dict_character[self.end_str])
+ else:
+ assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \
+ % beg_or_end
+ elif char_or_elem == "elem":
+ if beg_or_end == "beg":
+ idx = np.array(self.dict_elem[self.beg_str])
+ elif beg_or_end == "end":
+ idx = np.array(self.dict_elem[self.end_str])
+ else:
+ assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
+ % beg_or_end
+ else:
+ assert False, "Unsupport type %s in char_or_elem" \
+ % char_or_elem
+ return idx
+
+
+class SARLabelEncode(BaseRecLabelEncode):
+ """ Convert between text-label and text-index """
+
+ def __init__(self,
+ max_text_length,
+ character_dict_path=None,
+ use_space_char=False,
+ **kwargs):
+ super(SARLabelEncode, self).__init__(
+ max_text_length, character_dict_path, use_space_char)
+
+ def add_special_char(self, dict_character):
+ beg_end_str = ""
+ unknown_str = ""
+ padding_str = ""
+ dict_character = dict_character + [unknown_str]
+ self.unknown_idx = len(dict_character) - 1
+ dict_character = dict_character + [beg_end_str]
+ self.start_idx = len(dict_character) - 1
+ self.end_idx = len(dict_character) - 1
+ dict_character = dict_character + [padding_str]
+ self.padding_idx = len(dict_character) - 1
+
+ return dict_character
+
+ def __call__(self, data):
+ text = data['label']
+ text = self.encode(text)
+ if text is None:
+ return None
+ if len(text) >= self.max_text_len - 1:
+ return None
+ data['length'] = np.array(len(text))
+ target = [self.start_idx] + text + [self.end_idx]
+ padded_text = [self.padding_idx for _ in range(self.max_text_len)]
+
+ padded_text[:len(target)] = target
+ data['label'] = np.array(padded_text)
+ return data
+
+ def get_ignored_tokens(self):
+ return [self.padding_idx]
+
+
+class PRENLabelEncode(BaseRecLabelEncode):
+ def __init__(self,
+ max_text_length,
+ character_dict_path,
+ use_space_char=False,
+ **kwargs):
+ super(PRENLabelEncode, self).__init__(
+ max_text_length, character_dict_path, use_space_char)
+
+ def add_special_char(self, dict_character):
+ padding_str = '' # 0
+ end_str = '' # 1
+ unknown_str = '' # 2
+
+ dict_character = [padding_str, end_str, unknown_str] + dict_character
+ self.padding_idx = 0
+ self.end_idx = 1
+ self.unknown_idx = 2
+
+ return dict_character
+
+ def encode(self, text):
+ if len(text) == 0 or len(text) >= self.max_text_len:
+ return None
+ if self.lower:
+ text = text.lower()
+ text_list = []
+ for char in text:
+ if char not in self.dict:
+ text_list.append(self.unknown_idx)
+ else:
+ text_list.append(self.dict[char])
+ text_list.append(self.end_idx)
+ if len(text_list) < self.max_text_len:
+ text_list += [self.padding_idx] * (
+ self.max_text_len - len(text_list))
+ return text_list
+
+ def __call__(self, data):
+ text = data['label']
+ encoded_text = self.encode(text)
+ if encoded_text is None:
+ return None
+ data['label'] = np.array(encoded_text)
+ return data
+
+
+class VQATokenLabelEncode(object):
+ """
+ Label encode for NLP VQA methods
+ """
+
+ def __init__(self,
+ class_path,
+ contains_re=False,
+ add_special_ids=False,
+ algorithm='LayoutXLM',
+ infer_mode=False,
+ ocr_engine=None,
+ **kwargs):
+ super(VQATokenLabelEncode, self).__init__()
+ from paddlenlp.transformers import LayoutXLMTokenizer, LayoutLMTokenizer, LayoutLMv2Tokenizer
+ from ppocr.utils.utility import load_vqa_bio_label_maps
+ tokenizer_dict = {
+ 'LayoutXLM': {
+ 'class': LayoutXLMTokenizer,
+ 'pretrained_model': 'layoutxlm-base-uncased'
+ },
+ 'LayoutLM': {
+ 'class': LayoutLMTokenizer,
+ 'pretrained_model': 'layoutlm-base-uncased'
+ },
+ 'LayoutLMv2': {
+ 'class': LayoutLMv2Tokenizer,
+ 'pretrained_model': 'layoutlmv2-base-uncased'
+ }
+ }
+ self.contains_re = contains_re
+ tokenizer_config = tokenizer_dict[algorithm]
+ self.tokenizer = tokenizer_config['class'].from_pretrained(
+ tokenizer_config['pretrained_model'])
+ self.label2id_map, id2label_map = load_vqa_bio_label_maps(class_path)
+ self.add_special_ids = add_special_ids
+ self.infer_mode = infer_mode
+ self.ocr_engine = ocr_engine
+
+ def __call__(self, data):
+ # load bbox and label info
+ ocr_info = self._load_ocr_info(data)
+
+ height, width, _ = data['image'].shape
+
+ words_list = []
+ bbox_list = []
+ input_ids_list = []
+ token_type_ids_list = []
+ segment_offset_id = []
+ gt_label_list = []
+
+ entities = []
+
+ # for re
+ train_re = self.contains_re and not self.infer_mode
+ if train_re:
+ relations = []
+ id2label = {}
+ entity_id_to_index_map = {}
+ empty_entity = set()
+
+ data['ocr_info'] = copy.deepcopy(ocr_info)
+
+ for info in ocr_info:
+ if train_re:
+ # for re
+ if len(info["text"]) == 0:
+ empty_entity.add(info["id"])
+ continue
+ id2label[info["id"]] = info["label"]
+ relations.extend([tuple(sorted(l)) for l in info["linking"]])
+ # smooth_box
+ bbox = self._smooth_box(info["bbox"], height, width)
+
+ text = info["text"]
+ encode_res = self.tokenizer.encode(
+ text, pad_to_max_seq_len=False, return_attention_mask=True)
+
+ if not self.add_special_ids:
+ # TODO: use tok.all_special_ids to remove
+ encode_res["input_ids"] = encode_res["input_ids"][1:-1]
+ encode_res["token_type_ids"] = encode_res["token_type_ids"][1:
+ -1]
+ encode_res["attention_mask"] = encode_res["attention_mask"][1:
+ -1]
+ # parse label
+ if not self.infer_mode:
+ label = info['label']
+ gt_label = self._parse_label(label, encode_res)
+
+ # construct entities for re
+ if train_re:
+ if gt_label[0] != self.label2id_map["O"]:
+ entity_id_to_index_map[info["id"]] = len(entities)
+ label = label.upper()
+ entities.append({
+ "start": len(input_ids_list),
+ "end":
+ len(input_ids_list) + len(encode_res["input_ids"]),
+ "label": label.upper(),
+ })
+ else:
+ entities.append({
+ "start": len(input_ids_list),
+ "end": len(input_ids_list) + len(encode_res["input_ids"]),
+ "label": 'O',
+ })
+ input_ids_list.extend(encode_res["input_ids"])
+ token_type_ids_list.extend(encode_res["token_type_ids"])
+ bbox_list.extend([bbox] * len(encode_res["input_ids"]))
+ words_list.append(text)
+ segment_offset_id.append(len(input_ids_list))
+ if not self.infer_mode:
+ gt_label_list.extend(gt_label)
+
+ data['input_ids'] = input_ids_list
+ data['token_type_ids'] = token_type_ids_list
+ data['bbox'] = bbox_list
+ data['attention_mask'] = [1] * len(input_ids_list)
+ data['labels'] = gt_label_list
+ data['segment_offset_id'] = segment_offset_id
+ data['tokenizer_params'] = dict(
+ padding_side=self.tokenizer.padding_side,
+ pad_token_type_id=self.tokenizer.pad_token_type_id,
+ pad_token_id=self.tokenizer.pad_token_id)
+ data['entities'] = entities
+
+ if train_re:
+ data['relations'] = relations
+ data['id2label'] = id2label
+ data['empty_entity'] = empty_entity
+ data['entity_id_to_index_map'] = entity_id_to_index_map
+ return data
+
+ def _load_ocr_info(self, data):
+ def trans_poly_to_bbox(poly):
+ x1 = np.min([p[0] for p in poly])
+ x2 = np.max([p[0] for p in poly])
+ y1 = np.min([p[1] for p in poly])
+ y2 = np.max([p[1] for p in poly])
+ return [x1, y1, x2, y2]
+
+ if self.infer_mode:
+ ocr_result = self.ocr_engine.ocr(data['image'], cls=False)
+ ocr_info = []
+ for res in ocr_result:
+ ocr_info.append({
+ "text": res[1][0],
+ "bbox": trans_poly_to_bbox(res[0]),
+ "poly": res[0],
+ })
+ return ocr_info
+ else:
+ info = data['label']
+ # read text info
+ info_dict = json.loads(info)
+ return info_dict["ocr_info"]
+
+ def _smooth_box(self, bbox, height, width):
+ bbox[0] = int(bbox[0] * 1000.0 / width)
+ bbox[2] = int(bbox[2] * 1000.0 / width)
+ bbox[1] = int(bbox[1] * 1000.0 / height)
+ bbox[3] = int(bbox[3] * 1000.0 / height)
+ return bbox
+
+ def _parse_label(self, label, encode_res):
+ gt_label = []
+ if label.lower() == "other":
+ gt_label.extend([0] * len(encode_res["input_ids"]))
+ else:
+ gt_label.append(self.label2id_map[("b-" + label).upper()])
+ gt_label.extend([self.label2id_map[("i-" + label).upper()]] *
+ (len(encode_res["input_ids"]) - 1))
+ return gt_label
+
+
+class MultiLabelEncode(BaseRecLabelEncode):
+ def __init__(self,
+ max_text_length,
+ character_dict_path=None,
+ use_space_char=False,
+ **kwargs):
+ super(MultiLabelEncode, self).__init__(
+ max_text_length, character_dict_path, use_space_char)
+
+ self.ctc_encode = CTCLabelEncode(max_text_length, character_dict_path,
+ use_space_char, **kwargs)
+ self.sar_encode = SARLabelEncode(max_text_length, character_dict_path,
+ use_space_char, **kwargs)
+
+ def __call__(self, data):
+
+ data_ctc = copy.deepcopy(data)
+ data_sar = copy.deepcopy(data)
+ data_out = dict()
+ data_out['img_path'] = data.get('img_path', None)
+ data_out['image'] = data['image']
+ ctc = self.ctc_encode.__call__(data_ctc)
+ sar = self.sar_encode.__call__(data_sar)
+ if ctc is None or sar is None:
+ return None
+ data_out['label_ctc'] = ctc['label']
+ data_out['label_sar'] = sar['label']
+ data_out['length'] = ctc['length']
+ return data_out
diff --git a/backend/ppocr/data/imaug/make_border_map.py b/backend/ppocr/data/imaug/make_border_map.py
index cc2c9034..abab3836 100644
--- a/backend/ppocr/data/imaug/make_border_map.py
+++ b/backend/ppocr/data/imaug/make_border_map.py
@@ -1,4 +1,20 @@
-# -*- coding:utf-8 -*-
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/make_border_map.py
+"""
from __future__ import absolute_import
from __future__ import division
diff --git a/backend/ppocr/data/imaug/make_pse_gt.py b/backend/ppocr/data/imaug/make_pse_gt.py
new file mode 100644
index 00000000..255d076b
--- /dev/null
+++ b/backend/ppocr/data/imaug/make_pse_gt.py
@@ -0,0 +1,106 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import cv2
+import numpy as np
+import pyclipper
+from shapely.geometry import Polygon
+
+__all__ = ['MakePseGt']
+
+
+class MakePseGt(object):
+ def __init__(self, kernel_num=7, size=640, min_shrink_ratio=0.4, **kwargs):
+ self.kernel_num = kernel_num
+ self.min_shrink_ratio = min_shrink_ratio
+ self.size = size
+
+ def __call__(self, data):
+
+ image = data['image']
+ text_polys = data['polys']
+ ignore_tags = data['ignore_tags']
+
+ h, w, _ = image.shape
+ short_edge = min(h, w)
+ if short_edge < self.size:
+ # keep short_size >= self.size
+ scale = self.size / short_edge
+ image = cv2.resize(image, dsize=None, fx=scale, fy=scale)
+ text_polys *= scale
+
+ gt_kernels = []
+ for i in range(1, self.kernel_num + 1):
+ # s1->sn, from big to small
+ rate = 1.0 - (1.0 - self.min_shrink_ratio) / (self.kernel_num - 1
+ ) * i
+ text_kernel, ignore_tags = self.generate_kernel(
+ image.shape[0:2], rate, text_polys, ignore_tags)
+ gt_kernels.append(text_kernel)
+
+ training_mask = np.ones(image.shape[0:2], dtype='uint8')
+ for i in range(text_polys.shape[0]):
+ if ignore_tags[i]:
+ cv2.fillPoly(training_mask,
+ text_polys[i].astype(np.int32)[np.newaxis, :, :],
+ 0)
+
+ gt_kernels = np.array(gt_kernels)
+ gt_kernels[gt_kernels > 0] = 1
+
+ data['image'] = image
+ data['polys'] = text_polys
+ data['gt_kernels'] = gt_kernels[0:]
+ data['gt_text'] = gt_kernels[0]
+ data['mask'] = training_mask.astype('float32')
+ return data
+
+ def generate_kernel(self,
+ img_size,
+ shrink_ratio,
+ text_polys,
+ ignore_tags=None):
+ """
+ Refer to part of the code:
+ https://github.com/open-mmlab/mmocr/blob/main/mmocr/datasets/pipelines/textdet_targets/base_textdet_targets.py
+ """
+
+ h, w = img_size
+ text_kernel = np.zeros((h, w), dtype=np.float32)
+ for i, poly in enumerate(text_polys):
+ polygon = Polygon(poly)
+ distance = polygon.area * (1 - shrink_ratio * shrink_ratio) / (
+ polygon.length + 1e-6)
+ subject = [tuple(l) for l in poly]
+ pco = pyclipper.PyclipperOffset()
+ pco.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
+ shrinked = np.array(pco.Execute(-distance))
+
+ if len(shrinked) == 0 or shrinked.size == 0:
+ if ignore_tags is not None:
+ ignore_tags[i] = True
+ continue
+ try:
+ shrinked = np.array(shrinked[0]).reshape(-1, 2)
+ except:
+ if ignore_tags is not None:
+ ignore_tags[i] = True
+ continue
+ cv2.fillPoly(text_kernel, [shrinked.astype(np.int32)], i + 1)
+ return text_kernel, ignore_tags
diff --git a/backend/ppocr/data/imaug/make_shrink_map.py b/backend/ppocr/data/imaug/make_shrink_map.py
index ccdcd015..6c65c20e 100644
--- a/backend/ppocr/data/imaug/make_shrink_map.py
+++ b/backend/ppocr/data/imaug/make_shrink_map.py
@@ -1,4 +1,20 @@
-# -*- coding:utf-8 -*-
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/make_shrink_map.py
+"""
from __future__ import absolute_import
from __future__ import division
@@ -49,7 +65,7 @@ def __call__(self, data):
pyclipper.ET_CLOSEDPOLYGON)
shrinked = []
- # Increase the shrink ratio every time we get multiple polygon returned back
+ # Increase the shrink ratio every time we get multiple polygon returned back
possible_ratios = np.arange(self.shrink_ratio, 1,
self.shrink_ratio)
np.append(possible_ratios, 1)
@@ -71,7 +87,6 @@ def __call__(self, data):
for each_shirnk in shrinked:
shirnk = np.array(each_shirnk).reshape(-1, 2)
cv2.fillPoly(gt, [shirnk.astype(np.int32)], 1)
- # cv2.fillPoly(gt[0], [shrinked.astype(np.int32)], 1)
data['shrink_map'] = gt
data['shrink_mask'] = mask
@@ -97,11 +112,12 @@ def validate_polygons(self, polygons, ignore_tags, h, w):
return polygons, ignore_tags
def polygon_area(self, polygon):
- # return cv2.contourArea(polygon.astype(np.float32))
- edge = 0
- for i in range(polygon.shape[0]):
- next_index = (i + 1) % polygon.shape[0]
- edge += (polygon[next_index, 0] - polygon[i, 0]) * (
- polygon[next_index, 1] - polygon[i, 1])
-
- return edge / 2.
+ """
+ compute polygon area
+ """
+ area = 0
+ q = polygon[-1]
+ for p in polygon:
+ area += p[0] * q[1] - p[1] * q[0]
+ q = p
+ return area / 2.0
diff --git a/backend/ppocr/data/imaug/operators.py b/backend/ppocr/data/imaug/operators.py
new file mode 100644
index 00000000..09736515
--- /dev/null
+++ b/backend/ppocr/data/imaug/operators.py
@@ -0,0 +1,468 @@
+"""
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import sys
+import six
+import cv2
+import numpy as np
+import math
+
+
+class DecodeImage(object):
+ """ decode image """
+
+ def __init__(self,
+ img_mode='RGB',
+ channel_first=False,
+ ignore_orientation=False,
+ **kwargs):
+ self.img_mode = img_mode
+ self.channel_first = channel_first
+ self.ignore_orientation = ignore_orientation
+
+ def __call__(self, data):
+ img = data['image']
+ if six.PY2:
+ assert type(img) is str and len(
+ img) > 0, "invalid input 'img' in DecodeImage"
+ else:
+ assert type(img) is bytes and len(
+ img) > 0, "invalid input 'img' in DecodeImage"
+ img = np.frombuffer(img, dtype='uint8')
+ if self.ignore_orientation:
+ img = cv2.imdecode(img, cv2.IMREAD_IGNORE_ORIENTATION |
+ cv2.IMREAD_COLOR)
+ else:
+ img = cv2.imdecode(img, 1)
+ if img is None:
+ return None
+ if self.img_mode == 'GRAY':
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+ elif self.img_mode == 'RGB':
+ assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
+ img = img[:, :, ::-1]
+
+ if self.channel_first:
+ img = img.transpose((2, 0, 1))
+
+ data['image'] = img
+ return data
+
+
+class NRTRDecodeImage(object):
+ """ decode image """
+
+ def __init__(self, img_mode='RGB', channel_first=False, **kwargs):
+ self.img_mode = img_mode
+ self.channel_first = channel_first
+
+ def __call__(self, data):
+ img = data['image']
+ if six.PY2:
+ assert type(img) is str and len(
+ img) > 0, "invalid input 'img' in DecodeImage"
+ else:
+ assert type(img) is bytes and len(
+ img) > 0, "invalid input 'img' in DecodeImage"
+ img = np.frombuffer(img, dtype='uint8')
+
+ img = cv2.imdecode(img, 1)
+
+ if img is None:
+ return None
+ if self.img_mode == 'GRAY':
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+ elif self.img_mode == 'RGB':
+ assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
+ img = img[:, :, ::-1]
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+ if self.channel_first:
+ img = img.transpose((2, 0, 1))
+ data['image'] = img
+ return data
+
+
+class NormalizeImage(object):
+ """ normalize image such as substract mean, divide std
+ """
+
+ def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
+ if isinstance(scale, str):
+ scale = eval(scale)
+ self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
+ mean = mean if mean is not None else [0.485, 0.456, 0.406]
+ std = std if std is not None else [0.229, 0.224, 0.225]
+
+ shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
+ self.mean = np.array(mean).reshape(shape).astype('float32')
+ self.std = np.array(std).reshape(shape).astype('float32')
+
+ def __call__(self, data):
+ img = data['image']
+ from PIL import Image
+ if isinstance(img, Image.Image):
+ img = np.array(img)
+ assert isinstance(img,
+ np.ndarray), "invalid input 'img' in NormalizeImage"
+ data['image'] = (
+ img.astype('float32') * self.scale - self.mean) / self.std
+ return data
+
+
+class ToCHWImage(object):
+ """ convert hwc image to chw image
+ """
+
+ def __init__(self, **kwargs):
+ pass
+
+ def __call__(self, data):
+ img = data['image']
+ from PIL import Image
+ if isinstance(img, Image.Image):
+ img = np.array(img)
+ data['image'] = img.transpose((2, 0, 1))
+ return data
+
+
+class Fasttext(object):
+ def __init__(self, path="None", **kwargs):
+ import fasttext
+ self.fast_model = fasttext.load_model(path)
+
+ def __call__(self, data):
+ label = data['label']
+ fast_label = self.fast_model[label]
+ data['fast_label'] = fast_label
+ return data
+
+
+class KeepKeys(object):
+ def __init__(self, keep_keys, **kwargs):
+ self.keep_keys = keep_keys
+
+ def __call__(self, data):
+ data_list = []
+ for key in self.keep_keys:
+ data_list.append(data[key])
+ return data_list
+
+
+class Pad(object):
+ def __init__(self, size=None, size_div=32, **kwargs):
+ if size is not None and not isinstance(size, (int, list, tuple)):
+ raise TypeError("Type of target_size is invalid. Now is {}".format(
+ type(size)))
+ if isinstance(size, int):
+ size = [size, size]
+ self.size = size
+ self.size_div = size_div
+
+ def __call__(self, data):
+
+ img = data['image']
+ img_h, img_w = img.shape[0], img.shape[1]
+ if self.size:
+ resize_h2, resize_w2 = self.size
+ assert (
+ img_h < resize_h2 and img_w < resize_w2
+ ), '(h, w) of target size should be greater than (img_h, img_w)'
+ else:
+ resize_h2 = max(
+ int(math.ceil(img.shape[0] / self.size_div) * self.size_div),
+ self.size_div)
+ resize_w2 = max(
+ int(math.ceil(img.shape[1] / self.size_div) * self.size_div),
+ self.size_div)
+ img = cv2.copyMakeBorder(
+ img,
+ 0,
+ resize_h2 - img_h,
+ 0,
+ resize_w2 - img_w,
+ cv2.BORDER_CONSTANT,
+ value=0)
+ data['image'] = img
+ return data
+
+
+class Resize(object):
+ def __init__(self, size=(640, 640), **kwargs):
+ self.size = size
+
+ def resize_image(self, img):
+ resize_h, resize_w = self.size
+ ori_h, ori_w = img.shape[:2] # (h, w, c)
+ ratio_h = float(resize_h) / ori_h
+ ratio_w = float(resize_w) / ori_w
+ img = cv2.resize(img, (int(resize_w), int(resize_h)))
+ return img, [ratio_h, ratio_w]
+
+ def __call__(self, data):
+ img = data['image']
+ if 'polys' in data:
+ text_polys = data['polys']
+
+ img_resize, [ratio_h, ratio_w] = self.resize_image(img)
+ if 'polys' in data:
+ new_boxes = []
+ for box in text_polys:
+ new_box = []
+ for cord in box:
+ new_box.append([cord[0] * ratio_w, cord[1] * ratio_h])
+ new_boxes.append(new_box)
+ data['polys'] = np.array(new_boxes, dtype=np.float32)
+ data['image'] = img_resize
+ return data
+
+
+class DetResizeForTest(object):
+ def __init__(self, **kwargs):
+ super(DetResizeForTest, self).__init__()
+ self.resize_type = 0
+ if 'image_shape' in kwargs:
+ self.image_shape = kwargs['image_shape']
+ self.resize_type = 1
+ elif 'limit_side_len' in kwargs:
+ self.limit_side_len = kwargs['limit_side_len']
+ self.limit_type = kwargs.get('limit_type', 'min')
+ elif 'resize_long' in kwargs:
+ self.resize_type = 2
+ self.resize_long = kwargs.get('resize_long', 960)
+ else:
+ self.limit_side_len = 736
+ self.limit_type = 'min'
+
+ def __call__(self, data):
+ img = data['image']
+ src_h, src_w, _ = img.shape
+
+ if self.resize_type == 0:
+ # img, shape = self.resize_image_type0(img)
+ img, [ratio_h, ratio_w] = self.resize_image_type0(img)
+ elif self.resize_type == 2:
+ img, [ratio_h, ratio_w] = self.resize_image_type2(img)
+ else:
+ # img, shape = self.resize_image_type1(img)
+ img, [ratio_h, ratio_w] = self.resize_image_type1(img)
+ data['image'] = img
+ data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
+ return data
+
+ def resize_image_type1(self, img):
+ resize_h, resize_w = self.image_shape
+ ori_h, ori_w = img.shape[:2] # (h, w, c)
+ ratio_h = float(resize_h) / ori_h
+ ratio_w = float(resize_w) / ori_w
+ img = cv2.resize(img, (int(resize_w), int(resize_h)))
+ # return img, np.array([ori_h, ori_w])
+ return img, [ratio_h, ratio_w]
+
+ def resize_image_type0(self, img):
+ """
+ resize image to a size multiple of 32 which is required by the network
+ args:
+ img(array): array with shape [h, w, c]
+ return(tuple):
+ img, (ratio_h, ratio_w)
+ """
+ limit_side_len = self.limit_side_len
+ h, w, c = img.shape
+
+ # limit the max side
+ if self.limit_type == 'max':
+ if max(h, w) > limit_side_len:
+ if h > w:
+ ratio = float(limit_side_len) / h
+ else:
+ ratio = float(limit_side_len) / w
+ else:
+ ratio = 1.
+ elif self.limit_type == 'min':
+ if min(h, w) < limit_side_len:
+ if h < w:
+ ratio = float(limit_side_len) / h
+ else:
+ ratio = float(limit_side_len) / w
+ else:
+ ratio = 1.
+ elif self.limit_type == 'resize_long':
+ ratio = float(limit_side_len) / max(h, w)
+ else:
+ raise Exception('not support limit type, image ')
+ resize_h = int(h * ratio)
+ resize_w = int(w * ratio)
+
+ resize_h = max(int(round(resize_h / 32) * 32), 32)
+ resize_w = max(int(round(resize_w / 32) * 32), 32)
+
+ try:
+ if int(resize_w) <= 0 or int(resize_h) <= 0:
+ return None, (None, None)
+ img = cv2.resize(img, (int(resize_w), int(resize_h)))
+ except:
+ print(img.shape, resize_w, resize_h)
+ sys.exit(0)
+ ratio_h = resize_h / float(h)
+ ratio_w = resize_w / float(w)
+ return img, [ratio_h, ratio_w]
+
+ def resize_image_type2(self, img):
+ h, w, _ = img.shape
+
+ resize_w = w
+ resize_h = h
+
+ if resize_h > resize_w:
+ ratio = float(self.resize_long) / resize_h
+ else:
+ ratio = float(self.resize_long) / resize_w
+
+ resize_h = int(resize_h * ratio)
+ resize_w = int(resize_w * ratio)
+
+ max_stride = 128
+ resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
+ resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
+ img = cv2.resize(img, (int(resize_w), int(resize_h)))
+ ratio_h = resize_h / float(h)
+ ratio_w = resize_w / float(w)
+
+ return img, [ratio_h, ratio_w]
+
+
+class E2EResizeForTest(object):
+ def __init__(self, **kwargs):
+ super(E2EResizeForTest, self).__init__()
+ self.max_side_len = kwargs['max_side_len']
+ self.valid_set = kwargs['valid_set']
+
+ def __call__(self, data):
+ img = data['image']
+ src_h, src_w, _ = img.shape
+ if self.valid_set == 'totaltext':
+ im_resized, [ratio_h, ratio_w] = self.resize_image_for_totaltext(
+ img, max_side_len=self.max_side_len)
+ else:
+ im_resized, (ratio_h, ratio_w) = self.resize_image(
+ img, max_side_len=self.max_side_len)
+ data['image'] = im_resized
+ data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
+ return data
+
+ def resize_image_for_totaltext(self, im, max_side_len=512):
+
+ h, w, _ = im.shape
+ resize_w = w
+ resize_h = h
+ ratio = 1.25
+ if h * ratio > max_side_len:
+ ratio = float(max_side_len) / resize_h
+ resize_h = int(resize_h * ratio)
+ resize_w = int(resize_w * ratio)
+
+ max_stride = 128
+ resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
+ resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
+ im = cv2.resize(im, (int(resize_w), int(resize_h)))
+ ratio_h = resize_h / float(h)
+ ratio_w = resize_w / float(w)
+ return im, (ratio_h, ratio_w)
+
+ def resize_image(self, im, max_side_len=512):
+ """
+ resize image to a size multiple of max_stride which is required by the network
+ :param im: the resized image
+ :param max_side_len: limit of max image size to avoid out of memory in gpu
+ :return: the resized image and the resize ratio
+ """
+ h, w, _ = im.shape
+
+ resize_w = w
+ resize_h = h
+
+ # Fix the longer side
+ if resize_h > resize_w:
+ ratio = float(max_side_len) / resize_h
+ else:
+ ratio = float(max_side_len) / resize_w
+
+ resize_h = int(resize_h * ratio)
+ resize_w = int(resize_w * ratio)
+
+ max_stride = 128
+ resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
+ resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
+ im = cv2.resize(im, (int(resize_w), int(resize_h)))
+ ratio_h = resize_h / float(h)
+ ratio_w = resize_w / float(w)
+
+ return im, (ratio_h, ratio_w)
+
+
+class KieResize(object):
+ def __init__(self, **kwargs):
+ super(KieResize, self).__init__()
+ self.max_side, self.min_side = kwargs['img_scale'][0], kwargs[
+ 'img_scale'][1]
+
+ def __call__(self, data):
+ img = data['image']
+ points = data['points']
+ src_h, src_w, _ = img.shape
+ im_resized, scale_factor, [ratio_h, ratio_w
+ ], [new_h, new_w] = self.resize_image(img)
+ resize_points = self.resize_boxes(img, points, scale_factor)
+ data['ori_image'] = img
+ data['ori_boxes'] = points
+ data['points'] = resize_points
+ data['image'] = im_resized
+ data['shape'] = np.array([new_h, new_w])
+ return data
+
+ def resize_image(self, img):
+ norm_img = np.zeros([1024, 1024, 3], dtype='float32')
+ scale = [512, 1024]
+ h, w = img.shape[:2]
+ max_long_edge = max(scale)
+ max_short_edge = min(scale)
+ scale_factor = min(max_long_edge / max(h, w),
+ max_short_edge / min(h, w))
+ resize_w, resize_h = int(w * float(scale_factor) + 0.5), int(h * float(
+ scale_factor) + 0.5)
+ max_stride = 32
+ resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
+ resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
+ im = cv2.resize(img, (resize_w, resize_h))
+ new_h, new_w = im.shape[:2]
+ w_scale = new_w / w
+ h_scale = new_h / h
+ scale_factor = np.array(
+ [w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
+ norm_img[:new_h, :new_w, :] = im
+ return norm_img, scale_factor, [h_scale, w_scale], [new_h, new_w]
+
+ def resize_boxes(self, im, points, scale_factor):
+ points = points * scale_factor
+ img_shape = im.shape[:2]
+ points[:, 0::2] = np.clip(points[:, 0::2], 0, img_shape[1])
+ points[:, 1::2] = np.clip(points[:, 1::2], 0, img_shape[0])
+ return points
diff --git a/backend/ppocr/data/imaug/pg_process.py b/backend/ppocr/data/imaug/pg_process.py
new file mode 100644
index 00000000..53031064
--- /dev/null
+++ b/backend/ppocr/data/imaug/pg_process.py
@@ -0,0 +1,906 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import cv2
+import numpy as np
+
+__all__ = ['PGProcessTrain']
+
+
+class PGProcessTrain(object):
+ def __init__(self,
+ character_dict_path,
+ max_text_length,
+ max_text_nums,
+ tcl_len,
+ batch_size=14,
+ min_crop_size=24,
+ min_text_size=4,
+ max_text_size=512,
+ **kwargs):
+ self.tcl_len = tcl_len
+ self.max_text_length = max_text_length
+ self.max_text_nums = max_text_nums
+ self.batch_size = batch_size
+ self.min_crop_size = min_crop_size
+ self.min_text_size = min_text_size
+ self.max_text_size = max_text_size
+ self.Lexicon_Table = self.get_dict(character_dict_path)
+ self.pad_num = len(self.Lexicon_Table)
+ self.img_id = 0
+
+ def get_dict(self, character_dict_path):
+ character_str = ""
+ with open(character_dict_path, "rb") as fin:
+ lines = fin.readlines()
+ for line in lines:
+ line = line.decode('utf-8').strip("\n").strip("\r\n")
+ character_str += line
+ dict_character = list(character_str)
+ return dict_character
+
+ def quad_area(self, poly):
+ """
+ compute area of a polygon
+ :param poly:
+ :return:
+ """
+ edge = [(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
+ (poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
+ (poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
+ (poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])]
+ return np.sum(edge) / 2.
+
+ def gen_quad_from_poly(self, poly):
+ """
+ Generate min area quad from poly.
+ """
+ point_num = poly.shape[0]
+ min_area_quad = np.zeros((4, 2), dtype=np.float32)
+ rect = cv2.minAreaRect(poly.astype(
+ np.int32)) # (center (x,y), (width, height), angle of rotation)
+ box = np.array(cv2.boxPoints(rect))
+
+ first_point_idx = 0
+ min_dist = 1e4
+ for i in range(4):
+ dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
+ np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
+ np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
+ np.linalg.norm(box[(i + 3) % 4] - poly[-1])
+ if dist < min_dist:
+ min_dist = dist
+ first_point_idx = i
+ for i in range(4):
+ min_area_quad[i] = box[(first_point_idx + i) % 4]
+
+ return min_area_quad
+
+ def check_and_validate_polys(self, polys, tags, im_size):
+ """
+ check so that the text poly is in the same direction,
+ and also filter some invalid polygons
+ :param polys:
+ :param tags:
+ :return:
+ """
+ (h, w) = im_size
+ if polys.shape[0] == 0:
+ return polys, np.array([]), np.array([])
+ polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
+ polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1)
+
+ validated_polys = []
+ validated_tags = []
+ hv_tags = []
+ for poly, tag in zip(polys, tags):
+ quad = self.gen_quad_from_poly(poly)
+ p_area = self.quad_area(quad)
+ if abs(p_area) < 1:
+ print('invalid poly')
+ continue
+ if p_area > 0:
+ if tag == False:
+ print('poly in wrong direction')
+ tag = True # reversed cases should be ignore
+ poly = poly[(0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2,
+ 1), :]
+ quad = quad[(0, 3, 2, 1), :]
+
+ len_w = np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[3] -
+ quad[2])
+ len_h = np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] -
+ quad[2])
+ hv_tag = 1
+
+ if len_w * 2.0 < len_h:
+ hv_tag = 0
+
+ validated_polys.append(poly)
+ validated_tags.append(tag)
+ hv_tags.append(hv_tag)
+ return np.array(validated_polys), np.array(validated_tags), np.array(
+ hv_tags)
+
+ def crop_area(self,
+ im,
+ polys,
+ tags,
+ hv_tags,
+ txts,
+ crop_background=False,
+ max_tries=25):
+ """
+ make random crop from the input image
+ :param im:
+ :param polys: [b,4,2]
+ :param tags:
+ :param crop_background:
+ :param max_tries: 50 -> 25
+ :return:
+ """
+ h, w, _ = im.shape
+ pad_h = h // 10
+ pad_w = w // 10
+ h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
+ w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
+ for poly in polys:
+ poly = np.round(poly, decimals=0).astype(np.int32)
+ minx = np.min(poly[:, 0])
+ maxx = np.max(poly[:, 0])
+ w_array[minx + pad_w:maxx + pad_w] = 1
+ miny = np.min(poly[:, 1])
+ maxy = np.max(poly[:, 1])
+ h_array[miny + pad_h:maxy + pad_h] = 1
+ # ensure the cropped area not across a text
+ h_axis = np.where(h_array == 0)[0]
+ w_axis = np.where(w_array == 0)[0]
+ if len(h_axis) == 0 or len(w_axis) == 0:
+ return im, polys, tags, hv_tags, txts
+ for i in range(max_tries):
+ xx = np.random.choice(w_axis, size=2)
+ xmin = np.min(xx) - pad_w
+ xmax = np.max(xx) - pad_w
+ xmin = np.clip(xmin, 0, w - 1)
+ xmax = np.clip(xmax, 0, w - 1)
+ yy = np.random.choice(h_axis, size=2)
+ ymin = np.min(yy) - pad_h
+ ymax = np.max(yy) - pad_h
+ ymin = np.clip(ymin, 0, h - 1)
+ ymax = np.clip(ymax, 0, h - 1)
+ if xmax - xmin < self.min_crop_size or \
+ ymax - ymin < self.min_crop_size:
+ continue
+ if polys.shape[0] != 0:
+ poly_axis_in_area = (polys[:, :, 0] >= xmin) & (polys[:, :, 0] <= xmax) \
+ & (polys[:, :, 1] >= ymin) & (polys[:, :, 1] <= ymax)
+ selected_polys = np.where(
+ np.sum(poly_axis_in_area, axis=1) == 4)[0]
+ else:
+ selected_polys = []
+ if len(selected_polys) == 0:
+ # no text in this area
+ if crop_background:
+ txts_tmp = []
+ for selected_poly in selected_polys:
+ txts_tmp.append(txts[selected_poly])
+ txts = txts_tmp
+ return im[ymin: ymax + 1, xmin: xmax + 1, :], \
+ polys[selected_polys], tags[selected_polys], hv_tags[selected_polys], txts
+ else:
+ continue
+ im = im[ymin:ymax + 1, xmin:xmax + 1, :]
+ polys = polys[selected_polys]
+ tags = tags[selected_polys]
+ hv_tags = hv_tags[selected_polys]
+ txts_tmp = []
+ for selected_poly in selected_polys:
+ txts_tmp.append(txts[selected_poly])
+ txts = txts_tmp
+ polys[:, :, 0] -= xmin
+ polys[:, :, 1] -= ymin
+ return im, polys, tags, hv_tags, txts
+
+ return im, polys, tags, hv_tags, txts
+
+ def fit_and_gather_tcl_points_v2(self,
+ min_area_quad,
+ poly,
+ max_h,
+ max_w,
+ fixed_point_num=64,
+ img_id=0,
+ reference_height=3):
+ """
+ Find the center point of poly as key_points, then fit and gather.
+ """
+ key_point_xys = []
+ point_num = poly.shape[0]
+ for idx in range(point_num // 2):
+ center_point = (poly[idx] + poly[point_num - 1 - idx]) / 2.0
+ key_point_xys.append(center_point)
+
+ tmp_image = np.zeros(
+ shape=(
+ max_h,
+ max_w, ), dtype='float32')
+ cv2.polylines(tmp_image, [np.array(key_point_xys).astype('int32')],
+ False, 1.0)
+ ys, xs = np.where(tmp_image > 0)
+ xy_text = np.array(list(zip(xs, ys)), dtype='float32')
+
+ left_center_pt = (
+ (min_area_quad[0] - min_area_quad[1]) / 2.0).reshape(1, 2)
+ right_center_pt = (
+ (min_area_quad[1] - min_area_quad[2]) / 2.0).reshape(1, 2)
+ proj_unit_vec = (right_center_pt - left_center_pt) / (
+ np.linalg.norm(right_center_pt - left_center_pt) + 1e-6)
+ proj_unit_vec_tile = np.tile(proj_unit_vec,
+ (xy_text.shape[0], 1)) # (n, 2)
+ left_center_pt_tile = np.tile(left_center_pt,
+ (xy_text.shape[0], 1)) # (n, 2)
+ xy_text_to_left_center = xy_text - left_center_pt_tile
+ proj_value = np.sum(xy_text_to_left_center * proj_unit_vec_tile, axis=1)
+ xy_text = xy_text[np.argsort(proj_value)]
+
+ # convert to np and keep the num of point not greater then fixed_point_num
+ pos_info = np.array(xy_text).reshape(-1, 2)[:, ::-1] # xy-> yx
+ point_num = len(pos_info)
+ if point_num > fixed_point_num:
+ keep_ids = [
+ int((point_num * 1.0 / fixed_point_num) * x)
+ for x in range(fixed_point_num)
+ ]
+ pos_info = pos_info[keep_ids, :]
+
+ keep = int(min(len(pos_info), fixed_point_num))
+ if np.random.rand() < 0.2 and reference_height >= 3:
+ dl = (np.random.rand(keep) - 0.5) * reference_height * 0.3
+ random_float = np.array([1, 0]).reshape([1, 2]) * dl.reshape(
+ [keep, 1])
+ pos_info += random_float
+ pos_info[:, 0] = np.clip(pos_info[:, 0], 0, max_h - 1)
+ pos_info[:, 1] = np.clip(pos_info[:, 1], 0, max_w - 1)
+
+ # padding to fixed length
+ pos_l = np.zeros((self.tcl_len, 3), dtype=np.int32)
+ pos_l[:, 0] = np.ones((self.tcl_len, )) * img_id
+ pos_m = np.zeros((self.tcl_len, 1), dtype=np.float32)
+ pos_l[:keep, 1:] = np.round(pos_info).astype(np.int32)
+ pos_m[:keep] = 1.0
+ return pos_l, pos_m
+
+ def generate_direction_map(self, poly_quads, n_char, direction_map):
+ """
+ """
+ width_list = []
+ height_list = []
+ for quad in poly_quads:
+ quad_w = (np.linalg.norm(quad[0] - quad[1]) +
+ np.linalg.norm(quad[2] - quad[3])) / 2.0
+ quad_h = (np.linalg.norm(quad[0] - quad[3]) +
+ np.linalg.norm(quad[2] - quad[1])) / 2.0
+ width_list.append(quad_w)
+ height_list.append(quad_h)
+ norm_width = max(sum(width_list) / n_char, 1.0)
+ average_height = max(sum(height_list) / len(height_list), 1.0)
+ k = 1
+ for quad in poly_quads:
+ direct_vector_full = (
+ (quad[1] + quad[2]) - (quad[0] + quad[3])) / 2.0
+ direct_vector = direct_vector_full / (
+ np.linalg.norm(direct_vector_full) + 1e-6) * norm_width
+ direction_label = tuple(
+ map(float,
+ [direct_vector[0], direct_vector[1], 1.0 / average_height]))
+ cv2.fillPoly(direction_map,
+ quad.round().astype(np.int32)[np.newaxis, :, :],
+ direction_label)
+ k += 1
+ return direction_map
+
+ def calculate_average_height(self, poly_quads):
+ """
+ """
+ height_list = []
+ for quad in poly_quads:
+ quad_h = (np.linalg.norm(quad[0] - quad[3]) +
+ np.linalg.norm(quad[2] - quad[1])) / 2.0
+ height_list.append(quad_h)
+ average_height = max(sum(height_list) / len(height_list), 1.0)
+ return average_height
+
+ def generate_tcl_ctc_label(self,
+ h,
+ w,
+ polys,
+ tags,
+ text_strs,
+ ds_ratio,
+ tcl_ratio=0.3,
+ shrink_ratio_of_width=0.15):
+ """
+ Generate polygon.
+ """
+ score_map_big = np.zeros(
+ (
+ h,
+ w, ), dtype=np.float32)
+ h, w = int(h * ds_ratio), int(w * ds_ratio)
+ polys = polys * ds_ratio
+
+ score_map = np.zeros(
+ (
+ h,
+ w, ), dtype=np.float32)
+ score_label_map = np.zeros(
+ (
+ h,
+ w, ), dtype=np.float32)
+ tbo_map = np.zeros((h, w, 5), dtype=np.float32)
+ training_mask = np.ones(
+ (
+ h,
+ w, ), dtype=np.float32)
+ direction_map = np.ones((h, w, 3)) * np.array([0, 0, 1]).reshape(
+ [1, 1, 3]).astype(np.float32)
+
+ label_idx = 0
+ score_label_map_text_label_list = []
+ pos_list, pos_mask, label_list = [], [], []
+ for poly_idx, poly_tag in enumerate(zip(polys, tags)):
+ poly = poly_tag[0]
+ tag = poly_tag[1]
+
+ # generate min_area_quad
+ min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly)
+ min_area_quad_h = 0.5 * (
+ np.linalg.norm(min_area_quad[0] - min_area_quad[3]) +
+ np.linalg.norm(min_area_quad[1] - min_area_quad[2]))
+ min_area_quad_w = 0.5 * (
+ np.linalg.norm(min_area_quad[0] - min_area_quad[1]) +
+ np.linalg.norm(min_area_quad[2] - min_area_quad[3]))
+
+ if min(min_area_quad_h, min_area_quad_w) < self.min_text_size * ds_ratio \
+ or min(min_area_quad_h, min_area_quad_w) > self.max_text_size * ds_ratio:
+ continue
+
+ if tag:
+ cv2.fillPoly(training_mask,
+ poly.astype(np.int32)[np.newaxis, :, :], 0.15)
+ else:
+ text_label = text_strs[poly_idx]
+ text_label = self.prepare_text_label(text_label,
+ self.Lexicon_Table)
+
+ text_label_index_list = [[self.Lexicon_Table.index(c_)]
+ for c_ in text_label
+ if c_ in self.Lexicon_Table]
+ if len(text_label_index_list) < 1:
+ continue
+
+ tcl_poly = self.poly2tcl(poly, tcl_ratio)
+ tcl_quads = self.poly2quads(tcl_poly)
+ poly_quads = self.poly2quads(poly)
+
+ stcl_quads, quad_index = self.shrink_poly_along_width(
+ tcl_quads,
+ shrink_ratio_of_width=shrink_ratio_of_width,
+ expand_height_ratio=1.0 / tcl_ratio)
+
+ cv2.fillPoly(score_map,
+ np.round(stcl_quads).astype(np.int32), 1.0)
+ cv2.fillPoly(score_map_big,
+ np.round(stcl_quads / ds_ratio).astype(np.int32),
+ 1.0)
+
+ for idx, quad in enumerate(stcl_quads):
+ quad_mask = np.zeros((h, w), dtype=np.float32)
+ quad_mask = cv2.fillPoly(
+ quad_mask,
+ np.round(quad[np.newaxis, :, :]).astype(np.int32), 1.0)
+ tbo_map = self.gen_quad_tbo(poly_quads[quad_index[idx]],
+ quad_mask, tbo_map)
+
+ # score label map and score_label_map_text_label_list for refine
+ if label_idx == 0:
+ text_pos_list_ = [[len(self.Lexicon_Table)], ]
+ score_label_map_text_label_list.append(text_pos_list_)
+
+ label_idx += 1
+ cv2.fillPoly(score_label_map,
+ np.round(poly_quads).astype(np.int32), label_idx)
+ score_label_map_text_label_list.append(text_label_index_list)
+
+ # direction info, fix-me
+ n_char = len(text_label_index_list)
+ direction_map = self.generate_direction_map(poly_quads, n_char,
+ direction_map)
+
+ # pos info
+ average_shrink_height = self.calculate_average_height(
+ stcl_quads)
+ pos_l, pos_m = self.fit_and_gather_tcl_points_v2(
+ min_area_quad,
+ poly,
+ max_h=h,
+ max_w=w,
+ fixed_point_num=64,
+ img_id=self.img_id,
+ reference_height=average_shrink_height)
+
+ label_l = text_label_index_list
+ if len(text_label_index_list) < 2:
+ continue
+
+ pos_list.append(pos_l)
+ pos_mask.append(pos_m)
+ label_list.append(label_l)
+
+ # use big score_map for smooth tcl lines
+ score_map_big_resized = cv2.resize(
+ score_map_big, dsize=None, fx=ds_ratio, fy=ds_ratio)
+ score_map = np.array(score_map_big_resized > 1e-3, dtype='float32')
+
+ return score_map, score_label_map, tbo_map, direction_map, training_mask, \
+ pos_list, pos_mask, label_list, score_label_map_text_label_list
+
+ def adjust_point(self, poly):
+ """
+ adjust point order.
+ """
+ point_num = poly.shape[0]
+ if point_num == 4:
+ len_1 = np.linalg.norm(poly[0] - poly[1])
+ len_2 = np.linalg.norm(poly[1] - poly[2])
+ len_3 = np.linalg.norm(poly[2] - poly[3])
+ len_4 = np.linalg.norm(poly[3] - poly[0])
+
+ if (len_1 + len_3) * 1.5 < (len_2 + len_4):
+ poly = poly[[1, 2, 3, 0], :]
+
+ elif point_num > 4:
+ vector_1 = poly[0] - poly[1]
+ vector_2 = poly[1] - poly[2]
+ cos_theta = np.dot(vector_1, vector_2) / (
+ np.linalg.norm(vector_1) * np.linalg.norm(vector_2) + 1e-6)
+ theta = np.arccos(np.round(cos_theta, decimals=4))
+
+ if abs(theta) > (70 / 180 * math.pi):
+ index = list(range(1, point_num)) + [0]
+ poly = poly[np.array(index), :]
+ return poly
+
+ def gen_min_area_quad_from_poly(self, poly):
+ """
+ Generate min area quad from poly.
+ """
+ point_num = poly.shape[0]
+ min_area_quad = np.zeros((4, 2), dtype=np.float32)
+ if point_num == 4:
+ min_area_quad = poly
+ center_point = np.sum(poly, axis=0) / 4
+ else:
+ rect = cv2.minAreaRect(poly.astype(
+ np.int32)) # (center (x,y), (width, height), angle of rotation)
+ center_point = rect[0]
+ box = np.array(cv2.boxPoints(rect))
+
+ first_point_idx = 0
+ min_dist = 1e4
+ for i in range(4):
+ dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
+ np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
+ np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
+ np.linalg.norm(box[(i + 3) % 4] - poly[-1])
+ if dist < min_dist:
+ min_dist = dist
+ first_point_idx = i
+
+ for i in range(4):
+ min_area_quad[i] = box[(first_point_idx + i) % 4]
+
+ return min_area_quad, center_point
+
+ def shrink_quad_along_width(self,
+ quad,
+ begin_width_ratio=0.,
+ end_width_ratio=1.):
+ """
+ Generate shrink_quad_along_width.
+ """
+ ratio_pair = np.array(
+ [[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
+ p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
+ p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
+ return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
+
+ def shrink_poly_along_width(self,
+ quads,
+ shrink_ratio_of_width,
+ expand_height_ratio=1.0):
+ """
+ shrink poly with given length.
+ """
+ upper_edge_list = []
+
+ def get_cut_info(edge_len_list, cut_len):
+ for idx, edge_len in enumerate(edge_len_list):
+ cut_len -= edge_len
+ if cut_len <= 0.000001:
+ ratio = (cut_len + edge_len_list[idx]) / edge_len_list[idx]
+ return idx, ratio
+
+ for quad in quads:
+ upper_edge_len = np.linalg.norm(quad[0] - quad[1])
+ upper_edge_list.append(upper_edge_len)
+
+ # length of left edge and right edge.
+ left_length = np.linalg.norm(quads[0][0] - quads[0][
+ 3]) * expand_height_ratio
+ right_length = np.linalg.norm(quads[-1][1] - quads[-1][
+ 2]) * expand_height_ratio
+
+ shrink_length = min(left_length, right_length,
+ sum(upper_edge_list)) * shrink_ratio_of_width
+ # shrinking length
+ upper_len_left = shrink_length
+ upper_len_right = sum(upper_edge_list) - shrink_length
+
+ left_idx, left_ratio = get_cut_info(upper_edge_list, upper_len_left)
+ left_quad = self.shrink_quad_along_width(
+ quads[left_idx], begin_width_ratio=left_ratio, end_width_ratio=1)
+ right_idx, right_ratio = get_cut_info(upper_edge_list, upper_len_right)
+ right_quad = self.shrink_quad_along_width(
+ quads[right_idx], begin_width_ratio=0, end_width_ratio=right_ratio)
+
+ out_quad_list = []
+ if left_idx == right_idx:
+ out_quad_list.append(
+ [left_quad[0], right_quad[1], right_quad[2], left_quad[3]])
+ else:
+ out_quad_list.append(left_quad)
+ for idx in range(left_idx + 1, right_idx):
+ out_quad_list.append(quads[idx])
+ out_quad_list.append(right_quad)
+
+ return np.array(out_quad_list), list(range(left_idx, right_idx + 1))
+
+ def prepare_text_label(self, label_str, Lexicon_Table):
+ """
+ Prepare text lablel by given Lexicon_Table.
+ """
+ if len(Lexicon_Table) == 36:
+ return label_str.lower()
+ else:
+ return label_str
+
+ def vector_angle(self, A, B):
+ """
+ Calculate the angle between vector AB and x-axis positive direction.
+ """
+ AB = np.array([B[1] - A[1], B[0] - A[0]])
+ return np.arctan2(*AB)
+
+ def theta_line_cross_point(self, theta, point):
+ """
+ Calculate the line through given point and angle in ax + by + c =0 form.
+ """
+ x, y = point
+ cos = np.cos(theta)
+ sin = np.sin(theta)
+ return [sin, -cos, cos * y - sin * x]
+
+ def line_cross_two_point(self, A, B):
+ """
+ Calculate the line through given point A and B in ax + by + c =0 form.
+ """
+ angle = self.vector_angle(A, B)
+ return self.theta_line_cross_point(angle, A)
+
+ def average_angle(self, poly):
+ """
+ Calculate the average angle between left and right edge in given poly.
+ """
+ p0, p1, p2, p3 = poly
+ angle30 = self.vector_angle(p3, p0)
+ angle21 = self.vector_angle(p2, p1)
+ return (angle30 + angle21) / 2
+
+ def line_cross_point(self, line1, line2):
+ """
+ line1 and line2 in 0=ax+by+c form, compute the cross point of line1 and line2
+ """
+ a1, b1, c1 = line1
+ a2, b2, c2 = line2
+ d = a1 * b2 - a2 * b1
+
+ if d == 0:
+ print('Cross point does not exist')
+ return np.array([0, 0], dtype=np.float32)
+ else:
+ x = (b1 * c2 - b2 * c1) / d
+ y = (a2 * c1 - a1 * c2) / d
+
+ return np.array([x, y], dtype=np.float32)
+
+ def quad2tcl(self, poly, ratio):
+ """
+ Generate center line by poly clock-wise point. (4, 2)
+ """
+ ratio_pair = np.array(
+ [[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
+ p0_3 = poly[0] + (poly[3] - poly[0]) * ratio_pair
+ p1_2 = poly[1] + (poly[2] - poly[1]) * ratio_pair
+ return np.array([p0_3[0], p1_2[0], p1_2[1], p0_3[1]])
+
+ def poly2tcl(self, poly, ratio):
+ """
+ Generate center line by poly clock-wise point.
+ """
+ ratio_pair = np.array(
+ [[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
+ tcl_poly = np.zeros_like(poly)
+ point_num = poly.shape[0]
+
+ for idx in range(point_num // 2):
+ point_pair = poly[idx] + (poly[point_num - 1 - idx] - poly[idx]
+ ) * ratio_pair
+ tcl_poly[idx] = point_pair[0]
+ tcl_poly[point_num - 1 - idx] = point_pair[1]
+ return tcl_poly
+
+ def gen_quad_tbo(self, quad, tcl_mask, tbo_map):
+ """
+ Generate tbo_map for give quad.
+ """
+ # upper and lower line function: ax + by + c = 0;
+ up_line = self.line_cross_two_point(quad[0], quad[1])
+ lower_line = self.line_cross_two_point(quad[3], quad[2])
+
+ quad_h = 0.5 * (np.linalg.norm(quad[0] - quad[3]) +
+ np.linalg.norm(quad[1] - quad[2]))
+ quad_w = 0.5 * (np.linalg.norm(quad[0] - quad[1]) +
+ np.linalg.norm(quad[2] - quad[3]))
+
+ # average angle of left and right line.
+ angle = self.average_angle(quad)
+
+ xy_in_poly = np.argwhere(tcl_mask == 1)
+ for y, x in xy_in_poly:
+ point = (x, y)
+ line = self.theta_line_cross_point(angle, point)
+ cross_point_upper = self.line_cross_point(up_line, line)
+ cross_point_lower = self.line_cross_point(lower_line, line)
+ ##FIX, offset reverse
+ upper_offset_x, upper_offset_y = cross_point_upper - point
+ lower_offset_x, lower_offset_y = cross_point_lower - point
+ tbo_map[y, x, 0] = upper_offset_y
+ tbo_map[y, x, 1] = upper_offset_x
+ tbo_map[y, x, 2] = lower_offset_y
+ tbo_map[y, x, 3] = lower_offset_x
+ tbo_map[y, x, 4] = 1.0 / max(min(quad_h, quad_w), 1.0) * 2
+ return tbo_map
+
+ def poly2quads(self, poly):
+ """
+ Split poly into quads.
+ """
+ quad_list = []
+ point_num = poly.shape[0]
+
+ # point pair
+ point_pair_list = []
+ for idx in range(point_num // 2):
+ point_pair = [poly[idx], poly[point_num - 1 - idx]]
+ point_pair_list.append(point_pair)
+
+ quad_num = point_num // 2 - 1
+ for idx in range(quad_num):
+ # reshape and adjust to clock-wise
+ quad_list.append((np.array(point_pair_list)[[idx, idx + 1]]
+ ).reshape(4, 2)[[0, 2, 3, 1]])
+
+ return np.array(quad_list)
+
+ def rotate_im_poly(self, im, text_polys):
+ """
+ rotate image with 90 / 180 / 270 degre
+ """
+ im_w, im_h = im.shape[1], im.shape[0]
+ dst_im = im.copy()
+ dst_polys = []
+ rand_degree_ratio = np.random.rand()
+ rand_degree_cnt = 1
+ if rand_degree_ratio > 0.5:
+ rand_degree_cnt = 3
+ for i in range(rand_degree_cnt):
+ dst_im = np.rot90(dst_im)
+ rot_degree = -90 * rand_degree_cnt
+ rot_angle = rot_degree * math.pi / 180.0
+ n_poly = text_polys.shape[0]
+ cx, cy = 0.5 * im_w, 0.5 * im_h
+ ncx, ncy = 0.5 * dst_im.shape[1], 0.5 * dst_im.shape[0]
+ for i in range(n_poly):
+ wordBB = text_polys[i]
+ poly = []
+ for j in range(4): # 16->4
+ sx, sy = wordBB[j][0], wordBB[j][1]
+ dx = math.cos(rot_angle) * (sx - cx) - math.sin(rot_angle) * (
+ sy - cy) + ncx
+ dy = math.sin(rot_angle) * (sx - cx) + math.cos(rot_angle) * (
+ sy - cy) + ncy
+ poly.append([dx, dy])
+ dst_polys.append(poly)
+ return dst_im, np.array(dst_polys, dtype=np.float32)
+
+ def __call__(self, data):
+ input_size = 512
+ im = data['image']
+ text_polys = data['polys']
+ text_tags = data['ignore_tags']
+ text_strs = data['texts']
+ h, w, _ = im.shape
+ text_polys, text_tags, hv_tags = self.check_and_validate_polys(
+ text_polys, text_tags, (h, w))
+ if text_polys.shape[0] <= 0:
+ return None
+ # set aspect ratio and keep area fix
+ asp_scales = np.arange(1.0, 1.55, 0.1)
+ asp_scale = np.random.choice(asp_scales)
+ if np.random.rand() < 0.5:
+ asp_scale = 1.0 / asp_scale
+ asp_scale = math.sqrt(asp_scale)
+
+ asp_wx = asp_scale
+ asp_hy = 1.0 / asp_scale
+ im = cv2.resize(im, dsize=None, fx=asp_wx, fy=asp_hy)
+ text_polys[:, :, 0] *= asp_wx
+ text_polys[:, :, 1] *= asp_hy
+
+ h, w, _ = im.shape
+ if max(h, w) > 2048:
+ rd_scale = 2048.0 / max(h, w)
+ im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
+ text_polys *= rd_scale
+ h, w, _ = im.shape
+ if min(h, w) < 16:
+ return None
+
+ # no background
+ im, text_polys, text_tags, hv_tags, text_strs = self.crop_area(
+ im,
+ text_polys,
+ text_tags,
+ hv_tags,
+ text_strs,
+ crop_background=False)
+
+ if text_polys.shape[0] == 0:
+ return None
+ # # continue for all ignore case
+ if np.sum((text_tags * 1.0)) >= text_tags.size:
+ return None
+ new_h, new_w, _ = im.shape
+ if (new_h is None) or (new_w is None):
+ return None
+ # resize image
+ std_ratio = float(input_size) / max(new_w, new_h)
+ rand_scales = np.array(
+ [0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.0, 1.0, 1.0, 1.0])
+ rz_scale = std_ratio * np.random.choice(rand_scales)
+ im = cv2.resize(im, dsize=None, fx=rz_scale, fy=rz_scale)
+ text_polys[:, :, 0] *= rz_scale
+ text_polys[:, :, 1] *= rz_scale
+
+ # add gaussian blur
+ if np.random.rand() < 0.1 * 0.5:
+ ks = np.random.permutation(5)[0] + 1
+ ks = int(ks / 2) * 2 + 1
+ im = cv2.GaussianBlur(im, ksize=(ks, ks), sigmaX=0, sigmaY=0)
+ # add brighter
+ if np.random.rand() < 0.1 * 0.5:
+ im = im * (1.0 + np.random.rand() * 0.5)
+ im = np.clip(im, 0.0, 255.0)
+ # add darker
+ if np.random.rand() < 0.1 * 0.5:
+ im = im * (1.0 - np.random.rand() * 0.5)
+ im = np.clip(im, 0.0, 255.0)
+
+ # Padding the im to [input_size, input_size]
+ new_h, new_w, _ = im.shape
+ if min(new_w, new_h) < input_size * 0.5:
+ return None
+ im_padded = np.ones((input_size, input_size, 3), dtype=np.float32)
+ im_padded[:, :, 2] = 0.485 * 255
+ im_padded[:, :, 1] = 0.456 * 255
+ im_padded[:, :, 0] = 0.406 * 255
+
+ # Random the start position
+ del_h = input_size - new_h
+ del_w = input_size - new_w
+ sh, sw = 0, 0
+ if del_h > 1:
+ sh = int(np.random.rand() * del_h)
+ if del_w > 1:
+ sw = int(np.random.rand() * del_w)
+
+ # Padding
+ im_padded[sh:sh + new_h, sw:sw + new_w, :] = im.copy()
+ text_polys[:, :, 0] += sw
+ text_polys[:, :, 1] += sh
+
+ score_map, score_label_map, border_map, direction_map, training_mask, \
+ pos_list, pos_mask, label_list, score_label_map_text_label = self.generate_tcl_ctc_label(input_size,
+ input_size,
+ text_polys,
+ text_tags,
+ text_strs, 0.25)
+ if len(label_list) <= 0: # eliminate negative samples
+ return None
+ pos_list_temp = np.zeros([64, 3])
+ pos_mask_temp = np.zeros([64, 1])
+ label_list_temp = np.zeros([self.max_text_length, 1]) + self.pad_num
+
+ for i, label in enumerate(label_list):
+ n = len(label)
+ if n > self.max_text_length:
+ label_list[i] = label[:self.max_text_length]
+ continue
+ while n < self.max_text_length:
+ label.append([self.pad_num])
+ n += 1
+
+ for i in range(len(label_list)):
+ label_list[i] = np.array(label_list[i])
+
+ if len(pos_list) <= 0 or len(pos_list) > self.max_text_nums:
+ return None
+ for __ in range(self.max_text_nums - len(pos_list), 0, -1):
+ pos_list.append(pos_list_temp)
+ pos_mask.append(pos_mask_temp)
+ label_list.append(label_list_temp)
+
+ if self.img_id == self.batch_size - 1:
+ self.img_id = 0
+ else:
+ self.img_id += 1
+
+ im_padded[:, :, 2] -= 0.485 * 255
+ im_padded[:, :, 1] -= 0.456 * 255
+ im_padded[:, :, 0] -= 0.406 * 255
+ im_padded[:, :, 2] /= (255.0 * 0.229)
+ im_padded[:, :, 1] /= (255.0 * 0.224)
+ im_padded[:, :, 0] /= (255.0 * 0.225)
+ im_padded = im_padded.transpose((2, 0, 1))
+ images = im_padded[::-1, :, :]
+ tcl_maps = score_map[np.newaxis, :, :]
+ tcl_label_maps = score_label_map[np.newaxis, :, :]
+ border_maps = border_map.transpose((2, 0, 1))
+ direction_maps = direction_map.transpose((2, 0, 1))
+ training_masks = training_mask[np.newaxis, :, :]
+ pos_list = np.array(pos_list)
+ pos_mask = np.array(pos_mask)
+ label_list = np.array(label_list)
+ data['images'] = images
+ data['tcl_maps'] = tcl_maps
+ data['tcl_label_maps'] = tcl_label_maps
+ data['border_maps'] = border_maps
+ data['direction_maps'] = direction_maps
+ data['training_masks'] = training_masks
+ data['label_list'] = label_list
+ data['pos_list'] = pos_list
+ data['pos_mask'] = pos_mask
+ return data
diff --git a/backend/ppocr/data/imaug/randaugment.py b/backend/ppocr/data/imaug/randaugment.py
new file mode 100644
index 00000000..56f114d2
--- /dev/null
+++ b/backend/ppocr/data/imaug/randaugment.py
@@ -0,0 +1,143 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+from PIL import Image, ImageEnhance, ImageOps
+import numpy as np
+import random
+import six
+
+
+class RawRandAugment(object):
+ def __init__(self,
+ num_layers=2,
+ magnitude=5,
+ fillcolor=(128, 128, 128),
+ **kwargs):
+ self.num_layers = num_layers
+ self.magnitude = magnitude
+ self.max_level = 10
+
+ abso_level = self.magnitude / self.max_level
+ self.level_map = {
+ "shearX": 0.3 * abso_level,
+ "shearY": 0.3 * abso_level,
+ "translateX": 150.0 / 331 * abso_level,
+ "translateY": 150.0 / 331 * abso_level,
+ "rotate": 30 * abso_level,
+ "color": 0.9 * abso_level,
+ "posterize": int(4.0 * abso_level),
+ "solarize": 256.0 * abso_level,
+ "contrast": 0.9 * abso_level,
+ "sharpness": 0.9 * abso_level,
+ "brightness": 0.9 * abso_level,
+ "autocontrast": 0,
+ "equalize": 0,
+ "invert": 0
+ }
+
+ # from https://stackoverflow.com/questions/5252170/
+ # specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
+ def rotate_with_fill(img, magnitude):
+ rot = img.convert("RGBA").rotate(magnitude)
+ return Image.composite(rot,
+ Image.new("RGBA", rot.size, (128, ) * 4),
+ rot).convert(img.mode)
+
+ rnd_ch_op = random.choice
+
+ self.func = {
+ "shearX": lambda img, magnitude: img.transform(
+ img.size,
+ Image.AFFINE,
+ (1, magnitude * rnd_ch_op([-1, 1]), 0, 0, 1, 0),
+ Image.BICUBIC,
+ fillcolor=fillcolor),
+ "shearY": lambda img, magnitude: img.transform(
+ img.size,
+ Image.AFFINE,
+ (1, 0, 0, magnitude * rnd_ch_op([-1, 1]), 1, 0),
+ Image.BICUBIC,
+ fillcolor=fillcolor),
+ "translateX": lambda img, magnitude: img.transform(
+ img.size,
+ Image.AFFINE,
+ (1, 0, magnitude * img.size[0] * rnd_ch_op([-1, 1]), 0, 1, 0),
+ fillcolor=fillcolor),
+ "translateY": lambda img, magnitude: img.transform(
+ img.size,
+ Image.AFFINE,
+ (1, 0, 0, 0, 1, magnitude * img.size[1] * rnd_ch_op([-1, 1])),
+ fillcolor=fillcolor),
+ "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
+ "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(
+ 1 + magnitude * rnd_ch_op([-1, 1])),
+ "posterize": lambda img, magnitude:
+ ImageOps.posterize(img, magnitude),
+ "solarize": lambda img, magnitude:
+ ImageOps.solarize(img, magnitude),
+ "contrast": lambda img, magnitude:
+ ImageEnhance.Contrast(img).enhance(
+ 1 + magnitude * rnd_ch_op([-1, 1])),
+ "sharpness": lambda img, magnitude:
+ ImageEnhance.Sharpness(img).enhance(
+ 1 + magnitude * rnd_ch_op([-1, 1])),
+ "brightness": lambda img, magnitude:
+ ImageEnhance.Brightness(img).enhance(
+ 1 + magnitude * rnd_ch_op([-1, 1])),
+ "autocontrast": lambda img, magnitude:
+ ImageOps.autocontrast(img),
+ "equalize": lambda img, magnitude: ImageOps.equalize(img),
+ "invert": lambda img, magnitude: ImageOps.invert(img)
+ }
+
+ def __call__(self, img):
+ avaiable_op_names = list(self.level_map.keys())
+ for layer_num in range(self.num_layers):
+ op_name = np.random.choice(avaiable_op_names)
+ img = self.func[op_name](img, self.level_map[op_name])
+ return img
+
+
+class RandAugment(RawRandAugment):
+ """ RandAugment wrapper to auto fit different img types """
+
+ def __init__(self, prob=0.5, *args, **kwargs):
+ self.prob = prob
+ if six.PY2:
+ super(RandAugment, self).__init__(*args, **kwargs)
+ else:
+ super().__init__(*args, **kwargs)
+
+ def __call__(self, data):
+ if np.random.rand() > self.prob:
+ return data
+ img = data['image']
+ if not isinstance(img, Image.Image):
+ img = np.ascontiguousarray(img)
+ img = Image.fromarray(img)
+
+ if six.PY2:
+ img = super(RandAugment, self).__call__(img)
+ else:
+ img = super().__call__(img)
+
+ if isinstance(img, Image.Image):
+ img = np.asarray(img)
+ data['image'] = img
+ return data
diff --git a/backend/ppocr/data/imaug/random_crop_data.py b/backend/ppocr/data/imaug/random_crop_data.py
new file mode 100644
index 00000000..64aa110d
--- /dev/null
+++ b/backend/ppocr/data/imaug/random_crop_data.py
@@ -0,0 +1,234 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/random_crop_data.py
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import numpy as np
+import cv2
+import random
+
+
+def is_poly_in_rect(poly, x, y, w, h):
+ poly = np.array(poly)
+ if poly[:, 0].min() < x or poly[:, 0].max() > x + w:
+ return False
+ if poly[:, 1].min() < y or poly[:, 1].max() > y + h:
+ return False
+ return True
+
+
+def is_poly_outside_rect(poly, x, y, w, h):
+ poly = np.array(poly)
+ if poly[:, 0].max() < x or poly[:, 0].min() > x + w:
+ return True
+ if poly[:, 1].max() < y or poly[:, 1].min() > y + h:
+ return True
+ return False
+
+
+def split_regions(axis):
+ regions = []
+ min_axis = 0
+ for i in range(1, axis.shape[0]):
+ if axis[i] != axis[i - 1] + 1:
+ region = axis[min_axis:i]
+ min_axis = i
+ regions.append(region)
+ return regions
+
+
+def random_select(axis, max_size):
+ xx = np.random.choice(axis, size=2)
+ xmin = np.min(xx)
+ xmax = np.max(xx)
+ xmin = np.clip(xmin, 0, max_size - 1)
+ xmax = np.clip(xmax, 0, max_size - 1)
+ return xmin, xmax
+
+
+def region_wise_random_select(regions, max_size):
+ selected_index = list(np.random.choice(len(regions), 2))
+ selected_values = []
+ for index in selected_index:
+ axis = regions[index]
+ xx = int(np.random.choice(axis, size=1))
+ selected_values.append(xx)
+ xmin = min(selected_values)
+ xmax = max(selected_values)
+ return xmin, xmax
+
+
+def crop_area(im, text_polys, min_crop_side_ratio, max_tries):
+ h, w, _ = im.shape
+ h_array = np.zeros(h, dtype=np.int32)
+ w_array = np.zeros(w, dtype=np.int32)
+ for points in text_polys:
+ points = np.round(points, decimals=0).astype(np.int32)
+ minx = np.min(points[:, 0])
+ maxx = np.max(points[:, 0])
+ w_array[minx:maxx] = 1
+ miny = np.min(points[:, 1])
+ maxy = np.max(points[:, 1])
+ h_array[miny:maxy] = 1
+ # ensure the cropped area not across a text
+ h_axis = np.where(h_array == 0)[0]
+ w_axis = np.where(w_array == 0)[0]
+
+ if len(h_axis) == 0 or len(w_axis) == 0:
+ return 0, 0, w, h
+
+ h_regions = split_regions(h_axis)
+ w_regions = split_regions(w_axis)
+
+ for i in range(max_tries):
+ if len(w_regions) > 1:
+ xmin, xmax = region_wise_random_select(w_regions, w)
+ else:
+ xmin, xmax = random_select(w_axis, w)
+ if len(h_regions) > 1:
+ ymin, ymax = region_wise_random_select(h_regions, h)
+ else:
+ ymin, ymax = random_select(h_axis, h)
+
+ if xmax - xmin < min_crop_side_ratio * w or ymax - ymin < min_crop_side_ratio * h:
+ # area too small
+ continue
+ num_poly_in_rect = 0
+ for poly in text_polys:
+ if not is_poly_outside_rect(poly, xmin, ymin, xmax - xmin,
+ ymax - ymin):
+ num_poly_in_rect += 1
+ break
+
+ if num_poly_in_rect > 0:
+ return xmin, ymin, xmax - xmin, ymax - ymin
+
+ return 0, 0, w, h
+
+
+class EastRandomCropData(object):
+ def __init__(self,
+ size=(640, 640),
+ max_tries=10,
+ min_crop_side_ratio=0.1,
+ keep_ratio=True,
+ **kwargs):
+ self.size = size
+ self.max_tries = max_tries
+ self.min_crop_side_ratio = min_crop_side_ratio
+ self.keep_ratio = keep_ratio
+
+ def __call__(self, data):
+ img = data['image']
+ text_polys = data['polys']
+ ignore_tags = data['ignore_tags']
+ texts = data['texts']
+ all_care_polys = [
+ text_polys[i] for i, tag in enumerate(ignore_tags) if not tag
+ ]
+ # 计算crop区域
+ crop_x, crop_y, crop_w, crop_h = crop_area(
+ img, all_care_polys, self.min_crop_side_ratio, self.max_tries)
+ # crop 图片 保持比例填充
+ scale_w = self.size[0] / crop_w
+ scale_h = self.size[1] / crop_h
+ scale = min(scale_w, scale_h)
+ h = int(crop_h * scale)
+ w = int(crop_w * scale)
+ if self.keep_ratio:
+ padimg = np.zeros((self.size[1], self.size[0], img.shape[2]),
+ img.dtype)
+ padimg[:h, :w] = cv2.resize(
+ img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h))
+ img = padimg
+ else:
+ img = cv2.resize(
+ img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w],
+ tuple(self.size))
+ # crop 文本框
+ text_polys_crop = []
+ ignore_tags_crop = []
+ texts_crop = []
+ for poly, text, tag in zip(text_polys, texts, ignore_tags):
+ poly = ((poly - (crop_x, crop_y)) * scale).tolist()
+ if not is_poly_outside_rect(poly, 0, 0, w, h):
+ text_polys_crop.append(poly)
+ ignore_tags_crop.append(tag)
+ texts_crop.append(text)
+ data['image'] = img
+ data['polys'] = np.array(text_polys_crop)
+ data['ignore_tags'] = ignore_tags_crop
+ data['texts'] = texts_crop
+ return data
+
+
+class RandomCropImgMask(object):
+ def __init__(self, size, main_key, crop_keys, p=3 / 8, **kwargs):
+ self.size = size
+ self.main_key = main_key
+ self.crop_keys = crop_keys
+ self.p = p
+
+ def __call__(self, data):
+ image = data['image']
+
+ h, w = image.shape[0:2]
+ th, tw = self.size
+ if w == tw and h == th:
+ return data
+
+ mask = data[self.main_key]
+ if np.max(mask) > 0 and random.random() > self.p:
+ # make sure to crop the text region
+ tl = np.min(np.where(mask > 0), axis=1) - (th, tw)
+ tl[tl < 0] = 0
+ br = np.max(np.where(mask > 0), axis=1) - (th, tw)
+ br[br < 0] = 0
+
+ br[0] = min(br[0], h - th)
+ br[1] = min(br[1], w - tw)
+
+ i = random.randint(tl[0], br[0]) if tl[0] < br[0] else 0
+ j = random.randint(tl[1], br[1]) if tl[1] < br[1] else 0
+ else:
+ i = random.randint(0, h - th) if h - th > 0 else 0
+ j = random.randint(0, w - tw) if w - tw > 0 else 0
+
+ # return i, j, th, tw
+ for k in data:
+ if k in self.crop_keys:
+ if len(data[k].shape) == 3:
+ if np.argmin(data[k].shape) == 0:
+ img = data[k][:, i:i + th, j:j + tw]
+ if img.shape[1] != img.shape[2]:
+ a = 1
+ elif np.argmin(data[k].shape) == 2:
+ img = data[k][i:i + th, j:j + tw, :]
+ if img.shape[1] != img.shape[0]:
+ a = 1
+ else:
+ img = data[k]
+ else:
+ img = data[k][i:i + th, j:j + tw]
+ if img.shape[0] != img.shape[1]:
+ a = 1
+ data[k] = img
+ return data
diff --git a/backend/ppocr/data/imaug/rec_img_aug.py b/backend/ppocr/data/imaug/rec_img_aug.py
new file mode 100644
index 00000000..7483dffe
--- /dev/null
+++ b/backend/ppocr/data/imaug/rec_img_aug.py
@@ -0,0 +1,601 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import cv2
+import numpy as np
+import random
+import copy
+from PIL import Image
+from .text_image_aug import tia_perspective, tia_stretch, tia_distort
+
+
+class RecAug(object):
+ def __init__(self, use_tia=True, aug_prob=0.4, **kwargs):
+ self.use_tia = use_tia
+ self.aug_prob = aug_prob
+
+ def __call__(self, data):
+ img = data['image']
+ img = warp(img, 10, self.use_tia, self.aug_prob)
+ data['image'] = img
+ return data
+
+
+class RecConAug(object):
+ def __init__(self,
+ prob=0.5,
+ image_shape=(32, 320, 3),
+ max_text_length=25,
+ ext_data_num=1,
+ **kwargs):
+ self.ext_data_num = ext_data_num
+ self.prob = prob
+ self.max_text_length = max_text_length
+ self.image_shape = image_shape
+ self.max_wh_ratio = self.image_shape[1] / self.image_shape[0]
+
+ def merge_ext_data(self, data, ext_data):
+ ori_w = round(data['image'].shape[1] / data['image'].shape[0] *
+ self.image_shape[0])
+ ext_w = round(ext_data['image'].shape[1] / ext_data['image'].shape[0] *
+ self.image_shape[0])
+ data['image'] = cv2.resize(data['image'], (ori_w, self.image_shape[0]))
+ ext_data['image'] = cv2.resize(ext_data['image'],
+ (ext_w, self.image_shape[0]))
+ data['image'] = np.concatenate(
+ [data['image'], ext_data['image']], axis=1)
+ data["label"] += ext_data["label"]
+ return data
+
+ def __call__(self, data):
+ rnd_num = random.random()
+ if rnd_num > self.prob:
+ return data
+ for idx, ext_data in enumerate(data["ext_data"]):
+ if len(data["label"]) + len(ext_data[
+ "label"]) > self.max_text_length:
+ break
+ concat_ratio = data['image'].shape[1] / data['image'].shape[
+ 0] + ext_data['image'].shape[1] / ext_data['image'].shape[0]
+ if concat_ratio > self.max_wh_ratio:
+ break
+ data = self.merge_ext_data(data, ext_data)
+ data.pop("ext_data")
+ return data
+
+
+class ClsResizeImg(object):
+ def __init__(self, image_shape, **kwargs):
+ self.image_shape = image_shape
+
+ def __call__(self, data):
+ img = data['image']
+ norm_img, _ = resize_norm_img(img, self.image_shape)
+ data['image'] = norm_img
+ return data
+
+
+class NRTRRecResizeImg(object):
+ def __init__(self, image_shape, resize_type, padding=False, **kwargs):
+ self.image_shape = image_shape
+ self.resize_type = resize_type
+ self.padding = padding
+
+ def __call__(self, data):
+ img = data['image']
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+ image_shape = self.image_shape
+ if self.padding:
+ imgC, imgH, imgW = image_shape
+ # todo: change to 0 and modified image shape
+ h = img.shape[0]
+ w = img.shape[1]
+ ratio = w / float(h)
+ if math.ceil(imgH * ratio) > imgW:
+ resized_w = imgW
+ else:
+ resized_w = int(math.ceil(imgH * ratio))
+ resized_image = cv2.resize(img, (resized_w, imgH))
+ norm_img = np.expand_dims(resized_image, -1)
+ norm_img = norm_img.transpose((2, 0, 1))
+ resized_image = norm_img.astype(np.float32) / 128. - 1.
+ padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
+ padding_im[:, :, 0:resized_w] = resized_image
+ data['image'] = padding_im
+ return data
+ if self.resize_type == 'PIL':
+ image_pil = Image.fromarray(np.uint8(img))
+ img = image_pil.resize(self.image_shape, Image.ANTIALIAS)
+ img = np.array(img)
+ if self.resize_type == 'OpenCV':
+ img = cv2.resize(img, self.image_shape)
+ norm_img = np.expand_dims(img, -1)
+ norm_img = norm_img.transpose((2, 0, 1))
+ data['image'] = norm_img.astype(np.float32) / 128. - 1.
+ return data
+
+
+class RecResizeImg(object):
+ def __init__(self,
+ image_shape,
+ infer_mode=False,
+ character_dict_path='./ppocr/utils/ppocr_keys_v1.txt',
+ padding=True,
+ **kwargs):
+ self.image_shape = image_shape
+ self.infer_mode = infer_mode
+ self.character_dict_path = character_dict_path
+ self.padding = padding
+
+ def __call__(self, data):
+ img = data['image']
+ if self.infer_mode and self.character_dict_path is not None:
+ norm_img, valid_ratio = resize_norm_img_chinese(img,
+ self.image_shape)
+ else:
+ norm_img, valid_ratio = resize_norm_img(img, self.image_shape,
+ self.padding)
+ data['image'] = norm_img
+ data['valid_ratio'] = valid_ratio
+ return data
+
+
+class SRNRecResizeImg(object):
+ def __init__(self, image_shape, num_heads, max_text_length, **kwargs):
+ self.image_shape = image_shape
+ self.num_heads = num_heads
+ self.max_text_length = max_text_length
+
+ def __call__(self, data):
+ img = data['image']
+ norm_img = resize_norm_img_srn(img, self.image_shape)
+ data['image'] = norm_img
+ [encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
+ srn_other_inputs(self.image_shape, self.num_heads, self.max_text_length)
+
+ data['encoder_word_pos'] = encoder_word_pos
+ data['gsrm_word_pos'] = gsrm_word_pos
+ data['gsrm_slf_attn_bias1'] = gsrm_slf_attn_bias1
+ data['gsrm_slf_attn_bias2'] = gsrm_slf_attn_bias2
+ return data
+
+
+class SARRecResizeImg(object):
+ def __init__(self, image_shape, width_downsample_ratio=0.25, **kwargs):
+ self.image_shape = image_shape
+ self.width_downsample_ratio = width_downsample_ratio
+
+ def __call__(self, data):
+ img = data['image']
+ norm_img, resize_shape, pad_shape, valid_ratio = resize_norm_img_sar(
+ img, self.image_shape, self.width_downsample_ratio)
+ data['image'] = norm_img
+ data['resized_shape'] = resize_shape
+ data['pad_shape'] = pad_shape
+ data['valid_ratio'] = valid_ratio
+ return data
+
+
+class PRENResizeImg(object):
+ def __init__(self, image_shape, **kwargs):
+ """
+ Accroding to original paper's realization, it's a hard resize method here.
+ So maybe you should optimize it to fit for your task better.
+ """
+ self.dst_h, self.dst_w = image_shape
+
+ def __call__(self, data):
+ img = data['image']
+ resized_img = cv2.resize(
+ img, (self.dst_w, self.dst_h), interpolation=cv2.INTER_LINEAR)
+ resized_img = resized_img.transpose((2, 0, 1)) / 255
+ resized_img -= 0.5
+ resized_img /= 0.5
+ data['image'] = resized_img.astype(np.float32)
+ return data
+
+
+def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
+ imgC, imgH, imgW_min, imgW_max = image_shape
+ h = img.shape[0]
+ w = img.shape[1]
+ valid_ratio = 1.0
+ # make sure new_width is an integral multiple of width_divisor.
+ width_divisor = int(1 / width_downsample_ratio)
+ # resize
+ ratio = w / float(h)
+ resize_w = math.ceil(imgH * ratio)
+ if resize_w % width_divisor != 0:
+ resize_w = round(resize_w / width_divisor) * width_divisor
+ if imgW_min is not None:
+ resize_w = max(imgW_min, resize_w)
+ if imgW_max is not None:
+ valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
+ resize_w = min(imgW_max, resize_w)
+ resized_image = cv2.resize(img, (resize_w, imgH))
+ resized_image = resized_image.astype('float32')
+ # norm
+ if image_shape[0] == 1:
+ resized_image = resized_image / 255
+ resized_image = resized_image[np.newaxis, :]
+ else:
+ resized_image = resized_image.transpose((2, 0, 1)) / 255
+ resized_image -= 0.5
+ resized_image /= 0.5
+ resize_shape = resized_image.shape
+ padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
+ padding_im[:, :, 0:resize_w] = resized_image
+ pad_shape = padding_im.shape
+
+ return padding_im, resize_shape, pad_shape, valid_ratio
+
+
+def resize_norm_img(img, image_shape, padding=True):
+ imgC, imgH, imgW = image_shape
+ h = img.shape[0]
+ w = img.shape[1]
+ if not padding:
+ resized_image = cv2.resize(
+ img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
+ resized_w = imgW
+ else:
+ ratio = w / float(h)
+ if math.ceil(imgH * ratio) > imgW:
+ resized_w = imgW
+ else:
+ resized_w = int(math.ceil(imgH * ratio))
+ resized_image = cv2.resize(img, (resized_w, imgH))
+ resized_image = resized_image.astype('float32')
+ if image_shape[0] == 1:
+ resized_image = resized_image / 255
+ resized_image = resized_image[np.newaxis, :]
+ else:
+ resized_image = resized_image.transpose((2, 0, 1)) / 255
+ resized_image -= 0.5
+ resized_image /= 0.5
+ padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
+ padding_im[:, :, 0:resized_w] = resized_image
+ valid_ratio = min(1.0, float(resized_w / imgW))
+ return padding_im, valid_ratio
+
+
+def resize_norm_img_chinese(img, image_shape):
+ imgC, imgH, imgW = image_shape
+ # todo: change to 0 and modified image shape
+ max_wh_ratio = imgW * 1.0 / imgH
+ h, w = img.shape[0], img.shape[1]
+ ratio = w * 1.0 / h
+ max_wh_ratio = max(max_wh_ratio, ratio)
+ imgW = int(imgH * max_wh_ratio)
+ if math.ceil(imgH * ratio) > imgW:
+ resized_w = imgW
+ else:
+ resized_w = int(math.ceil(imgH * ratio))
+ resized_image = cv2.resize(img, (resized_w, imgH))
+ resized_image = resized_image.astype('float32')
+ if image_shape[0] == 1:
+ resized_image = resized_image / 255
+ resized_image = resized_image[np.newaxis, :]
+ else:
+ resized_image = resized_image.transpose((2, 0, 1)) / 255
+ resized_image -= 0.5
+ resized_image /= 0.5
+ padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
+ padding_im[:, :, 0:resized_w] = resized_image
+ valid_ratio = min(1.0, float(resized_w / imgW))
+ return padding_im, valid_ratio
+
+
+def resize_norm_img_srn(img, image_shape):
+ imgC, imgH, imgW = image_shape
+
+ img_black = np.zeros((imgH, imgW))
+ im_hei = img.shape[0]
+ im_wid = img.shape[1]
+
+ if im_wid <= im_hei * 1:
+ img_new = cv2.resize(img, (imgH * 1, imgH))
+ elif im_wid <= im_hei * 2:
+ img_new = cv2.resize(img, (imgH * 2, imgH))
+ elif im_wid <= im_hei * 3:
+ img_new = cv2.resize(img, (imgH * 3, imgH))
+ else:
+ img_new = cv2.resize(img, (imgW, imgH))
+
+ img_np = np.asarray(img_new)
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
+ img_black[:, 0:img_np.shape[1]] = img_np
+ img_black = img_black[:, :, np.newaxis]
+
+ row, col, c = img_black.shape
+ c = 1
+
+ return np.reshape(img_black, (c, row, col)).astype(np.float32)
+
+
+def srn_other_inputs(image_shape, num_heads, max_text_length):
+
+ imgC, imgH, imgW = image_shape
+ feature_dim = int((imgH / 8) * (imgW / 8))
+
+ encoder_word_pos = np.array(range(0, feature_dim)).reshape(
+ (feature_dim, 1)).astype('int64')
+ gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
+ (max_text_length, 1)).astype('int64')
+
+ gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
+ gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
+ [1, max_text_length, max_text_length])
+ gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1,
+ [num_heads, 1, 1]) * [-1e9]
+
+ gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
+ [1, max_text_length, max_text_length])
+ gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2,
+ [num_heads, 1, 1]) * [-1e9]
+
+ return [
+ encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
+ gsrm_slf_attn_bias2
+ ]
+
+
+def flag():
+ """
+ flag
+ """
+ return 1 if random.random() > 0.5000001 else -1
+
+
+def cvtColor(img):
+ """
+ cvtColor
+ """
+ hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
+ delta = 0.001 * random.random() * flag()
+ hsv[:, :, 2] = hsv[:, :, 2] * (1 + delta)
+ new_img = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
+ return new_img
+
+
+def blur(img):
+ """
+ blur
+ """
+ h, w, _ = img.shape
+ if h > 10 and w > 10:
+ return cv2.GaussianBlur(img, (5, 5), 1)
+ else:
+ return img
+
+
+def jitter(img):
+ """
+ jitter
+ """
+ w, h, _ = img.shape
+ if h > 10 and w > 10:
+ thres = min(w, h)
+ s = int(random.random() * thres * 0.01)
+ src_img = img.copy()
+ for i in range(s):
+ img[i:, i:, :] = src_img[:w - i, :h - i, :]
+ return img
+ else:
+ return img
+
+
+def add_gasuss_noise(image, mean=0, var=0.1):
+ """
+ Gasuss noise
+ """
+
+ noise = np.random.normal(mean, var**0.5, image.shape)
+ out = image + 0.5 * noise
+ out = np.clip(out, 0, 255)
+ out = np.uint8(out)
+ return out
+
+
+def get_crop(image):
+ """
+ random crop
+ """
+ h, w, _ = image.shape
+ top_min = 1
+ top_max = 8
+ top_crop = int(random.randint(top_min, top_max))
+ top_crop = min(top_crop, h - 1)
+ crop_img = image.copy()
+ ratio = random.randint(0, 1)
+ if ratio:
+ crop_img = crop_img[top_crop:h, :, :]
+ else:
+ crop_img = crop_img[0:h - top_crop, :, :]
+ return crop_img
+
+
+class Config:
+ """
+ Config
+ """
+
+ def __init__(self, use_tia):
+ self.anglex = random.random() * 30
+ self.angley = random.random() * 15
+ self.anglez = random.random() * 10
+ self.fov = 42
+ self.r = 0
+ self.shearx = random.random() * 0.3
+ self.sheary = random.random() * 0.05
+ self.borderMode = cv2.BORDER_REPLICATE
+ self.use_tia = use_tia
+
+ def make(self, w, h, ang):
+ """
+ make
+ """
+ self.anglex = random.random() * 5 * flag()
+ self.angley = random.random() * 5 * flag()
+ self.anglez = -1 * random.random() * int(ang) * flag()
+ self.fov = 42
+ self.r = 0
+ self.shearx = 0
+ self.sheary = 0
+ self.borderMode = cv2.BORDER_REPLICATE
+ self.w = w
+ self.h = h
+
+ self.perspective = self.use_tia
+ self.stretch = self.use_tia
+ self.distort = self.use_tia
+
+ self.crop = True
+ self.affine = False
+ self.reverse = True
+ self.noise = True
+ self.jitter = True
+ self.blur = True
+ self.color = True
+
+
+def rad(x):
+ """
+ rad
+ """
+ return x * np.pi / 180
+
+
+def get_warpR(config):
+ """
+ get_warpR
+ """
+ anglex, angley, anglez, fov, w, h, r = \
+ config.anglex, config.angley, config.anglez, config.fov, config.w, config.h, config.r
+ if w > 69 and w < 112:
+ anglex = anglex * 1.5
+
+ z = np.sqrt(w**2 + h**2) / 2 / np.tan(rad(fov / 2))
+ # Homogeneous coordinate transformation matrix
+ rx = np.array([[1, 0, 0, 0],
+ [0, np.cos(rad(anglex)), -np.sin(rad(anglex)), 0], [
+ 0,
+ -np.sin(rad(anglex)),
+ np.cos(rad(anglex)),
+ 0,
+ ], [0, 0, 0, 1]], np.float32)
+ ry = np.array([[np.cos(rad(angley)), 0, np.sin(rad(angley)), 0],
+ [0, 1, 0, 0], [
+ -np.sin(rad(angley)),
+ 0,
+ np.cos(rad(angley)),
+ 0,
+ ], [0, 0, 0, 1]], np.float32)
+ rz = np.array([[np.cos(rad(anglez)), np.sin(rad(anglez)), 0, 0],
+ [-np.sin(rad(anglez)), np.cos(rad(anglez)), 0, 0],
+ [0, 0, 1, 0], [0, 0, 0, 1]], np.float32)
+ r = rx.dot(ry).dot(rz)
+ # generate 4 points
+ pcenter = np.array([h / 2, w / 2, 0, 0], np.float32)
+ p1 = np.array([0, 0, 0, 0], np.float32) - pcenter
+ p2 = np.array([w, 0, 0, 0], np.float32) - pcenter
+ p3 = np.array([0, h, 0, 0], np.float32) - pcenter
+ p4 = np.array([w, h, 0, 0], np.float32) - pcenter
+ dst1 = r.dot(p1)
+ dst2 = r.dot(p2)
+ dst3 = r.dot(p3)
+ dst4 = r.dot(p4)
+ list_dst = np.array([dst1, dst2, dst3, dst4])
+ org = np.array([[0, 0], [w, 0], [0, h], [w, h]], np.float32)
+ dst = np.zeros((4, 2), np.float32)
+ # Project onto the image plane
+ dst[:, 0] = list_dst[:, 0] * z / (z - list_dst[:, 2]) + pcenter[0]
+ dst[:, 1] = list_dst[:, 1] * z / (z - list_dst[:, 2]) + pcenter[1]
+
+ warpR = cv2.getPerspectiveTransform(org, dst)
+
+ dst1, dst2, dst3, dst4 = dst
+ r1 = int(min(dst1[1], dst2[1]))
+ r2 = int(max(dst3[1], dst4[1]))
+ c1 = int(min(dst1[0], dst3[0]))
+ c2 = int(max(dst2[0], dst4[0]))
+
+ try:
+ ratio = min(1.0 * h / (r2 - r1), 1.0 * w / (c2 - c1))
+
+ dx = -c1
+ dy = -r1
+ T1 = np.float32([[1., 0, dx], [0, 1., dy], [0, 0, 1.0 / ratio]])
+ ret = T1.dot(warpR)
+ except:
+ ratio = 1.0
+ T1 = np.float32([[1., 0, 0], [0, 1., 0], [0, 0, 1.]])
+ ret = T1
+ return ret, (-r1, -c1), ratio, dst
+
+
+def get_warpAffine(config):
+ """
+ get_warpAffine
+ """
+ anglez = config.anglez
+ rz = np.array([[np.cos(rad(anglez)), np.sin(rad(anglez)), 0],
+ [-np.sin(rad(anglez)), np.cos(rad(anglez)), 0]], np.float32)
+ return rz
+
+
+def warp(img, ang, use_tia=True, prob=0.4):
+ """
+ warp
+ """
+ h, w, _ = img.shape
+ config = Config(use_tia=use_tia)
+ config.make(w, h, ang)
+ new_img = img
+
+ if config.distort:
+ img_height, img_width = img.shape[0:2]
+ if random.random() <= prob and img_height >= 20 and img_width >= 20:
+ new_img = tia_distort(new_img, random.randint(3, 6))
+
+ if config.stretch:
+ img_height, img_width = img.shape[0:2]
+ if random.random() <= prob and img_height >= 20 and img_width >= 20:
+ new_img = tia_stretch(new_img, random.randint(3, 6))
+
+ if config.perspective:
+ if random.random() <= prob:
+ new_img = tia_perspective(new_img)
+
+ if config.crop:
+ img_height, img_width = img.shape[0:2]
+ if random.random() <= prob and img_height >= 20 and img_width >= 20:
+ new_img = get_crop(new_img)
+
+ if config.blur:
+ if random.random() <= prob:
+ new_img = blur(new_img)
+ if config.color:
+ if random.random() <= prob:
+ new_img = cvtColor(new_img)
+ if config.jitter:
+ new_img = jitter(new_img)
+ if config.noise:
+ if random.random() <= prob:
+ new_img = add_gasuss_noise(new_img)
+ if config.reverse:
+ if random.random() <= prob:
+ new_img = 255 - new_img
+ return new_img
diff --git a/backend/ppocr/data/imaug/sast_process.py b/backend/ppocr/data/imaug/sast_process.py
index 1536dceb..08d03b19 100644
--- a/backend/ppocr/data/imaug/sast_process.py
+++ b/backend/ppocr/data/imaug/sast_process.py
@@ -11,7 +11,10 @@
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
-
+"""
+This part code is refered from:
+https://github.com/songdejia/EAST/blob/master/data_utils.py
+"""
import math
import cv2
import numpy as np
diff --git a/backend/ppocr/data/imaug/ssl_img_aug.py b/backend/ppocr/data/imaug/ssl_img_aug.py
new file mode 100644
index 00000000..f9ed6ac3
--- /dev/null
+++ b/backend/ppocr/data/imaug/ssl_img_aug.py
@@ -0,0 +1,60 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import cv2
+import numpy as np
+import random
+from PIL import Image
+
+from .rec_img_aug import resize_norm_img
+
+
+class SSLRotateResize(object):
+ def __init__(self,
+ image_shape,
+ padding=False,
+ select_all=True,
+ mode="train",
+ **kwargs):
+ self.image_shape = image_shape
+ self.padding = padding
+ self.select_all = select_all
+ self.mode = mode
+
+ def __call__(self, data):
+ img = data["image"]
+
+ data["image_r90"] = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
+ data["image_r180"] = cv2.rotate(data["image_r90"],
+ cv2.ROTATE_90_CLOCKWISE)
+ data["image_r270"] = cv2.rotate(data["image_r180"],
+ cv2.ROTATE_90_CLOCKWISE)
+
+ images = []
+ for key in ["image", "image_r90", "image_r180", "image_r270"]:
+ images.append(
+ resize_norm_img(
+ data.pop(key),
+ image_shape=self.image_shape,
+ padding=self.padding)[0])
+ data["image"] = np.stack(images, axis=0)
+ data["label"] = np.array(list(range(4)))
+ if not self.select_all:
+ data["image"] = data["image"][0::2] # just choose 0 and 180
+ data["label"] = data["label"][0:2] # label needs to be continuous
+ if self.mode == "test":
+ data["image"] = data["image"][0]
+ data["label"] = data["label"][0]
+ return data
diff --git a/backend/ppocr/data/imaug/text_image_aug/augment.py b/backend/ppocr/data/imaug/text_image_aug/augment.py
index 1aeff373..2d15dd5f 100644
--- a/backend/ppocr/data/imaug/text_image_aug/augment.py
+++ b/backend/ppocr/data/imaug/text_image_aug/augment.py
@@ -11,6 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+"""
+This code is refer from:
+https://github.com/RubanSeven/Text-Image-Augmentation-python/blob/master/augment.py
+"""
import numpy as np
from .warp_mls import WarpMLS
diff --git a/backend/ppocr/data/imaug/text_image_aug/warp_mls.py b/backend/ppocr/data/imaug/text_image_aug/warp_mls.py
index d6cbe749..75de1111 100644
--- a/backend/ppocr/data/imaug/text_image_aug/warp_mls.py
+++ b/backend/ppocr/data/imaug/text_image_aug/warp_mls.py
@@ -11,6 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+"""
+This code is refer from:
+https://github.com/RubanSeven/Text-Image-Augmentation-python/blob/master/warp_mls.py
+"""
import numpy as np
@@ -161,4 +165,4 @@ def gen_img(self):
dst = np.clip(dst, 0, 255)
dst = np.array(dst, dtype=np.uint8)
- return dst
\ No newline at end of file
+ return dst
diff --git a/backend/ppocr/data/imaug/vqa/__init__.py b/backend/ppocr/data/imaug/vqa/__init__.py
new file mode 100644
index 00000000..a5025e79
--- /dev/null
+++ b/backend/ppocr/data/imaug/vqa/__init__.py
@@ -0,0 +1,19 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .token import VQATokenPad, VQASerTokenChunk, VQAReTokenChunk, VQAReTokenRelation
+
+__all__ = [
+ 'VQATokenPad', 'VQASerTokenChunk', 'VQAReTokenChunk', 'VQAReTokenRelation'
+]
diff --git a/backend/ppocr/data/imaug/vqa/token/__init__.py b/backend/ppocr/data/imaug/vqa/token/__init__.py
new file mode 100644
index 00000000..7c115661
--- /dev/null
+++ b/backend/ppocr/data/imaug/vqa/token/__init__.py
@@ -0,0 +1,17 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .vqa_token_chunk import VQASerTokenChunk, VQAReTokenChunk
+from .vqa_token_pad import VQATokenPad
+from .vqa_token_relation import VQAReTokenRelation
diff --git a/backend/ppocr/data/imaug/vqa/token/vqa_token_chunk.py b/backend/ppocr/data/imaug/vqa/token/vqa_token_chunk.py
new file mode 100644
index 00000000..1fa949e6
--- /dev/null
+++ b/backend/ppocr/data/imaug/vqa/token/vqa_token_chunk.py
@@ -0,0 +1,122 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections import defaultdict
+
+
+class VQASerTokenChunk(object):
+ def __init__(self, max_seq_len=512, infer_mode=False, **kwargs):
+ self.max_seq_len = max_seq_len
+ self.infer_mode = infer_mode
+
+ def __call__(self, data):
+ encoded_inputs_all = []
+ seq_len = len(data['input_ids'])
+ for index in range(0, seq_len, self.max_seq_len):
+ chunk_beg = index
+ chunk_end = min(index + self.max_seq_len, seq_len)
+ encoded_inputs_example = {}
+ for key in data:
+ if key in [
+ 'label', 'input_ids', 'labels', 'token_type_ids',
+ 'bbox', 'attention_mask'
+ ]:
+ if self.infer_mode and key == 'labels':
+ encoded_inputs_example[key] = data[key]
+ else:
+ encoded_inputs_example[key] = data[key][chunk_beg:
+ chunk_end]
+ else:
+ encoded_inputs_example[key] = data[key]
+
+ encoded_inputs_all.append(encoded_inputs_example)
+ if len(encoded_inputs_all) == 0:
+ return None
+ return encoded_inputs_all[0]
+
+
+class VQAReTokenChunk(object):
+ def __init__(self,
+ max_seq_len=512,
+ entities_labels=None,
+ infer_mode=False,
+ **kwargs):
+ self.max_seq_len = max_seq_len
+ self.entities_labels = {
+ 'HEADER': 0,
+ 'QUESTION': 1,
+ 'ANSWER': 2
+ } if entities_labels is None else entities_labels
+ self.infer_mode = infer_mode
+
+ def __call__(self, data):
+ # prepare data
+ entities = data.pop('entities')
+ relations = data.pop('relations')
+ encoded_inputs_all = []
+ for index in range(0, len(data["input_ids"]), self.max_seq_len):
+ item = {}
+ for key in data:
+ if key in [
+ 'label', 'input_ids', 'labels', 'token_type_ids',
+ 'bbox', 'attention_mask'
+ ]:
+ if self.infer_mode and key == 'labels':
+ item[key] = data[key]
+ else:
+ item[key] = data[key][index:index + self.max_seq_len]
+ else:
+ item[key] = data[key]
+ # select entity in current chunk
+ entities_in_this_span = []
+ global_to_local_map = {} #
+ for entity_id, entity in enumerate(entities):
+ if (index <= entity["start"] < index + self.max_seq_len and
+ index <= entity["end"] < index + self.max_seq_len):
+ entity["start"] = entity["start"] - index
+ entity["end"] = entity["end"] - index
+ global_to_local_map[entity_id] = len(entities_in_this_span)
+ entities_in_this_span.append(entity)
+
+ # select relations in current chunk
+ relations_in_this_span = []
+ for relation in relations:
+ if (index <= relation["start_index"] < index + self.max_seq_len
+ and index <= relation["end_index"] <
+ index + self.max_seq_len):
+ relations_in_this_span.append({
+ "head": global_to_local_map[relation["head"]],
+ "tail": global_to_local_map[relation["tail"]],
+ "start_index": relation["start_index"] - index,
+ "end_index": relation["end_index"] - index,
+ })
+ item.update({
+ "entities": self.reformat(entities_in_this_span),
+ "relations": self.reformat(relations_in_this_span),
+ })
+ if len(item['entities']) > 0:
+ item['entities']['label'] = [
+ self.entities_labels[x] for x in item['entities']['label']
+ ]
+ encoded_inputs_all.append(item)
+ if len(encoded_inputs_all) == 0:
+ return None
+ return encoded_inputs_all[0]
+
+ def reformat(self, data):
+ new_data = defaultdict(list)
+ for item in data:
+ for k, v in item.items():
+ new_data[k].append(v)
+ return new_data
diff --git a/backend/ppocr/data/imaug/vqa/token/vqa_token_pad.py b/backend/ppocr/data/imaug/vqa/token/vqa_token_pad.py
new file mode 100644
index 00000000..8e5a20f9
--- /dev/null
+++ b/backend/ppocr/data/imaug/vqa/token/vqa_token_pad.py
@@ -0,0 +1,104 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import paddle
+import numpy as np
+
+
+class VQATokenPad(object):
+ def __init__(self,
+ max_seq_len=512,
+ pad_to_max_seq_len=True,
+ return_attention_mask=True,
+ return_token_type_ids=True,
+ truncation_strategy="longest_first",
+ return_overflowing_tokens=False,
+ return_special_tokens_mask=False,
+ infer_mode=False,
+ **kwargs):
+ self.max_seq_len = max_seq_len
+ self.pad_to_max_seq_len = max_seq_len
+ self.return_attention_mask = return_attention_mask
+ self.return_token_type_ids = return_token_type_ids
+ self.truncation_strategy = truncation_strategy
+ self.return_overflowing_tokens = return_overflowing_tokens
+ self.return_special_tokens_mask = return_special_tokens_mask
+ self.pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
+ self.infer_mode = infer_mode
+
+ def __call__(self, data):
+ needs_to_be_padded = self.pad_to_max_seq_len and len(data[
+ "input_ids"]) < self.max_seq_len
+
+ if needs_to_be_padded:
+ if 'tokenizer_params' in data:
+ tokenizer_params = data.pop('tokenizer_params')
+ else:
+ tokenizer_params = dict(
+ padding_side='right', pad_token_type_id=0, pad_token_id=1)
+
+ difference = self.max_seq_len - len(data["input_ids"])
+ if tokenizer_params['padding_side'] == 'right':
+ if self.return_attention_mask:
+ data["attention_mask"] = [1] * len(data[
+ "input_ids"]) + [0] * difference
+ if self.return_token_type_ids:
+ data["token_type_ids"] = (
+ data["token_type_ids"] +
+ [tokenizer_params['pad_token_type_id']] * difference)
+ if self.return_special_tokens_mask:
+ data["special_tokens_mask"] = data[
+ "special_tokens_mask"] + [1] * difference
+ data["input_ids"] = data["input_ids"] + [
+ tokenizer_params['pad_token_id']
+ ] * difference
+ if not self.infer_mode:
+ data["labels"] = data[
+ "labels"] + [self.pad_token_label_id] * difference
+ data["bbox"] = data["bbox"] + [[0, 0, 0, 0]] * difference
+ elif tokenizer_params['padding_side'] == 'left':
+ if self.return_attention_mask:
+ data["attention_mask"] = [0] * difference + [
+ 1
+ ] * len(data["input_ids"])
+ if self.return_token_type_ids:
+ data["token_type_ids"] = (
+ [tokenizer_params['pad_token_type_id']] * difference +
+ data["token_type_ids"])
+ if self.return_special_tokens_mask:
+ data["special_tokens_mask"] = [
+ 1
+ ] * difference + data["special_tokens_mask"]
+ data["input_ids"] = [tokenizer_params['pad_token_id']
+ ] * difference + data["input_ids"]
+ if not self.infer_mode:
+ data["labels"] = [self.pad_token_label_id
+ ] * difference + data["labels"]
+ data["bbox"] = [[0, 0, 0, 0]] * difference + data["bbox"]
+ else:
+ if self.return_attention_mask:
+ data["attention_mask"] = [1] * len(data["input_ids"])
+
+ for key in data:
+ if key in [
+ 'input_ids', 'labels', 'token_type_ids', 'bbox',
+ 'attention_mask'
+ ]:
+ if self.infer_mode:
+ if key != 'labels':
+ length = min(len(data[key]), self.max_seq_len)
+ data[key] = data[key][:length]
+ else:
+ continue
+ data[key] = np.array(data[key], dtype='int64')
+ return data
diff --git a/backend/ppocr/data/imaug/vqa/token/vqa_token_relation.py b/backend/ppocr/data/imaug/vqa/token/vqa_token_relation.py
new file mode 100644
index 00000000..293988ff
--- /dev/null
+++ b/backend/ppocr/data/imaug/vqa/token/vqa_token_relation.py
@@ -0,0 +1,67 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+class VQAReTokenRelation(object):
+ def __init__(self, **kwargs):
+ pass
+
+ def __call__(self, data):
+ """
+ build relations
+ """
+ entities = data['entities']
+ relations = data['relations']
+ id2label = data.pop('id2label')
+ empty_entity = data.pop('empty_entity')
+ entity_id_to_index_map = data.pop('entity_id_to_index_map')
+
+ relations = list(set(relations))
+ relations = [
+ rel for rel in relations
+ if rel[0] not in empty_entity and rel[1] not in empty_entity
+ ]
+ kv_relations = []
+ for rel in relations:
+ pair = [id2label[rel[0]], id2label[rel[1]]]
+ if pair == ["question", "answer"]:
+ kv_relations.append({
+ "head": entity_id_to_index_map[rel[0]],
+ "tail": entity_id_to_index_map[rel[1]]
+ })
+ elif pair == ["answer", "question"]:
+ kv_relations.append({
+ "head": entity_id_to_index_map[rel[1]],
+ "tail": entity_id_to_index_map[rel[0]]
+ })
+ else:
+ continue
+ relations = sorted(
+ [{
+ "head": rel["head"],
+ "tail": rel["tail"],
+ "start_index": self.get_relation_span(rel, entities)[0],
+ "end_index": self.get_relation_span(rel, entities)[1],
+ } for rel in kv_relations],
+ key=lambda x: x["head"], )
+
+ data['relations'] = relations
+ return data
+
+ def get_relation_span(self, rel, entities):
+ bound = []
+ for entity_index in [rel["head"], rel["tail"]]:
+ bound.append(entities[entity_index]["start"])
+ bound.append(entities[entity_index]["end"])
+ return min(bound), max(bound)
diff --git a/backend/ppocr/data/lmdb_dataset.py b/backend/ppocr/data/lmdb_dataset.py
new file mode 100644
index 00000000..e1b49809
--- /dev/null
+++ b/backend/ppocr/data/lmdb_dataset.py
@@ -0,0 +1,118 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import numpy as np
+import os
+from paddle.io import Dataset
+import lmdb
+import cv2
+
+from .imaug import transform, create_operators
+
+
+class LMDBDataSet(Dataset):
+ def __init__(self, config, mode, logger, seed=None):
+ super(LMDBDataSet, self).__init__()
+
+ global_config = config['Global']
+ dataset_config = config[mode]['dataset']
+ loader_config = config[mode]['loader']
+ batch_size = loader_config['batch_size_per_card']
+ data_dir = dataset_config['data_dir']
+ self.do_shuffle = loader_config['shuffle']
+
+ self.lmdb_sets = self.load_hierarchical_lmdb_dataset(data_dir)
+ logger.info("Initialize indexs of datasets:%s" % data_dir)
+ self.data_idx_order_list = self.dataset_traversal()
+ if self.do_shuffle:
+ np.random.shuffle(self.data_idx_order_list)
+ self.ops = create_operators(dataset_config['transforms'], global_config)
+
+ ratio_list = dataset_config.get("ratio_list", [1.0])
+ self.need_reset = True in [x < 1 for x in ratio_list]
+
+ def load_hierarchical_lmdb_dataset(self, data_dir):
+ lmdb_sets = {}
+ dataset_idx = 0
+ for dirpath, dirnames, filenames in os.walk(data_dir + '/'):
+ if not dirnames:
+ env = lmdb.open(
+ dirpath,
+ max_readers=32,
+ readonly=True,
+ lock=False,
+ readahead=False,
+ meminit=False)
+ txn = env.begin(write=False)
+ num_samples = int(txn.get('num-samples'.encode()))
+ lmdb_sets[dataset_idx] = {"dirpath":dirpath, "env":env, \
+ "txn":txn, "num_samples":num_samples}
+ dataset_idx += 1
+ return lmdb_sets
+
+ def dataset_traversal(self):
+ lmdb_num = len(self.lmdb_sets)
+ total_sample_num = 0
+ for lno in range(lmdb_num):
+ total_sample_num += self.lmdb_sets[lno]['num_samples']
+ data_idx_order_list = np.zeros((total_sample_num, 2))
+ beg_idx = 0
+ for lno in range(lmdb_num):
+ tmp_sample_num = self.lmdb_sets[lno]['num_samples']
+ end_idx = beg_idx + tmp_sample_num
+ data_idx_order_list[beg_idx:end_idx, 0] = lno
+ data_idx_order_list[beg_idx:end_idx, 1] \
+ = list(range(tmp_sample_num))
+ data_idx_order_list[beg_idx:end_idx, 1] += 1
+ beg_idx = beg_idx + tmp_sample_num
+ return data_idx_order_list
+
+ def get_img_data(self, value):
+ """get_img_data"""
+ if not value:
+ return None
+ imgdata = np.frombuffer(value, dtype='uint8')
+ if imgdata is None:
+ return None
+ imgori = cv2.imdecode(imgdata, 1)
+ if imgori is None:
+ return None
+ return imgori
+
+ def get_lmdb_sample_info(self, txn, index):
+ label_key = 'label-%09d'.encode() % index
+ label = txn.get(label_key)
+ if label is None:
+ return None
+ label = label.decode('utf-8')
+ img_key = 'image-%09d'.encode() % index
+ imgbuf = txn.get(img_key)
+ return imgbuf, label
+
+ def __getitem__(self, idx):
+ lmdb_idx, file_idx = self.data_idx_order_list[idx]
+ lmdb_idx = int(lmdb_idx)
+ file_idx = int(file_idx)
+ sample_info = self.get_lmdb_sample_info(self.lmdb_sets[lmdb_idx]['txn'],
+ file_idx)
+ if sample_info is None:
+ return self.__getitem__(np.random.randint(self.__len__()))
+ img, label = sample_info
+ data = {'image': img, 'label': label}
+ outs = transform(data, self.ops)
+ if outs is None:
+ return self.__getitem__(np.random.randint(self.__len__()))
+ return outs
+
+ def __len__(self):
+ return self.data_idx_order_list.shape[0]
diff --git a/backend/ppocr/data/pgnet_dataset.py b/backend/ppocr/data/pgnet_dataset.py
new file mode 100644
index 00000000..6f80179c
--- /dev/null
+++ b/backend/ppocr/data/pgnet_dataset.py
@@ -0,0 +1,106 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import numpy as np
+import os
+from paddle.io import Dataset
+from .imaug import transform, create_operators
+import random
+
+
+class PGDataSet(Dataset):
+ def __init__(self, config, mode, logger, seed=None):
+ super(PGDataSet, self).__init__()
+
+ self.logger = logger
+ self.seed = seed
+ self.mode = mode
+ global_config = config['Global']
+ dataset_config = config[mode]['dataset']
+ loader_config = config[mode]['loader']
+
+ self.delimiter = dataset_config.get('delimiter', '\t')
+ label_file_list = dataset_config.pop('label_file_list')
+ data_source_num = len(label_file_list)
+ ratio_list = dataset_config.get("ratio_list", [1.0])
+ if isinstance(ratio_list, (float, int)):
+ ratio_list = [float(ratio_list)] * int(data_source_num)
+ assert len(
+ ratio_list
+ ) == data_source_num, "The length of ratio_list should be the same as the file_list."
+ self.data_dir = dataset_config['data_dir']
+ self.do_shuffle = loader_config['shuffle']
+
+ logger.info("Initialize indexs of datasets:%s" % label_file_list)
+ self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
+ self.data_idx_order_list = list(range(len(self.data_lines)))
+ if mode.lower() == "train":
+ self.shuffle_data_random()
+
+ self.ops = create_operators(dataset_config['transforms'], global_config)
+
+ self.need_reset = True in [x < 1 for x in ratio_list]
+
+ def shuffle_data_random(self):
+ if self.do_shuffle:
+ random.seed(self.seed)
+ random.shuffle(self.data_lines)
+ return
+
+ def get_image_info_list(self, file_list, ratio_list):
+ if isinstance(file_list, str):
+ file_list = [file_list]
+ data_lines = []
+ for idx, file in enumerate(file_list):
+ with open(file, "rb") as f:
+ lines = f.readlines()
+ if self.mode == "train" or ratio_list[idx] < 1.0:
+ random.seed(self.seed)
+ lines = random.sample(lines,
+ round(len(lines) * ratio_list[idx]))
+ data_lines.extend(lines)
+ return data_lines
+
+ def __getitem__(self, idx):
+ file_idx = self.data_idx_order_list[idx]
+ data_line = self.data_lines[file_idx]
+ img_id = 0
+ try:
+ data_line = data_line.decode('utf-8')
+ substr = data_line.strip("\n").split(self.delimiter)
+ file_name = substr[0]
+ label = substr[1]
+ img_path = os.path.join(self.data_dir, file_name)
+ if self.mode.lower() == 'eval':
+ try:
+ img_id = int(data_line.split(".")[0][7:])
+ except:
+ img_id = 0
+ data = {'img_path': img_path, 'label': label, 'img_id': img_id}
+ if not os.path.exists(img_path):
+ raise Exception("{} does not exist!".format(img_path))
+ with open(data['img_path'], 'rb') as f:
+ img = f.read()
+ data['image'] = img
+ outs = transform(data, self.ops)
+ except Exception as e:
+ self.logger.error(
+ "When parsing line {}, error happened with msg: {}".format(
+ self.data_idx_order_list[idx], e))
+ outs = None
+ if outs is None:
+ return self.__getitem__(np.random.randint(self.__len__()))
+ return outs
+
+ def __len__(self):
+ return len(self.data_idx_order_list)
diff --git a/backend/ppocr/data/pubtab_dataset.py b/backend/ppocr/data/pubtab_dataset.py
new file mode 100644
index 00000000..671cda76
--- /dev/null
+++ b/backend/ppocr/data/pubtab_dataset.py
@@ -0,0 +1,114 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import numpy as np
+import os
+import random
+from paddle.io import Dataset
+import json
+
+from .imaug import transform, create_operators
+
+
+class PubTabDataSet(Dataset):
+ def __init__(self, config, mode, logger, seed=None):
+ super(PubTabDataSet, self).__init__()
+ self.logger = logger
+
+ global_config = config['Global']
+ dataset_config = config[mode]['dataset']
+ loader_config = config[mode]['loader']
+
+ label_file_path = dataset_config.pop('label_file_path')
+
+ self.data_dir = dataset_config['data_dir']
+ self.do_shuffle = loader_config['shuffle']
+ self.do_hard_select = False
+ if 'hard_select' in loader_config:
+ self.do_hard_select = loader_config['hard_select']
+ self.hard_prob = loader_config['hard_prob']
+ if self.do_hard_select:
+ self.img_select_prob = self.load_hard_select_prob()
+ self.table_select_type = None
+ if 'table_select_type' in loader_config:
+ self.table_select_type = loader_config['table_select_type']
+ self.table_select_prob = loader_config['table_select_prob']
+
+ self.seed = seed
+ logger.info("Initialize indexs of datasets:%s" % label_file_path)
+ with open(label_file_path, "rb") as f:
+ self.data_lines = f.readlines()
+ self.data_idx_order_list = list(range(len(self.data_lines)))
+ if mode.lower() == "train":
+ self.shuffle_data_random()
+ self.ops = create_operators(dataset_config['transforms'], global_config)
+
+ ratio_list = dataset_config.get("ratio_list", [1.0])
+ self.need_reset = True in [x < 1 for x in ratio_list]
+
+ def shuffle_data_random(self):
+ if self.do_shuffle:
+ random.seed(self.seed)
+ random.shuffle(self.data_lines)
+ return
+
+ def __getitem__(self, idx):
+ try:
+ data_line = self.data_lines[idx]
+ data_line = data_line.decode('utf-8').strip("\n")
+ info = json.loads(data_line)
+ file_name = info['filename']
+ select_flag = True
+ if self.do_hard_select:
+ prob = self.img_select_prob[file_name]
+ if prob < random.uniform(0, 1):
+ select_flag = False
+
+ if self.table_select_type:
+ structure = info['html']['structure']['tokens'].copy()
+ structure_str = ''.join(structure)
+ table_type = "simple"
+ if 'colspan' in structure_str or 'rowspan' in structure_str:
+ table_type = "complex"
+ if table_type == "complex":
+ if self.table_select_prob < random.uniform(0, 1):
+ select_flag = False
+
+ if select_flag:
+ cells = info['html']['cells'].copy()
+ structure = info['html']['structure'].copy()
+ img_path = os.path.join(self.data_dir, file_name)
+ data = {
+ 'img_path': img_path,
+ 'cells': cells,
+ 'structure': structure
+ }
+ if not os.path.exists(img_path):
+ raise Exception("{} does not exist!".format(img_path))
+ with open(data['img_path'], 'rb') as f:
+ img = f.read()
+ data['image'] = img
+ outs = transform(data, self.ops)
+ else:
+ outs = None
+ except Exception as e:
+ self.logger.error(
+ "When parsing line {}, error happened with msg: {}".format(
+ data_line, e))
+ outs = None
+ if outs is None:
+ return self.__getitem__(np.random.randint(self.__len__()))
+ return outs
+
+ def __len__(self):
+ return len(self.data_idx_order_list)
diff --git a/backend/ppocr/data/simple_dataset.py b/backend/ppocr/data/simple_dataset.py
index d2a86b0f..b5da9b88 100644
--- a/backend/ppocr/data/simple_dataset.py
+++ b/backend/ppocr/data/simple_dataset.py
@@ -13,9 +13,10 @@
# limitations under the License.
import numpy as np
import os
+import json
import random
+import traceback
from paddle.io import Dataset
-
from .imaug import transform, create_operators
@@ -23,6 +24,7 @@ class SimpleDataSet(Dataset):
def __init__(self, config, mode, logger, seed=None):
super(SimpleDataSet, self).__init__()
self.logger = logger
+ self.mode = mode.lower()
global_config = config['Global']
dataset_config = config[mode]['dataset']
@@ -40,14 +42,16 @@ def __init__(self, config, mode, logger, seed=None):
) == data_source_num, "The length of ratio_list should be the same as the file_list."
self.data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle']
-
self.seed = seed
logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
self.data_idx_order_list = list(range(len(self.data_lines)))
- if mode.lower() == "train":
+ if self.mode == "train" and self.do_shuffle:
self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config)
+ self.ext_op_transform_idx = dataset_config.get("ext_op_transform_idx",
+ 2)
+ self.need_reset = True in [x < 1 for x in ratio_list]
def get_image_info_list(self, file_list, ratio_list):
if isinstance(file_list, str):
@@ -56,18 +60,63 @@ def get_image_info_list(self, file_list, ratio_list):
for idx, file in enumerate(file_list):
with open(file, "rb") as f:
lines = f.readlines()
- random.seed(self.seed)
- lines = random.sample(lines,
- round(len(lines) * ratio_list[idx]))
+ if self.mode == "train" or ratio_list[idx] < 1.0:
+ random.seed(self.seed)
+ lines = random.sample(lines,
+ round(len(lines) * ratio_list[idx]))
data_lines.extend(lines)
return data_lines
def shuffle_data_random(self):
- if self.do_shuffle:
- random.seed(self.seed)
- random.shuffle(self.data_lines)
+ random.seed(self.seed)
+ random.shuffle(self.data_lines)
return
+ def _try_parse_filename_list(self, file_name):
+ # multiple images -> one gt label
+ if len(file_name) > 0 and file_name[0] == "[":
+ try:
+ info = json.loads(file_name)
+ file_name = random.choice(info)
+ except:
+ pass
+ return file_name
+
+ def get_ext_data(self):
+ ext_data_num = 0
+ for op in self.ops:
+ if hasattr(op, 'ext_data_num'):
+ ext_data_num = getattr(op, 'ext_data_num')
+ break
+ load_data_ops = self.ops[:self.ext_op_transform_idx]
+ ext_data = []
+
+ while len(ext_data) < ext_data_num:
+ file_idx = self.data_idx_order_list[np.random.randint(self.__len__(
+ ))]
+ data_line = self.data_lines[file_idx]
+ data_line = data_line.decode('utf-8')
+ substr = data_line.strip("\n").split(self.delimiter)
+ file_name = substr[0]
+ file_name = self._try_parse_filename_list(file_name)
+ label = substr[1]
+ img_path = os.path.join(self.data_dir, file_name)
+ data = {'img_path': img_path, 'label': label}
+ if not os.path.exists(img_path):
+ continue
+ with open(data['img_path'], 'rb') as f:
+ img = f.read()
+ data['image'] = img
+ data = transform(data, load_data_ops)
+
+ if data is None:
+ continue
+ if 'polys' in data.keys():
+ if data['polys'].shape[1] != 4:
+ continue
+ ext_data.append(data)
+ return ext_data
+
def __getitem__(self, idx):
file_idx = self.data_idx_order_list[idx]
data_line = self.data_lines[file_idx]
@@ -75,6 +124,7 @@ def __getitem__(self, idx):
data_line = data_line.decode('utf-8')
substr = data_line.strip("\n").split(self.delimiter)
file_name = substr[0]
+ file_name = self._try_parse_filename_list(file_name)
label = substr[1]
img_path = os.path.join(self.data_dir, file_name)
data = {'img_path': img_path, 'label': label}
@@ -83,14 +133,18 @@ def __getitem__(self, idx):
with open(data['img_path'], 'rb') as f:
img = f.read()
data['image'] = img
+ data['ext_data'] = self.get_ext_data()
outs = transform(data, self.ops)
- except Exception as e:
+ except:
self.logger.error(
"When parsing line {}, error happened with msg: {}".format(
- data_line, e))
+ data_line, traceback.format_exc()))
outs = None
if outs is None:
- return self.__getitem__(np.random.randint(self.__len__()))
+ # during evaluation, we should fix the idx to get same results for many times of evaluation.
+ rnd_idx = np.random.randint(self.__len__(
+ )) if self.mode == "train" else (idx + 1) % self.__len__()
+ return self.__getitem__(rnd_idx)
return outs
def __len__(self):
diff --git a/backend/ppocr/losses/__init__.py b/backend/ppocr/losses/__init__.py
index 3881abf7..de8419b7 100755
--- a/backend/ppocr/losses/__init__.py
+++ b/backend/ppocr/losses/__init__.py
@@ -13,27 +13,56 @@
# limitations under the License.
import copy
+import paddle
+import paddle.nn as nn
+# basic_loss
+from .basic_loss import LossFromOutput
-def build_loss(config):
- # det loss
- from .det_db_loss import DBLoss
- from .det_east_loss import EASTLoss
- from .det_sast_loss import SASTLoss
+# det loss
+from .det_db_loss import DBLoss
+from .det_east_loss import EASTLoss
+from .det_sast_loss import SASTLoss
+from .det_pse_loss import PSELoss
+from .det_fce_loss import FCELoss
+
+# rec loss
+from .rec_ctc_loss import CTCLoss
+from .rec_att_loss import AttentionLoss
+from .rec_srn_loss import SRNLoss
+from .rec_nrtr_loss import NRTRLoss
+from .rec_sar_loss import SARLoss
+from .rec_aster_loss import AsterLoss
+from .rec_pren_loss import PRENLoss
+from .rec_multi_loss import MultiLoss
+
+# cls loss
+from .cls_loss import ClsLoss
+
+# e2e loss
+from .e2e_pg_loss import PGLoss
+from .kie_sdmgr_loss import SDMGRLoss
+
+# basic loss function
+from .basic_loss import DistanceLoss
- # rec loss
- from .rec_ctc_loss import CTCLoss
- from .rec_att_loss import AttentionLoss
- from .rec_srn_loss import SRNLoss
+# combined loss function
+from .combined_loss import CombinedLoss
- # cls loss
- from .cls_loss import ClsLoss
+# table loss
+from .table_att_loss import TableAttentionLoss
+# vqa token loss
+from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
+
+
+def build_loss(config):
support_dict = [
- 'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
- 'SRNLoss'
+ 'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'FCELoss', 'CTCLoss',
+ 'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
+ 'NRTRLoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
+ 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss'
]
-
config = copy.deepcopy(config)
module_name = config.pop('name')
assert module_name in support_dict, Exception('loss only support {}'.format(
diff --git a/backend/ppocr/losses/ace_loss.py b/backend/ppocr/losses/ace_loss.py
new file mode 100644
index 00000000..915b99e6
--- /dev/null
+++ b/backend/ppocr/losses/ace_loss.py
@@ -0,0 +1,52 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This code is refer from: https://github.com/viig99/LS-ACELoss
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+import paddle.nn as nn
+
+
+class ACELoss(nn.Layer):
+ def __init__(self, **kwargs):
+ super().__init__()
+ self.loss_func = nn.CrossEntropyLoss(
+ weight=None,
+ ignore_index=0,
+ reduction='none',
+ soft_label=True,
+ axis=-1)
+
+ def __call__(self, predicts, batch):
+ if isinstance(predicts, (list, tuple)):
+ predicts = predicts[-1]
+
+ B, N = predicts.shape[:2]
+ div = paddle.to_tensor([N]).astype('float32')
+
+ predicts = nn.functional.softmax(predicts, axis=-1)
+ aggregation_preds = paddle.sum(predicts, axis=1)
+ aggregation_preds = paddle.divide(aggregation_preds, div)
+
+ length = batch[2].astype("float32")
+ batch = batch[3].astype("float32")
+ batch[:, 0] = paddle.subtract(div, length)
+ batch = paddle.divide(batch, div)
+
+ loss = self.loss_func(aggregation_preds, batch)
+ return {"loss_ace": loss}
diff --git a/backend/ppocr/losses/basic_loss.py b/backend/ppocr/losses/basic_loss.py
new file mode 100644
index 00000000..2df96ea2
--- /dev/null
+++ b/backend/ppocr/losses/basic_loss.py
@@ -0,0 +1,155 @@
+#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+from paddle.nn import L1Loss
+from paddle.nn import MSELoss as L2Loss
+from paddle.nn import SmoothL1Loss
+
+
+class CELoss(nn.Layer):
+ def __init__(self, epsilon=None):
+ super().__init__()
+ if epsilon is not None and (epsilon <= 0 or epsilon >= 1):
+ epsilon = None
+ self.epsilon = epsilon
+
+ def _labelsmoothing(self, target, class_num):
+ if target.shape[-1] != class_num:
+ one_hot_target = F.one_hot(target, class_num)
+ else:
+ one_hot_target = target
+ soft_target = F.label_smooth(one_hot_target, epsilon=self.epsilon)
+ soft_target = paddle.reshape(soft_target, shape=[-1, class_num])
+ return soft_target
+
+ def forward(self, x, label):
+ loss_dict = {}
+ if self.epsilon is not None:
+ class_num = x.shape[-1]
+ label = self._labelsmoothing(label, class_num)
+ x = -F.log_softmax(x, axis=-1)
+ loss = paddle.sum(x * label, axis=-1)
+ else:
+ if label.shape[-1] == x.shape[-1]:
+ label = F.softmax(label, axis=-1)
+ soft_label = True
+ else:
+ soft_label = False
+ loss = F.cross_entropy(x, label=label, soft_label=soft_label)
+ return loss
+
+
+class KLJSLoss(object):
+ def __init__(self, mode='kl'):
+ assert mode in ['kl', 'js', 'KL', 'JS'
+ ], "mode can only be one of ['kl', 'js', 'KL', 'JS']"
+ self.mode = mode
+
+ def __call__(self, p1, p2, reduction="mean"):
+
+ loss = paddle.multiply(p2, paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5))
+
+ if self.mode.lower() == "js":
+ loss += paddle.multiply(
+ p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5))
+ loss *= 0.5
+ if reduction == "mean":
+ loss = paddle.mean(loss, axis=[1, 2])
+ elif reduction == "none" or reduction is None:
+ return loss
+ else:
+ loss = paddle.sum(loss, axis=[1, 2])
+
+ return loss
+
+
+class DMLLoss(nn.Layer):
+ """
+ DMLLoss
+ """
+
+ def __init__(self, act=None, use_log=False):
+ super().__init__()
+ if act is not None:
+ assert act in ["softmax", "sigmoid"]
+ if act == "softmax":
+ self.act = nn.Softmax(axis=-1)
+ elif act == "sigmoid":
+ self.act = nn.Sigmoid()
+ else:
+ self.act = None
+
+ self.use_log = use_log
+ self.jskl_loss = KLJSLoss(mode="js")
+
+ def _kldiv(self, x, target):
+ eps = 1.0e-10
+ loss = target * (paddle.log(target + eps) - x)
+ # batch mean loss
+ loss = paddle.sum(loss) / loss.shape[0]
+ return loss
+
+ def forward(self, out1, out2):
+ if self.act is not None:
+ out1 = self.act(out1) + 1e-10
+ out2 = self.act(out2) + 1e-10
+ if self.use_log:
+ # for recognition distillation, log is needed for feature map
+ log_out1 = paddle.log(out1)
+ log_out2 = paddle.log(out2)
+ loss = (
+ self._kldiv(log_out1, out2) + self._kldiv(log_out2, out1)) / 2.0
+ else:
+ # for detection distillation log is not needed
+ loss = self.jskl_loss(out1, out2)
+ return loss
+
+
+class DistanceLoss(nn.Layer):
+ """
+ DistanceLoss:
+ mode: loss mode
+ """
+
+ def __init__(self, mode="l2", **kargs):
+ super().__init__()
+ assert mode in ["l1", "l2", "smooth_l1"]
+ if mode == "l1":
+ self.loss_func = nn.L1Loss(**kargs)
+ elif mode == "l2":
+ self.loss_func = nn.MSELoss(**kargs)
+ elif mode == "smooth_l1":
+ self.loss_func = nn.SmoothL1Loss(**kargs)
+
+ def forward(self, x, y):
+ return self.loss_func(x, y)
+
+
+class LossFromOutput(nn.Layer):
+ def __init__(self, key='loss', reduction='none'):
+ super().__init__()
+ self.key = key
+ self.reduction = reduction
+
+ def forward(self, predicts, batch):
+ loss = predicts[self.key]
+ if self.reduction == 'mean':
+ loss = paddle.mean(loss)
+ elif self.reduction == 'sum':
+ loss = paddle.sum(loss)
+ return {'loss': loss}
diff --git a/backend/ppocr/losses/center_loss.py b/backend/ppocr/losses/center_loss.py
new file mode 100644
index 00000000..f62b8af3
--- /dev/null
+++ b/backend/ppocr/losses/center_loss.py
@@ -0,0 +1,88 @@
+#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+# This code is refer from: https://github.com/KaiyangZhou/pytorch-center-loss
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import os
+import pickle
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+
+class CenterLoss(nn.Layer):
+ """
+ Reference: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
+ """
+
+ def __init__(self, num_classes=6625, feat_dim=96, center_file_path=None):
+ super().__init__()
+ self.num_classes = num_classes
+ self.feat_dim = feat_dim
+ self.centers = paddle.randn(
+ shape=[self.num_classes, self.feat_dim]).astype("float64")
+
+ if center_file_path is not None:
+ assert os.path.exists(
+ center_file_path
+ ), f"center path({center_file_path}) must exist when it is not None."
+ with open(center_file_path, 'rb') as f:
+ char_dict = pickle.load(f)
+ for key in char_dict.keys():
+ self.centers[key] = paddle.to_tensor(char_dict[key])
+
+ def __call__(self, predicts, batch):
+ assert isinstance(predicts, (list, tuple))
+ features, predicts = predicts
+
+ feats_reshape = paddle.reshape(
+ features, [-1, features.shape[-1]]).astype("float64")
+ label = paddle.argmax(predicts, axis=2)
+ label = paddle.reshape(label, [label.shape[0] * label.shape[1]])
+
+ batch_size = feats_reshape.shape[0]
+
+ #calc l2 distance between feats and centers
+ square_feat = paddle.sum(paddle.square(feats_reshape),
+ axis=1,
+ keepdim=True)
+ square_feat = paddle.expand(square_feat, [batch_size, self.num_classes])
+
+ square_center = paddle.sum(paddle.square(self.centers),
+ axis=1,
+ keepdim=True)
+ square_center = paddle.expand(
+ square_center, [self.num_classes, batch_size]).astype("float64")
+ square_center = paddle.transpose(square_center, [1, 0])
+
+ distmat = paddle.add(square_feat, square_center)
+ feat_dot_center = paddle.matmul(feats_reshape,
+ paddle.transpose(self.centers, [1, 0]))
+ distmat = distmat - 2.0 * feat_dot_center
+
+ #generate the mask
+ classes = paddle.arange(self.num_classes).astype("int64")
+ label = paddle.expand(
+ paddle.unsqueeze(label, 1), (batch_size, self.num_classes))
+ mask = paddle.equal(
+ paddle.expand(classes, [batch_size, self.num_classes]),
+ label).astype("float64")
+ dist = paddle.multiply(distmat, mask)
+
+ loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size
+ return {'loss_center': loss}
diff --git a/backend/ppocr/losses/cls_loss.py b/backend/ppocr/losses/cls_loss.py
index 41c7db02..abc5e5b7 100755
--- a/backend/ppocr/losses/cls_loss.py
+++ b/backend/ppocr/losses/cls_loss.py
@@ -24,7 +24,7 @@ def __init__(self, **kwargs):
super(ClsLoss, self).__init__()
self.loss_func = nn.CrossEntropyLoss(reduction='mean')
- def __call__(self, predicts, batch):
- label = batch[1]
+ def forward(self, predicts, batch):
+ label = batch[1].astype("int64")
loss = self.loss_func(input=predicts, label=label)
return {'loss': loss}
diff --git a/backend/ppocr/losses/combined_loss.py b/backend/ppocr/losses/combined_loss.py
new file mode 100644
index 00000000..f4cdee8f
--- /dev/null
+++ b/backend/ppocr/losses/combined_loss.py
@@ -0,0 +1,69 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+import paddle.nn as nn
+
+from .rec_ctc_loss import CTCLoss
+from .center_loss import CenterLoss
+from .ace_loss import ACELoss
+from .rec_sar_loss import SARLoss
+
+from .distillation_loss import DistillationCTCLoss
+from .distillation_loss import DistillationSARLoss
+from .distillation_loss import DistillationDMLLoss
+from .distillation_loss import DistillationDistanceLoss, DistillationDBLoss, DistillationDilaDBLoss
+
+
+class CombinedLoss(nn.Layer):
+ """
+ CombinedLoss:
+ a combionation of loss function
+ """
+
+ def __init__(self, loss_config_list=None):
+ super().__init__()
+ self.loss_func = []
+ self.loss_weight = []
+ assert isinstance(loss_config_list, list), (
+ 'operator config should be a list')
+ for config in loss_config_list:
+ assert isinstance(config,
+ dict) and len(config) == 1, "yaml format error"
+ name = list(config)[0]
+ param = config[name]
+ assert "weight" in param, "weight must be in param, but param just contains {}".format(
+ param.keys())
+ self.loss_weight.append(param.pop("weight"))
+ self.loss_func.append(eval(name)(**param))
+
+ def forward(self, input, batch, **kargs):
+ loss_dict = {}
+ loss_all = 0.
+ for idx, loss_func in enumerate(self.loss_func):
+ loss = loss_func(input, batch, **kargs)
+ if isinstance(loss, paddle.Tensor):
+ loss = {"loss_{}_{}".format(str(loss), idx): loss}
+
+ weight = self.loss_weight[idx]
+
+ loss = {key: loss[key] * weight for key in loss}
+
+ if "loss" in loss:
+ loss_all += loss["loss"]
+ else:
+ loss_all += paddle.add_n(list(loss.values()))
+ loss_dict.update(loss)
+ loss_dict["loss"] = loss_all
+ return loss_dict
diff --git a/backend/ppocr/losses/det_basic_loss.py b/backend/ppocr/losses/det_basic_loss.py
index 57b3667d..61ea579b 100644
--- a/backend/ppocr/losses/det_basic_loss.py
+++ b/backend/ppocr/losses/det_basic_loss.py
@@ -11,7 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+"""
+This code is refer from:
+https://github.com/WenmuZhou/DBNet.pytorch/blob/master/models/losses/basic_loss.py
+"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -75,12 +78,6 @@ def forward(self, pred, gt, mask=None):
mask (variable): masked maps.
return: (variable) balanced loss
"""
- # if self.main_loss_type in ['DiceLoss']:
- # # For the loss that returns to scalar value, perform ohem on the mask
- # mask = ohem_batch(pred, gt, mask, self.negative_ratio)
- # loss = self.loss(pred, gt, mask)
- # return loss
-
positive = gt * mask
negative = (1 - gt) * mask
@@ -154,52 +151,3 @@ def __init__(self, reduction='mean'):
def forward(self, input, label, mask=None, weight=None, name=None):
loss = F.binary_cross_entropy(input, label, reduction=self.reduction)
return loss
-
-
-def ohem_single(score, gt_text, training_mask, ohem_ratio):
- pos_num = (int)(np.sum(gt_text > 0.5)) - (
- int)(np.sum((gt_text > 0.5) & (training_mask <= 0.5)))
-
- if pos_num == 0:
- # selected_mask = gt_text.copy() * 0 # may be not good
- selected_mask = training_mask
- selected_mask = selected_mask.reshape(
- 1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
- return selected_mask
-
- neg_num = (int)(np.sum(gt_text <= 0.5))
- neg_num = (int)(min(pos_num * ohem_ratio, neg_num))
-
- if neg_num == 0:
- selected_mask = training_mask
- selected_mask = selected_mask.reshape(
- 1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
- return selected_mask
-
- neg_score = score[gt_text <= 0.5]
- # 将负样本得分从高到低排序
- neg_score_sorted = np.sort(-neg_score)
- threshold = -neg_score_sorted[neg_num - 1]
- # 选出 得分高的 负样本 和正样本 的 mask
- selected_mask = ((score >= threshold) |
- (gt_text > 0.5)) & (training_mask > 0.5)
- selected_mask = selected_mask.reshape(
- 1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
- return selected_mask
-
-
-def ohem_batch(scores, gt_texts, training_masks, ohem_ratio):
- scores = scores.numpy()
- gt_texts = gt_texts.numpy()
- training_masks = training_masks.numpy()
-
- selected_masks = []
- for i in range(scores.shape[0]):
- selected_masks.append(
- ohem_single(scores[i, :, :], gt_texts[i, :, :], training_masks[
- i, :, :], ohem_ratio))
-
- selected_masks = np.concatenate(selected_masks, 0)
- selected_masks = paddle.to_variable(selected_masks)
-
- return selected_masks
diff --git a/backend/ppocr/losses/det_db_loss.py b/backend/ppocr/losses/det_db_loss.py
index b079aabf..708ffbdb 100755
--- a/backend/ppocr/losses/det_db_loss.py
+++ b/backend/ppocr/losses/det_db_loss.py
@@ -11,6 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+"""
+This code is refer from:
+https://github.com/WenmuZhou/DBNet.pytorch/blob/master/models/losses/DB_loss.py
+"""
from __future__ import absolute_import
from __future__ import division
diff --git a/backend/ppocr/losses/det_fce_loss.py b/backend/ppocr/losses/det_fce_loss.py
new file mode 100644
index 00000000..d7dfb5aa
--- /dev/null
+++ b/backend/ppocr/losses/det_fce_loss.py
@@ -0,0 +1,227 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textdet/losses/fce_loss.py
+"""
+
+import numpy as np
+from paddle import nn
+import paddle
+import paddle.nn.functional as F
+from functools import partial
+
+
+def multi_apply(func, *args, **kwargs):
+ pfunc = partial(func, **kwargs) if kwargs else func
+ map_results = map(pfunc, *args)
+ return tuple(map(list, zip(*map_results)))
+
+
+class FCELoss(nn.Layer):
+ """The class for implementing FCENet loss
+ FCENet(CVPR2021): Fourier Contour Embedding for Arbitrary-shaped
+ Text Detection
+
+ [https://arxiv.org/abs/2104.10442]
+
+ Args:
+ fourier_degree (int) : The maximum Fourier transform degree k.
+ num_sample (int) : The sampling points number of regression
+ loss. If it is too small, fcenet tends to be overfitting.
+ ohem_ratio (float): the negative/positive ratio in OHEM.
+ """
+
+ def __init__(self, fourier_degree, num_sample, ohem_ratio=3.):
+ super().__init__()
+ self.fourier_degree = fourier_degree
+ self.num_sample = num_sample
+ self.ohem_ratio = ohem_ratio
+
+ def forward(self, preds, labels):
+ assert isinstance(preds, dict)
+ preds = preds['levels']
+
+ p3_maps, p4_maps, p5_maps = labels[1:]
+ assert p3_maps[0].shape[0] == 4 * self.fourier_degree + 5,\
+ 'fourier degree not equal in FCEhead and FCEtarget'
+
+ # to tensor
+ gts = [p3_maps, p4_maps, p5_maps]
+ for idx, maps in enumerate(gts):
+ gts[idx] = paddle.to_tensor(np.stack(maps))
+
+ losses = multi_apply(self.forward_single, preds, gts)
+
+ loss_tr = paddle.to_tensor(0.).astype('float32')
+ loss_tcl = paddle.to_tensor(0.).astype('float32')
+ loss_reg_x = paddle.to_tensor(0.).astype('float32')
+ loss_reg_y = paddle.to_tensor(0.).astype('float32')
+ loss_all = paddle.to_tensor(0.).astype('float32')
+
+ for idx, loss in enumerate(losses):
+ loss_all += sum(loss)
+ if idx == 0:
+ loss_tr += sum(loss)
+ elif idx == 1:
+ loss_tcl += sum(loss)
+ elif idx == 2:
+ loss_reg_x += sum(loss)
+ else:
+ loss_reg_y += sum(loss)
+
+ results = dict(
+ loss=loss_all,
+ loss_text=loss_tr,
+ loss_center=loss_tcl,
+ loss_reg_x=loss_reg_x,
+ loss_reg_y=loss_reg_y, )
+ return results
+
+ def forward_single(self, pred, gt):
+ cls_pred = paddle.transpose(pred[0], (0, 2, 3, 1))
+ reg_pred = paddle.transpose(pred[1], (0, 2, 3, 1))
+ gt = paddle.transpose(gt, (0, 2, 3, 1))
+
+ k = 2 * self.fourier_degree + 1
+ tr_pred = paddle.reshape(cls_pred[:, :, :, :2], (-1, 2))
+ tcl_pred = paddle.reshape(cls_pred[:, :, :, 2:], (-1, 2))
+ x_pred = paddle.reshape(reg_pred[:, :, :, 0:k], (-1, k))
+ y_pred = paddle.reshape(reg_pred[:, :, :, k:2 * k], (-1, k))
+
+ tr_mask = gt[:, :, :, :1].reshape([-1])
+ tcl_mask = gt[:, :, :, 1:2].reshape([-1])
+ train_mask = gt[:, :, :, 2:3].reshape([-1])
+ x_map = paddle.reshape(gt[:, :, :, 3:3 + k], (-1, k))
+ y_map = paddle.reshape(gt[:, :, :, 3 + k:], (-1, k))
+
+ tr_train_mask = (train_mask * tr_mask).astype('bool')
+ tr_train_mask2 = paddle.concat(
+ [tr_train_mask.unsqueeze(1), tr_train_mask.unsqueeze(1)], axis=1)
+ # tr loss
+ loss_tr = self.ohem(tr_pred, tr_mask, train_mask)
+ # tcl loss
+ loss_tcl = paddle.to_tensor(0.).astype('float32')
+ tr_neg_mask = tr_train_mask.logical_not()
+ tr_neg_mask2 = paddle.concat(
+ [tr_neg_mask.unsqueeze(1), tr_neg_mask.unsqueeze(1)], axis=1)
+ if tr_train_mask.sum().item() > 0:
+ loss_tcl_pos = F.cross_entropy(
+ tcl_pred.masked_select(tr_train_mask2).reshape([-1, 2]),
+ tcl_mask.masked_select(tr_train_mask).astype('int64'))
+ loss_tcl_neg = F.cross_entropy(
+ tcl_pred.masked_select(tr_neg_mask2).reshape([-1, 2]),
+ tcl_mask.masked_select(tr_neg_mask).astype('int64'))
+ loss_tcl = loss_tcl_pos + 0.5 * loss_tcl_neg
+
+ # regression loss
+ loss_reg_x = paddle.to_tensor(0.).astype('float32')
+ loss_reg_y = paddle.to_tensor(0.).astype('float32')
+ if tr_train_mask.sum().item() > 0:
+ weight = (tr_mask.masked_select(tr_train_mask.astype('bool'))
+ .astype('float32') + tcl_mask.masked_select(
+ tr_train_mask.astype('bool')).astype('float32')) / 2
+ weight = weight.reshape([-1, 1])
+
+ ft_x, ft_y = self.fourier2poly(x_map, y_map)
+ ft_x_pre, ft_y_pre = self.fourier2poly(x_pred, y_pred)
+
+ dim = ft_x.shape[1]
+
+ tr_train_mask3 = paddle.concat(
+ [tr_train_mask.unsqueeze(1) for i in range(dim)], axis=1)
+
+ loss_reg_x = paddle.mean(weight * F.smooth_l1_loss(
+ ft_x_pre.masked_select(tr_train_mask3).reshape([-1, dim]),
+ ft_x.masked_select(tr_train_mask3).reshape([-1, dim]),
+ reduction='none'))
+ loss_reg_y = paddle.mean(weight * F.smooth_l1_loss(
+ ft_y_pre.masked_select(tr_train_mask3).reshape([-1, dim]),
+ ft_y.masked_select(tr_train_mask3).reshape([-1, dim]),
+ reduction='none'))
+
+ return loss_tr, loss_tcl, loss_reg_x, loss_reg_y
+
+ def ohem(self, predict, target, train_mask):
+
+ pos = (target * train_mask).astype('bool')
+ neg = ((1 - target) * train_mask).astype('bool')
+
+ pos2 = paddle.concat([pos.unsqueeze(1), pos.unsqueeze(1)], axis=1)
+ neg2 = paddle.concat([neg.unsqueeze(1), neg.unsqueeze(1)], axis=1)
+
+ n_pos = pos.astype('float32').sum()
+
+ if n_pos.item() > 0:
+ loss_pos = F.cross_entropy(
+ predict.masked_select(pos2).reshape([-1, 2]),
+ target.masked_select(pos).astype('int64'),
+ reduction='sum')
+ loss_neg = F.cross_entropy(
+ predict.masked_select(neg2).reshape([-1, 2]),
+ target.masked_select(neg).astype('int64'),
+ reduction='none')
+ n_neg = min(
+ int(neg.astype('float32').sum().item()),
+ int(self.ohem_ratio * n_pos.astype('float32')))
+ else:
+ loss_pos = paddle.to_tensor(0.)
+ loss_neg = F.cross_entropy(
+ predict.masked_select(neg2).reshape([-1, 2]),
+ target.masked_select(neg).astype('int64'),
+ reduction='none')
+ n_neg = 100
+ if len(loss_neg) > n_neg:
+ loss_neg, _ = paddle.topk(loss_neg, n_neg)
+
+ return (loss_pos + loss_neg.sum()) / (n_pos + n_neg).astype('float32')
+
+ def fourier2poly(self, real_maps, imag_maps):
+ """Transform Fourier coefficient maps to polygon maps.
+
+ Args:
+ real_maps (tensor): A map composed of the real parts of the
+ Fourier coefficients, whose shape is (-1, 2k+1)
+ imag_maps (tensor):A map composed of the imag parts of the
+ Fourier coefficients, whose shape is (-1, 2k+1)
+
+ Returns
+ x_maps (tensor): A map composed of the x value of the polygon
+ represented by n sample points (xn, yn), whose shape is (-1, n)
+ y_maps (tensor): A map composed of the y value of the polygon
+ represented by n sample points (xn, yn), whose shape is (-1, n)
+ """
+
+ k_vect = paddle.arange(
+ -self.fourier_degree, self.fourier_degree + 1,
+ dtype='float32').reshape([-1, 1])
+ i_vect = paddle.arange(
+ 0, self.num_sample, dtype='float32').reshape([1, -1])
+
+ transform_matrix = 2 * np.pi / self.num_sample * paddle.matmul(k_vect,
+ i_vect)
+
+ x1 = paddle.einsum('ak, kn-> an', real_maps,
+ paddle.cos(transform_matrix))
+ x2 = paddle.einsum('ak, kn-> an', imag_maps,
+ paddle.sin(transform_matrix))
+ y1 = paddle.einsum('ak, kn-> an', real_maps,
+ paddle.sin(transform_matrix))
+ y2 = paddle.einsum('ak, kn-> an', imag_maps,
+ paddle.cos(transform_matrix))
+
+ x_maps = x1 - x2
+ y_maps = y1 + y2
+
+ return x_maps, y_maps
diff --git a/backend/ppocr/losses/det_pse_loss.py b/backend/ppocr/losses/det_pse_loss.py
new file mode 100644
index 00000000..6b31343e
--- /dev/null
+++ b/backend/ppocr/losses/det_pse_loss.py
@@ -0,0 +1,149 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/whai362/PSENet/blob/python3/models/head/psenet_head.py
+"""
+
+import paddle
+from paddle import nn
+from paddle.nn import functional as F
+import numpy as np
+from ppocr.utils.iou import iou
+
+
+class PSELoss(nn.Layer):
+ def __init__(self,
+ alpha,
+ ohem_ratio=3,
+ kernel_sample_mask='pred',
+ reduction='sum',
+ eps=1e-6,
+ **kwargs):
+ """Implement PSE Loss.
+ """
+ super(PSELoss, self).__init__()
+ assert reduction in ['sum', 'mean', 'none']
+ self.alpha = alpha
+ self.ohem_ratio = ohem_ratio
+ self.kernel_sample_mask = kernel_sample_mask
+ self.reduction = reduction
+ self.eps = eps
+
+ def forward(self, outputs, labels):
+ predicts = outputs['maps']
+ predicts = F.interpolate(predicts, scale_factor=4)
+
+ texts = predicts[:, 0, :, :]
+ kernels = predicts[:, 1:, :, :]
+ gt_texts, gt_kernels, training_masks = labels[1:]
+
+ # text loss
+ selected_masks = self.ohem_batch(texts, gt_texts, training_masks)
+
+ loss_text = self.dice_loss(texts, gt_texts, selected_masks)
+ iou_text = iou((texts > 0).astype('int64'),
+ gt_texts,
+ training_masks,
+ reduce=False)
+ losses = dict(loss_text=loss_text, iou_text=iou_text)
+
+ # kernel loss
+ loss_kernels = []
+ if self.kernel_sample_mask == 'gt':
+ selected_masks = gt_texts * training_masks
+ elif self.kernel_sample_mask == 'pred':
+ selected_masks = (
+ F.sigmoid(texts) > 0.5).astype('float32') * training_masks
+
+ for i in range(kernels.shape[1]):
+ kernel_i = kernels[:, i, :, :]
+ gt_kernel_i = gt_kernels[:, i, :, :]
+ loss_kernel_i = self.dice_loss(kernel_i, gt_kernel_i,
+ selected_masks)
+ loss_kernels.append(loss_kernel_i)
+ loss_kernels = paddle.mean(paddle.stack(loss_kernels, axis=1), axis=1)
+ iou_kernel = iou((kernels[:, -1, :, :] > 0).astype('int64'),
+ gt_kernels[:, -1, :, :],
+ training_masks * gt_texts,
+ reduce=False)
+ losses.update(dict(loss_kernels=loss_kernels, iou_kernel=iou_kernel))
+ loss = self.alpha * loss_text + (1 - self.alpha) * loss_kernels
+ losses['loss'] = loss
+ if self.reduction == 'sum':
+ losses = {x: paddle.sum(v) for x, v in losses.items()}
+ elif self.reduction == 'mean':
+ losses = {x: paddle.mean(v) for x, v in losses.items()}
+ return losses
+
+ def dice_loss(self, input, target, mask):
+ input = F.sigmoid(input)
+
+ input = input.reshape([input.shape[0], -1])
+ target = target.reshape([target.shape[0], -1])
+ mask = mask.reshape([mask.shape[0], -1])
+
+ input = input * mask
+ target = target * mask
+
+ a = paddle.sum(input * target, 1)
+ b = paddle.sum(input * input, 1) + self.eps
+ c = paddle.sum(target * target, 1) + self.eps
+ d = (2 * a) / (b + c)
+ return 1 - d
+
+ def ohem_single(self, score, gt_text, training_mask, ohem_ratio=3):
+ pos_num = int(paddle.sum((gt_text > 0.5).astype('float32'))) - int(
+ paddle.sum(
+ paddle.logical_and((gt_text > 0.5), (training_mask <= 0.5))
+ .astype('float32')))
+
+ if pos_num == 0:
+ selected_mask = training_mask
+ selected_mask = selected_mask.reshape(
+ [1, selected_mask.shape[0], selected_mask.shape[1]]).astype(
+ 'float32')
+ return selected_mask
+
+ neg_num = int(paddle.sum((gt_text <= 0.5).astype('float32')))
+ neg_num = int(min(pos_num * ohem_ratio, neg_num))
+
+ if neg_num == 0:
+ selected_mask = training_mask
+ selected_mask = selected_mask.reshape(
+ [1, selected_mask.shape[0], selected_mask.shape[1]]).astype(
+ 'float32')
+ return selected_mask
+
+ neg_score = paddle.masked_select(score, gt_text <= 0.5)
+ neg_score_sorted = paddle.sort(-neg_score)
+ threshold = -neg_score_sorted[neg_num - 1]
+
+ selected_mask = paddle.logical_and(
+ paddle.logical_or((score >= threshold), (gt_text > 0.5)),
+ (training_mask > 0.5))
+ selected_mask = selected_mask.reshape(
+ [1, selected_mask.shape[0], selected_mask.shape[1]]).astype(
+ 'float32')
+ return selected_mask
+
+ def ohem_batch(self, scores, gt_texts, training_masks, ohem_ratio=3):
+ selected_masks = []
+ for i in range(scores.shape[0]):
+ selected_masks.append(
+ self.ohem_single(scores[i, :, :], gt_texts[i, :, :],
+ training_masks[i, :, :], ohem_ratio))
+
+ selected_masks = paddle.concat(selected_masks, 0).astype('float32')
+ return selected_masks
diff --git a/backend/ppocr/losses/distillation_loss.py b/backend/ppocr/losses/distillation_loss.py
new file mode 100644
index 00000000..565b066d
--- /dev/null
+++ b/backend/ppocr/losses/distillation_loss.py
@@ -0,0 +1,324 @@
+#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+import paddle
+import paddle.nn as nn
+import numpy as np
+import cv2
+
+from .rec_ctc_loss import CTCLoss
+from .rec_sar_loss import SARLoss
+from .basic_loss import DMLLoss
+from .basic_loss import DistanceLoss
+from .det_db_loss import DBLoss
+from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss
+
+
+def _sum_loss(loss_dict):
+ if "loss" in loss_dict.keys():
+ return loss_dict
+ else:
+ loss_dict["loss"] = 0.
+ for k, value in loss_dict.items():
+ if k == "loss":
+ continue
+ else:
+ loss_dict["loss"] += value
+ return loss_dict
+
+
+class DistillationDMLLoss(DMLLoss):
+ """
+ """
+
+ def __init__(self,
+ model_name_pairs=[],
+ act=None,
+ use_log=False,
+ key=None,
+ multi_head=False,
+ dis_head='ctc',
+ maps_name=None,
+ name="dml"):
+ super().__init__(act=act, use_log=use_log)
+ assert isinstance(model_name_pairs, list)
+ self.key = key
+ self.multi_head = multi_head
+ self.dis_head = dis_head
+ self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
+ self.name = name
+ self.maps_name = self._check_maps_name(maps_name)
+
+ def _check_model_name_pairs(self, model_name_pairs):
+ if not isinstance(model_name_pairs, list):
+ return []
+ elif isinstance(model_name_pairs[0], list) and isinstance(
+ model_name_pairs[0][0], str):
+ return model_name_pairs
+ else:
+ return [model_name_pairs]
+
+ def _check_maps_name(self, maps_name):
+ if maps_name is None:
+ return None
+ elif type(maps_name) == str:
+ return [maps_name]
+ elif type(maps_name) == list:
+ return [maps_name]
+ else:
+ return None
+
+ def _slice_out(self, outs):
+ new_outs = {}
+ for k in self.maps_name:
+ if k == "thrink_maps":
+ new_outs[k] = outs[:, 0, :, :]
+ elif k == "threshold_maps":
+ new_outs[k] = outs[:, 1, :, :]
+ elif k == "binary_maps":
+ new_outs[k] = outs[:, 2, :, :]
+ else:
+ continue
+ return new_outs
+
+ def forward(self, predicts, batch):
+ loss_dict = dict()
+ for idx, pair in enumerate(self.model_name_pairs):
+ out1 = predicts[pair[0]]
+ out2 = predicts[pair[1]]
+ if self.key is not None:
+ out1 = out1[self.key]
+ out2 = out2[self.key]
+
+ if self.maps_name is None:
+ if self.multi_head:
+ loss = super().forward(out1[self.dis_head],
+ out2[self.dis_head])
+ else:
+ loss = super().forward(out1, out2)
+ if isinstance(loss, dict):
+ for key in loss:
+ loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1],
+ idx)] = loss[key]
+ else:
+ loss_dict["{}_{}".format(self.name, idx)] = loss
+ else:
+ outs1 = self._slice_out(out1)
+ outs2 = self._slice_out(out2)
+ for _c, k in enumerate(outs1.keys()):
+ loss = super().forward(outs1[k], outs2[k])
+ if isinstance(loss, dict):
+ for key in loss:
+ loss_dict["{}_{}_{}_{}_{}".format(key, pair[
+ 0], pair[1], self.maps_name, idx)] = loss[key]
+ else:
+ loss_dict["{}_{}_{}".format(self.name, self.maps_name[
+ _c], idx)] = loss
+
+ loss_dict = _sum_loss(loss_dict)
+
+ return loss_dict
+
+
+class DistillationCTCLoss(CTCLoss):
+ def __init__(self,
+ model_name_list=[],
+ key=None,
+ multi_head=False,
+ name="loss_ctc"):
+ super().__init__()
+ self.model_name_list = model_name_list
+ self.key = key
+ self.name = name
+ self.multi_head = multi_head
+
+ def forward(self, predicts, batch):
+ loss_dict = dict()
+ for idx, model_name in enumerate(self.model_name_list):
+ out = predicts[model_name]
+ if self.key is not None:
+ out = out[self.key]
+ if self.multi_head:
+ assert 'ctc' in out, 'multi head has multi out'
+ loss = super().forward(out['ctc'], batch[:2] + batch[3:])
+ else:
+ loss = super().forward(out, batch)
+ if isinstance(loss, dict):
+ for key in loss:
+ loss_dict["{}_{}_{}".format(self.name, model_name,
+ idx)] = loss[key]
+ else:
+ loss_dict["{}_{}".format(self.name, model_name)] = loss
+ return loss_dict
+
+
+class DistillationSARLoss(SARLoss):
+ def __init__(self,
+ model_name_list=[],
+ key=None,
+ multi_head=False,
+ name="loss_sar",
+ **kwargs):
+ ignore_index = kwargs.get('ignore_index', 92)
+ super().__init__(ignore_index=ignore_index)
+ self.model_name_list = model_name_list
+ self.key = key
+ self.name = name
+ self.multi_head = multi_head
+
+ def forward(self, predicts, batch):
+ loss_dict = dict()
+ for idx, model_name in enumerate(self.model_name_list):
+ out = predicts[model_name]
+ if self.key is not None:
+ out = out[self.key]
+ if self.multi_head:
+ assert 'sar' in out, 'multi head has multi out'
+ loss = super().forward(out['sar'], batch[:1] + batch[2:])
+ else:
+ loss = super().forward(out, batch)
+ if isinstance(loss, dict):
+ for key in loss:
+ loss_dict["{}_{}_{}".format(self.name, model_name,
+ idx)] = loss[key]
+ else:
+ loss_dict["{}_{}".format(self.name, model_name)] = loss
+ return loss_dict
+
+
+class DistillationDBLoss(DBLoss):
+ def __init__(self,
+ model_name_list=[],
+ balance_loss=True,
+ main_loss_type='DiceLoss',
+ alpha=5,
+ beta=10,
+ ohem_ratio=3,
+ eps=1e-6,
+ name="db",
+ **kwargs):
+ super().__init__()
+ self.model_name_list = model_name_list
+ self.name = name
+ self.key = None
+
+ def forward(self, predicts, batch):
+ loss_dict = {}
+ for idx, model_name in enumerate(self.model_name_list):
+ out = predicts[model_name]
+ if self.key is not None:
+ out = out[self.key]
+ loss = super().forward(out, batch)
+
+ if isinstance(loss, dict):
+ for key in loss.keys():
+ if key == "loss":
+ continue
+ name = "{}_{}_{}".format(self.name, model_name, key)
+ loss_dict[name] = loss[key]
+ else:
+ loss_dict["{}_{}".format(self.name, model_name)] = loss
+
+ loss_dict = _sum_loss(loss_dict)
+ return loss_dict
+
+
+class DistillationDilaDBLoss(DBLoss):
+ def __init__(self,
+ model_name_pairs=[],
+ key=None,
+ balance_loss=True,
+ main_loss_type='DiceLoss',
+ alpha=5,
+ beta=10,
+ ohem_ratio=3,
+ eps=1e-6,
+ name="dila_dbloss"):
+ super().__init__()
+ self.model_name_pairs = model_name_pairs
+ self.name = name
+ self.key = key
+
+ def forward(self, predicts, batch):
+ loss_dict = dict()
+ for idx, pair in enumerate(self.model_name_pairs):
+ stu_outs = predicts[pair[0]]
+ tch_outs = predicts[pair[1]]
+ if self.key is not None:
+ stu_preds = stu_outs[self.key]
+ tch_preds = tch_outs[self.key]
+
+ stu_shrink_maps = stu_preds[:, 0, :, :]
+ stu_binary_maps = stu_preds[:, 2, :, :]
+
+ # dilation to teacher prediction
+ dilation_w = np.array([[1, 1], [1, 1]])
+ th_shrink_maps = tch_preds[:, 0, :, :]
+ th_shrink_maps = th_shrink_maps.numpy() > 0.3 # thresh = 0.3
+ dilate_maps = np.zeros_like(th_shrink_maps).astype(np.float32)
+ for i in range(th_shrink_maps.shape[0]):
+ dilate_maps[i] = cv2.dilate(
+ th_shrink_maps[i, :, :].astype(np.uint8), dilation_w)
+ th_shrink_maps = paddle.to_tensor(dilate_maps)
+
+ label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = batch[
+ 1:]
+
+ # calculate the shrink map loss
+ bce_loss = self.alpha * self.bce_loss(
+ stu_shrink_maps, th_shrink_maps, label_shrink_mask)
+ loss_binary_maps = self.dice_loss(stu_binary_maps, th_shrink_maps,
+ label_shrink_mask)
+
+ # k = f"{self.name}_{pair[0]}_{pair[1]}"
+ k = "{}_{}_{}".format(self.name, pair[0], pair[1])
+ loss_dict[k] = bce_loss + loss_binary_maps
+
+ loss_dict = _sum_loss(loss_dict)
+ return loss_dict
+
+
+class DistillationDistanceLoss(DistanceLoss):
+ """
+ """
+
+ def __init__(self,
+ mode="l2",
+ model_name_pairs=[],
+ key=None,
+ name="loss_distance",
+ **kargs):
+ super().__init__(mode=mode, **kargs)
+ assert isinstance(model_name_pairs, list)
+ self.key = key
+ self.model_name_pairs = model_name_pairs
+ self.name = name + "_l2"
+
+ def forward(self, predicts, batch):
+ loss_dict = dict()
+ for idx, pair in enumerate(self.model_name_pairs):
+ out1 = predicts[pair[0]]
+ out2 = predicts[pair[1]]
+ if self.key is not None:
+ out1 = out1[self.key]
+ out2 = out2[self.key]
+ loss = super().forward(out1, out2)
+ if isinstance(loss, dict):
+ for key in loss:
+ loss_dict["{}_{}_{}".format(self.name, key, idx)] = loss[
+ key]
+ else:
+ loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1],
+ idx)] = loss
+ return loss_dict
diff --git a/backend/ppocr/losses/e2e_pg_loss.py b/backend/ppocr/losses/e2e_pg_loss.py
new file mode 100644
index 00000000..10a8ed0a
--- /dev/null
+++ b/backend/ppocr/losses/e2e_pg_loss.py
@@ -0,0 +1,140 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from paddle import nn
+import paddle
+
+from .det_basic_loss import DiceLoss
+from ppocr.utils.e2e_utils.extract_batchsize import pre_process
+
+
+class PGLoss(nn.Layer):
+ def __init__(self,
+ tcl_bs,
+ max_text_length,
+ max_text_nums,
+ pad_num,
+ eps=1e-6,
+ **kwargs):
+ super(PGLoss, self).__init__()
+ self.tcl_bs = tcl_bs
+ self.max_text_nums = max_text_nums
+ self.max_text_length = max_text_length
+ self.pad_num = pad_num
+ self.dice_loss = DiceLoss(eps=eps)
+
+ def border_loss(self, f_border, l_border, l_score, l_mask):
+ l_border_split, l_border_norm = paddle.tensor.split(
+ l_border, num_or_sections=[4, 1], axis=1)
+ f_border_split = f_border
+ b, c, h, w = l_border_norm.shape
+ l_border_norm_split = paddle.expand(
+ x=l_border_norm, shape=[b, 4 * c, h, w])
+ b, c, h, w = l_score.shape
+ l_border_score = paddle.expand(x=l_score, shape=[b, 4 * c, h, w])
+ b, c, h, w = l_mask.shape
+ l_border_mask = paddle.expand(x=l_mask, shape=[b, 4 * c, h, w])
+ border_diff = l_border_split - f_border_split
+ abs_border_diff = paddle.abs(border_diff)
+ border_sign = abs_border_diff < 1.0
+ border_sign = paddle.cast(border_sign, dtype='float32')
+ border_sign.stop_gradient = True
+ border_in_loss = 0.5 * abs_border_diff * abs_border_diff * border_sign + \
+ (abs_border_diff - 0.5) * (1.0 - border_sign)
+ border_out_loss = l_border_norm_split * border_in_loss
+ border_loss = paddle.sum(border_out_loss * l_border_score * l_border_mask) / \
+ (paddle.sum(l_border_score * l_border_mask) + 1e-5)
+ return border_loss
+
+ def direction_loss(self, f_direction, l_direction, l_score, l_mask):
+ l_direction_split, l_direction_norm = paddle.tensor.split(
+ l_direction, num_or_sections=[2, 1], axis=1)
+ f_direction_split = f_direction
+ b, c, h, w = l_direction_norm.shape
+ l_direction_norm_split = paddle.expand(
+ x=l_direction_norm, shape=[b, 2 * c, h, w])
+ b, c, h, w = l_score.shape
+ l_direction_score = paddle.expand(x=l_score, shape=[b, 2 * c, h, w])
+ b, c, h, w = l_mask.shape
+ l_direction_mask = paddle.expand(x=l_mask, shape=[b, 2 * c, h, w])
+ direction_diff = l_direction_split - f_direction_split
+ abs_direction_diff = paddle.abs(direction_diff)
+ direction_sign = abs_direction_diff < 1.0
+ direction_sign = paddle.cast(direction_sign, dtype='float32')
+ direction_sign.stop_gradient = True
+ direction_in_loss = 0.5 * abs_direction_diff * abs_direction_diff * direction_sign + \
+ (abs_direction_diff - 0.5) * (1.0 - direction_sign)
+ direction_out_loss = l_direction_norm_split * direction_in_loss
+ direction_loss = paddle.sum(direction_out_loss * l_direction_score * l_direction_mask) / \
+ (paddle.sum(l_direction_score * l_direction_mask) + 1e-5)
+ return direction_loss
+
+ def ctcloss(self, f_char, tcl_pos, tcl_mask, tcl_label, label_t):
+ f_char = paddle.transpose(f_char, [0, 2, 3, 1])
+ tcl_pos = paddle.reshape(tcl_pos, [-1, 3])
+ tcl_pos = paddle.cast(tcl_pos, dtype=int)
+ f_tcl_char = paddle.gather_nd(f_char, tcl_pos)
+ f_tcl_char = paddle.reshape(f_tcl_char,
+ [-1, 64, 37]) # len(Lexicon_Table)+1
+ f_tcl_char_fg, f_tcl_char_bg = paddle.split(f_tcl_char, [36, 1], axis=2)
+ f_tcl_char_bg = f_tcl_char_bg * tcl_mask + (1.0 - tcl_mask) * 20.0
+ b, c, l = tcl_mask.shape
+ tcl_mask_fg = paddle.expand(x=tcl_mask, shape=[b, c, 36 * l])
+ tcl_mask_fg.stop_gradient = True
+ f_tcl_char_fg = f_tcl_char_fg * tcl_mask_fg + (1.0 - tcl_mask_fg) * (
+ -20.0)
+ f_tcl_char_mask = paddle.concat([f_tcl_char_fg, f_tcl_char_bg], axis=2)
+ f_tcl_char_ld = paddle.transpose(f_tcl_char_mask, (1, 0, 2))
+ N, B, _ = f_tcl_char_ld.shape
+ input_lengths = paddle.to_tensor([N] * B, dtype='int64')
+ cost = paddle.nn.functional.ctc_loss(
+ log_probs=f_tcl_char_ld,
+ labels=tcl_label,
+ input_lengths=input_lengths,
+ label_lengths=label_t,
+ blank=self.pad_num,
+ reduction='none')
+ cost = cost.mean()
+ return cost
+
+ def forward(self, predicts, labels):
+ images, tcl_maps, tcl_label_maps, border_maps \
+ , direction_maps, training_masks, label_list, pos_list, pos_mask = labels
+ # for all the batch_size
+ pos_list, pos_mask, label_list, label_t = pre_process(
+ label_list, pos_list, pos_mask, self.max_text_length,
+ self.max_text_nums, self.pad_num, self.tcl_bs)
+
+ f_score, f_border, f_direction, f_char = predicts['f_score'], predicts['f_border'], predicts['f_direction'], \
+ predicts['f_char']
+ score_loss = self.dice_loss(f_score, tcl_maps, training_masks)
+ border_loss = self.border_loss(f_border, border_maps, tcl_maps,
+ training_masks)
+ direction_loss = self.direction_loss(f_direction, direction_maps,
+ tcl_maps, training_masks)
+ ctc_loss = self.ctcloss(f_char, pos_list, pos_mask, label_list, label_t)
+ loss_all = score_loss + border_loss + direction_loss + 5 * ctc_loss
+
+ losses = {
+ 'loss': loss_all,
+ "score_loss": score_loss,
+ "border_loss": border_loss,
+ "direction_loss": direction_loss,
+ "ctc_loss": ctc_loss
+ }
+ return losses
diff --git a/backend/ppocr/losses/kie_sdmgr_loss.py b/backend/ppocr/losses/kie_sdmgr_loss.py
new file mode 100644
index 00000000..745671f5
--- /dev/null
+++ b/backend/ppocr/losses/kie_sdmgr_loss.py
@@ -0,0 +1,115 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# reference from : https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/kie/losses/sdmgr_loss.py
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from paddle import nn
+import paddle
+
+
+class SDMGRLoss(nn.Layer):
+ def __init__(self, node_weight=1.0, edge_weight=1.0, ignore=0):
+ super().__init__()
+ self.loss_node = nn.CrossEntropyLoss(ignore_index=ignore)
+ self.loss_edge = nn.CrossEntropyLoss(ignore_index=-1)
+ self.node_weight = node_weight
+ self.edge_weight = edge_weight
+ self.ignore = ignore
+
+ def pre_process(self, gts, tag):
+ gts, tag = gts.numpy(), tag.numpy().tolist()
+ temp_gts = []
+ batch = len(tag)
+ for i in range(batch):
+ num, recoder_len = tag[i][0], tag[i][1]
+ temp_gts.append(
+ paddle.to_tensor(
+ gts[i, :num, :num + 1], dtype='int64'))
+ return temp_gts
+
+ def accuracy(self, pred, target, topk=1, thresh=None):
+ """Calculate accuracy according to the prediction and target.
+
+ Args:
+ pred (torch.Tensor): The model prediction, shape (N, num_class)
+ target (torch.Tensor): The target of each prediction, shape (N, )
+ topk (int | tuple[int], optional): If the predictions in ``topk``
+ matches the target, the predictions will be regarded as
+ correct ones. Defaults to 1.
+ thresh (float, optional): If not None, predictions with scores under
+ this threshold are considered incorrect. Default to None.
+
+ Returns:
+ float | tuple[float]: If the input ``topk`` is a single integer,
+ the function will return a single float as accuracy. If
+ ``topk`` is a tuple containing multiple integers, the
+ function will return a tuple containing accuracies of
+ each ``topk`` number.
+ """
+ assert isinstance(topk, (int, tuple))
+ if isinstance(topk, int):
+ topk = (topk, )
+ return_single = True
+ else:
+ return_single = False
+
+ maxk = max(topk)
+ if pred.shape[0] == 0:
+ accu = [pred.new_tensor(0.) for i in range(len(topk))]
+ return accu[0] if return_single else accu
+ pred_value, pred_label = paddle.topk(pred, maxk, axis=1)
+ pred_label = pred_label.transpose(
+ [1, 0]) # transpose to shape (maxk, N)
+ correct = paddle.equal(pred_label,
+ (target.reshape([1, -1]).expand_as(pred_label)))
+ res = []
+ for k in topk:
+ correct_k = paddle.sum(correct[:k].reshape([-1]).astype('float32'),
+ axis=0,
+ keepdim=True)
+ res.append(
+ paddle.multiply(correct_k,
+ paddle.to_tensor(100.0 / pred.shape[0])))
+ return res[0] if return_single else res
+
+ def forward(self, pred, batch):
+ node_preds, edge_preds = pred
+ gts, tag = batch[4], batch[5]
+ gts = self.pre_process(gts, tag)
+ node_gts, edge_gts = [], []
+ for gt in gts:
+ node_gts.append(gt[:, 0])
+ edge_gts.append(gt[:, 1:].reshape([-1]))
+ node_gts = paddle.concat(node_gts)
+ edge_gts = paddle.concat(edge_gts)
+
+ node_valids = paddle.nonzero(node_gts != self.ignore).reshape([-1])
+ edge_valids = paddle.nonzero(edge_gts != -1).reshape([-1])
+ loss_node = self.loss_node(node_preds, node_gts)
+ loss_edge = self.loss_edge(edge_preds, edge_gts)
+ loss = self.node_weight * loss_node + self.edge_weight * loss_edge
+ return dict(
+ loss=loss,
+ loss_node=loss_node,
+ loss_edge=loss_edge,
+ acc_node=self.accuracy(
+ paddle.gather(node_preds, node_valids),
+ paddle.gather(node_gts, node_valids)),
+ acc_edge=self.accuracy(
+ paddle.gather(edge_preds, edge_valids),
+ paddle.gather(edge_gts, edge_valids)))
diff --git a/backend/ppocr/losses/rec_aster_loss.py b/backend/ppocr/losses/rec_aster_loss.py
new file mode 100644
index 00000000..fbb99d29
--- /dev/null
+++ b/backend/ppocr/losses/rec_aster_loss.py
@@ -0,0 +1,99 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+
+
+class CosineEmbeddingLoss(nn.Layer):
+ def __init__(self, margin=0.):
+ super(CosineEmbeddingLoss, self).__init__()
+ self.margin = margin
+ self.epsilon = 1e-12
+
+ def forward(self, x1, x2, target):
+ similarity = paddle.fluid.layers.reduce_sum(
+ x1 * x2, dim=-1) / (paddle.norm(
+ x1, axis=-1) * paddle.norm(
+ x2, axis=-1) + self.epsilon)
+ one_list = paddle.full_like(target, fill_value=1)
+ out = paddle.fluid.layers.reduce_mean(
+ paddle.where(
+ paddle.equal(target, one_list), 1. - similarity,
+ paddle.maximum(
+ paddle.zeros_like(similarity), similarity - self.margin)))
+
+ return out
+
+
+class AsterLoss(nn.Layer):
+ def __init__(self,
+ weight=None,
+ size_average=True,
+ ignore_index=-100,
+ sequence_normalize=False,
+ sample_normalize=True,
+ **kwargs):
+ super(AsterLoss, self).__init__()
+ self.weight = weight
+ self.size_average = size_average
+ self.ignore_index = ignore_index
+ self.sequence_normalize = sequence_normalize
+ self.sample_normalize = sample_normalize
+ self.loss_sem = CosineEmbeddingLoss()
+ self.is_cosin_loss = True
+ self.loss_func_rec = nn.CrossEntropyLoss(weight=None, reduction='none')
+
+ def forward(self, predicts, batch):
+ targets = batch[1].astype("int64")
+ label_lengths = batch[2].astype('int64')
+ sem_target = batch[3].astype('float32')
+ embedding_vectors = predicts['embedding_vectors']
+ rec_pred = predicts['rec_pred']
+
+ if not self.is_cosin_loss:
+ sem_loss = paddle.sum(self.loss_sem(embedding_vectors, sem_target))
+ else:
+ label_target = paddle.ones([embedding_vectors.shape[0]])
+ sem_loss = paddle.sum(
+ self.loss_sem(embedding_vectors, sem_target, label_target))
+
+ # rec loss
+ batch_size, def_max_length = targets.shape[0], targets.shape[1]
+
+ mask = paddle.zeros([batch_size, def_max_length])
+ for i in range(batch_size):
+ mask[i, :label_lengths[i]] = 1
+ mask = paddle.cast(mask, "float32")
+ max_length = max(label_lengths)
+ assert max_length == rec_pred.shape[1]
+ targets = targets[:, :max_length]
+ mask = mask[:, :max_length]
+ rec_pred = paddle.reshape(rec_pred, [-1, rec_pred.shape[2]])
+ input = nn.functional.log_softmax(rec_pred, axis=1)
+ targets = paddle.reshape(targets, [-1, 1])
+ mask = paddle.reshape(mask, [-1, 1])
+ output = -paddle.index_sample(input, index=targets) * mask
+ output = paddle.sum(output)
+ if self.sequence_normalize:
+ output = output / paddle.sum(mask)
+ if self.sample_normalize:
+ output = output / batch_size
+
+ loss = output + sem_loss * 0.1
+ return {'loss': loss}
diff --git a/backend/ppocr/losses/rec_ctc_loss.py b/backend/ppocr/losses/rec_ctc_loss.py
index 425de587..502fc8c5 100755
--- a/backend/ppocr/losses/rec_ctc_loss.py
+++ b/backend/ppocr/losses/rec_ctc_loss.py
@@ -21,16 +21,25 @@
class CTCLoss(nn.Layer):
- def __init__(self, **kwargs):
+ def __init__(self, use_focal_loss=False, **kwargs):
super(CTCLoss, self).__init__()
self.loss_func = nn.CTCLoss(blank=0, reduction='none')
+ self.use_focal_loss = use_focal_loss
- def __call__(self, predicts, batch):
+ def forward(self, predicts, batch):
+ if isinstance(predicts, (list, tuple)):
+ predicts = predicts[-1]
predicts = predicts.transpose((1, 0, 2))
N, B, _ = predicts.shape
- preds_lengths = paddle.to_tensor([N] * B, dtype='int64')
+ preds_lengths = paddle.to_tensor(
+ [N] * B, dtype='int64', place=paddle.CPUPlace())
labels = batch[1].astype("int32")
label_lengths = batch[2].astype('int64')
loss = self.loss_func(predicts, labels, preds_lengths, label_lengths)
- loss = loss.mean() # sum
+ if self.use_focal_loss:
+ weight = paddle.exp(-loss)
+ weight = paddle.subtract(paddle.to_tensor([1.0]), weight)
+ weight = paddle.square(weight)
+ loss = paddle.multiply(loss, weight)
+ loss = loss.mean()
return {'loss': loss}
diff --git a/backend/ppocr/losses/rec_enhanced_ctc_loss.py b/backend/ppocr/losses/rec_enhanced_ctc_loss.py
new file mode 100644
index 00000000..b57be646
--- /dev/null
+++ b/backend/ppocr/losses/rec_enhanced_ctc_loss.py
@@ -0,0 +1,70 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+from .ace_loss import ACELoss
+from .center_loss import CenterLoss
+from .rec_ctc_loss import CTCLoss
+
+
+class EnhancedCTCLoss(nn.Layer):
+ def __init__(self,
+ use_focal_loss=False,
+ use_ace_loss=False,
+ ace_loss_weight=0.1,
+ use_center_loss=False,
+ center_loss_weight=0.05,
+ num_classes=6625,
+ feat_dim=96,
+ init_center=False,
+ center_file_path=None,
+ **kwargs):
+ super(EnhancedCTCLoss, self).__init__()
+ self.ctc_loss_func = CTCLoss(use_focal_loss=use_focal_loss)
+
+ self.use_ace_loss = False
+ if use_ace_loss:
+ self.use_ace_loss = use_ace_loss
+ self.ace_loss_func = ACELoss()
+ self.ace_loss_weight = ace_loss_weight
+
+ self.use_center_loss = False
+ if use_center_loss:
+ self.use_center_loss = use_center_loss
+ self.center_loss_func = CenterLoss(
+ num_classes=num_classes,
+ feat_dim=feat_dim,
+ init_center=init_center,
+ center_file_path=center_file_path)
+ self.center_loss_weight = center_loss_weight
+
+ def __call__(self, predicts, batch):
+ loss = self.ctc_loss_func(predicts, batch)["loss"]
+
+ if self.use_center_loss:
+ center_loss = self.center_loss_func(
+ predicts, batch)["loss_center"] * self.center_loss_weight
+ loss = loss + center_loss
+
+ if self.use_ace_loss:
+ ace_loss = self.ace_loss_func(
+ predicts, batch)["loss_ace"] * self.ace_loss_weight
+ loss = loss + ace_loss
+
+ return {'enhanced_ctc_loss': loss}
diff --git a/backend/ppocr/losses/rec_multi_loss.py b/backend/ppocr/losses/rec_multi_loss.py
new file mode 100644
index 00000000..09f007af
--- /dev/null
+++ b/backend/ppocr/losses/rec_multi_loss.py
@@ -0,0 +1,58 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+
+from .rec_ctc_loss import CTCLoss
+from .rec_sar_loss import SARLoss
+
+
+class MultiLoss(nn.Layer):
+ def __init__(self, **kwargs):
+ super().__init__()
+ self.loss_funcs = {}
+ self.loss_list = kwargs.pop('loss_config_list')
+ self.weight_1 = kwargs.get('weight_1', 1.0)
+ self.weight_2 = kwargs.get('weight_2', 1.0)
+ self.gtc_loss = kwargs.get('gtc_loss', 'sar')
+ for loss_info in self.loss_list:
+ for name, param in loss_info.items():
+ if param is not None:
+ kwargs.update(param)
+ loss = eval(name)(**kwargs)
+ self.loss_funcs[name] = loss
+
+ def forward(self, predicts, batch):
+ self.total_loss = {}
+ total_loss = 0.0
+ # batch [image, label_ctc, label_sar, length, valid_ratio]
+ for name, loss_func in self.loss_funcs.items():
+ if name == 'CTCLoss':
+ loss = loss_func(predicts['ctc'],
+ batch[:2] + batch[3:])['loss'] * self.weight_1
+ elif name == 'SARLoss':
+ loss = loss_func(predicts['sar'],
+ batch[:1] + batch[2:])['loss'] * self.weight_2
+ else:
+ raise NotImplementedError(
+ '{} is not supported in MultiLoss yet'.format(name))
+ self.total_loss[name] = loss
+ total_loss += loss
+ self.total_loss['loss'] = total_loss
+ return self.total_loss
diff --git a/backend/ppocr/losses/rec_nrtr_loss.py b/backend/ppocr/losses/rec_nrtr_loss.py
new file mode 100644
index 00000000..200a6d04
--- /dev/null
+++ b/backend/ppocr/losses/rec_nrtr_loss.py
@@ -0,0 +1,30 @@
+import paddle
+from paddle import nn
+import paddle.nn.functional as F
+
+
+class NRTRLoss(nn.Layer):
+ def __init__(self, smoothing=True, **kwargs):
+ super(NRTRLoss, self).__init__()
+ self.loss_func = nn.CrossEntropyLoss(reduction='mean', ignore_index=0)
+ self.smoothing = smoothing
+
+ def forward(self, pred, batch):
+ pred = pred.reshape([-1, pred.shape[2]])
+ max_len = batch[2].max()
+ tgt = batch[1][:, 1:2 + max_len]
+ tgt = tgt.reshape([-1])
+ if self.smoothing:
+ eps = 0.1
+ n_class = pred.shape[1]
+ one_hot = F.one_hot(tgt, pred.shape[1])
+ one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
+ log_prb = F.log_softmax(pred, axis=1)
+ non_pad_mask = paddle.not_equal(
+ tgt, paddle.zeros(
+ tgt.shape, dtype=tgt.dtype))
+ loss = -(one_hot * log_prb).sum(axis=1)
+ loss = loss.masked_select(non_pad_mask).mean()
+ else:
+ loss = self.loss_func(pred, tgt)
+ return {'loss': loss}
diff --git a/backend/ppocr/losses/rec_pren_loss.py b/backend/ppocr/losses/rec_pren_loss.py
new file mode 100644
index 00000000..7bc53d29
--- /dev/null
+++ b/backend/ppocr/losses/rec_pren_loss.py
@@ -0,0 +1,30 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from paddle import nn
+
+
+class PRENLoss(nn.Layer):
+ def __init__(self, **kwargs):
+ super(PRENLoss, self).__init__()
+ # note: 0 is padding idx
+ self.loss_func = nn.CrossEntropyLoss(reduction='mean', ignore_index=0)
+
+ def forward(self, predicts, batch):
+ loss = self.loss_func(predicts, batch[1].astype('int64'))
+ return {'loss': loss}
diff --git a/backend/ppocr/losses/rec_sar_loss.py b/backend/ppocr/losses/rec_sar_loss.py
new file mode 100644
index 00000000..a4f83f03
--- /dev/null
+++ b/backend/ppocr/losses/rec_sar_loss.py
@@ -0,0 +1,29 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+
+
+class SARLoss(nn.Layer):
+ def __init__(self, **kwargs):
+ super(SARLoss, self).__init__()
+ ignore_index = kwargs.get('ignore_index', 92) # 6626
+ self.loss_func = paddle.nn.loss.CrossEntropyLoss(
+ reduction="mean", ignore_index=ignore_index)
+
+ def forward(self, predicts, batch):
+ predict = predicts[:, :
+ -1, :] # ignore last index of outputs to be in same seq_len with targets
+ label = batch[1].astype(
+ "int64")[:, 1:] # ignore first index of target in loss calculation
+ batch_size, num_steps, num_classes = predict.shape[0], predict.shape[
+ 1], predict.shape[2]
+ assert len(label.shape) == len(list(predict.shape)) - 1, \
+ "The target's shape and inputs's shape is [N, d] and [N, num_steps]"
+
+ inputs = paddle.reshape(predict, [-1, num_classes])
+ targets = paddle.reshape(label, [-1])
+ loss = self.loss_func(inputs, targets)
+ return {'loss': loss}
diff --git a/backend/ppocr/losses/table_att_loss.py b/backend/ppocr/losses/table_att_loss.py
new file mode 100644
index 00000000..d7fd99e6
--- /dev/null
+++ b/backend/ppocr/losses/table_att_loss.py
@@ -0,0 +1,109 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+from paddle.nn import functional as F
+from paddle import fluid
+
+class TableAttentionLoss(nn.Layer):
+ def __init__(self, structure_weight, loc_weight, use_giou=False, giou_weight=1.0, **kwargs):
+ super(TableAttentionLoss, self).__init__()
+ self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none')
+ self.structure_weight = structure_weight
+ self.loc_weight = loc_weight
+ self.use_giou = use_giou
+ self.giou_weight = giou_weight
+
+ def giou_loss(self, preds, bbox, eps=1e-7, reduction='mean'):
+ '''
+ :param preds:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
+ :param bbox:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
+ :return: loss
+ '''
+ ix1 = fluid.layers.elementwise_max(preds[:, 0], bbox[:, 0])
+ iy1 = fluid.layers.elementwise_max(preds[:, 1], bbox[:, 1])
+ ix2 = fluid.layers.elementwise_min(preds[:, 2], bbox[:, 2])
+ iy2 = fluid.layers.elementwise_min(preds[:, 3], bbox[:, 3])
+
+ iw = fluid.layers.clip(ix2 - ix1 + 1e-3, 0., 1e10)
+ ih = fluid.layers.clip(iy2 - iy1 + 1e-3, 0., 1e10)
+
+ # overlap
+ inters = iw * ih
+
+ # union
+ uni = (preds[:, 2] - preds[:, 0] + 1e-3) * (preds[:, 3] - preds[:, 1] + 1e-3
+ ) + (bbox[:, 2] - bbox[:, 0] + 1e-3) * (
+ bbox[:, 3] - bbox[:, 1] + 1e-3) - inters + eps
+
+ # ious
+ ious = inters / uni
+
+ ex1 = fluid.layers.elementwise_min(preds[:, 0], bbox[:, 0])
+ ey1 = fluid.layers.elementwise_min(preds[:, 1], bbox[:, 1])
+ ex2 = fluid.layers.elementwise_max(preds[:, 2], bbox[:, 2])
+ ey2 = fluid.layers.elementwise_max(preds[:, 3], bbox[:, 3])
+ ew = fluid.layers.clip(ex2 - ex1 + 1e-3, 0., 1e10)
+ eh = fluid.layers.clip(ey2 - ey1 + 1e-3, 0., 1e10)
+
+ # enclose erea
+ enclose = ew * eh + eps
+ giou = ious - (enclose - uni) / enclose
+
+ loss = 1 - giou
+
+ if reduction == 'mean':
+ loss = paddle.mean(loss)
+ elif reduction == 'sum':
+ loss = paddle.sum(loss)
+ else:
+ raise NotImplementedError
+ return loss
+
+ def forward(self, predicts, batch):
+ structure_probs = predicts['structure_probs']
+ structure_targets = batch[1].astype("int64")
+ structure_targets = structure_targets[:, 1:]
+ if len(batch) == 6:
+ structure_mask = batch[5].astype("int64")
+ structure_mask = structure_mask[:, 1:]
+ structure_mask = paddle.reshape(structure_mask, [-1])
+ structure_probs = paddle.reshape(structure_probs, [-1, structure_probs.shape[-1]])
+ structure_targets = paddle.reshape(structure_targets, [-1])
+ structure_loss = self.loss_func(structure_probs, structure_targets)
+
+ if len(batch) == 6:
+ structure_loss = structure_loss * structure_mask
+
+# structure_loss = paddle.sum(structure_loss) * self.structure_weight
+ structure_loss = paddle.mean(structure_loss) * self.structure_weight
+
+ loc_preds = predicts['loc_preds']
+ loc_targets = batch[2].astype("float32")
+ loc_targets_mask = batch[4].astype("float32")
+ loc_targets = loc_targets[:, 1:, :]
+ loc_targets_mask = loc_targets_mask[:, 1:, :]
+ loc_loss = F.mse_loss(loc_preds * loc_targets_mask, loc_targets) * self.loc_weight
+ if self.use_giou:
+ loc_loss_giou = self.giou_loss(loc_preds * loc_targets_mask, loc_targets) * self.giou_weight
+ total_loss = structure_loss + loc_loss + loc_loss_giou
+ return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss, "loc_loss_giou":loc_loss_giou}
+ else:
+ total_loss = structure_loss + loc_loss
+ return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss}
\ No newline at end of file
diff --git a/backend/ppocr/losses/vqa_token_layoutlm_loss.py b/backend/ppocr/losses/vqa_token_layoutlm_loss.py
new file mode 100755
index 00000000..244893d9
--- /dev/null
+++ b/backend/ppocr/losses/vqa_token_layoutlm_loss.py
@@ -0,0 +1,42 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from paddle import nn
+
+
+class VQASerTokenLayoutLMLoss(nn.Layer):
+ def __init__(self, num_classes):
+ super().__init__()
+ self.loss_class = nn.CrossEntropyLoss()
+ self.num_classes = num_classes
+ self.ignore_index = self.loss_class.ignore_index
+
+ def forward(self, predicts, batch):
+ labels = batch[1]
+ attention_mask = batch[4]
+ if attention_mask is not None:
+ active_loss = attention_mask.reshape([-1, ]) == 1
+ active_outputs = predicts.reshape(
+ [-1, self.num_classes])[active_loss]
+ active_labels = labels.reshape([-1, ])[active_loss]
+ loss = self.loss_class(active_outputs, active_labels)
+ else:
+ loss = self.loss_class(
+ predicts.reshape([-1, self.num_classes]),
+ labels.reshape([-1, ]))
+ return {'loss': loss}
diff --git a/backend/ppocr/metrics/__init__.py b/backend/ppocr/metrics/__init__.py
index a0e7d912..c244066c 100644
--- a/backend/ppocr/metrics/__init__.py
+++ b/backend/ppocr/metrics/__init__.py
@@ -19,19 +19,29 @@
import copy
-__all__ = ['build_metric']
+__all__ = ["build_metric"]
+from .det_metric import DetMetric, DetFCEMetric
+from .rec_metric import RecMetric
+from .cls_metric import ClsMetric
+from .e2e_metric import E2EMetric
+from .distillation_metric import DistillationMetric
+from .table_metric import TableMetric
+from .kie_metric import KIEMetric
+from .vqa_token_ser_metric import VQASerTokenMetric
+from .vqa_token_re_metric import VQAReTokenMetric
-def build_metric(config):
- from .det_metric import DetMetric
- from .rec_metric import RecMetric
- from .cls_metric import ClsMetric
- support_dict = ['DetMetric', 'RecMetric', 'ClsMetric']
+def build_metric(config):
+ support_dict = [
+ "DetMetric", "DetFCEMetric", "RecMetric", "ClsMetric", "E2EMetric",
+ "DistillationMetric", "TableMetric", 'KIEMetric', 'VQASerTokenMetric',
+ 'VQAReTokenMetric'
+ ]
config = copy.deepcopy(config)
- module_name = config.pop('name')
+ module_name = config.pop("name")
assert module_name in support_dict, Exception(
- 'metric only support {}'.format(support_dict))
+ "metric only support {}".format(support_dict))
module_class = eval(module_name)(**config)
return module_class
diff --git a/backend/ppocr/metrics/cls_metric.py b/backend/ppocr/metrics/cls_metric.py
index 09817200..6c077518 100644
--- a/backend/ppocr/metrics/cls_metric.py
+++ b/backend/ppocr/metrics/cls_metric.py
@@ -16,6 +16,7 @@
class ClsMetric(object):
def __init__(self, main_indicator='acc', **kwargs):
self.main_indicator = main_indicator
+ self.eps = 1e-5
self.reset()
def __call__(self, pred_label, *args, **kwargs):
@@ -28,7 +29,7 @@ def __call__(self, pred_label, *args, **kwargs):
all_num += 1
self.correct_num += correct_num
self.all_num += all_num
- return {'acc': correct_num / all_num, }
+ return {'acc': correct_num / (all_num + self.eps), }
def get_metric(self):
"""
@@ -36,7 +37,7 @@ def get_metric(self):
'acc': 0
}
"""
- acc = self.correct_num / self.all_num
+ acc = self.correct_num / (self.all_num + self.eps)
self.reset()
return {'acc': acc}
diff --git a/backend/ppocr/metrics/det_metric.py b/backend/ppocr/metrics/det_metric.py
index 0f9e94df..dca94c09 100644
--- a/backend/ppocr/metrics/det_metric.py
+++ b/backend/ppocr/metrics/det_metric.py
@@ -16,7 +16,7 @@
from __future__ import division
from __future__ import print_function
-__all__ = ['DetMetric']
+__all__ = ['DetMetric', 'DetFCEMetric']
from .eval_det_iou import DetectionIoUEvaluator
@@ -64,9 +64,91 @@ def get_metric(self):
}
"""
- metircs = self.evaluator.combine_results(self.results)
+ metrics = self.evaluator.combine_results(self.results)
self.reset()
- return metircs
+ return metrics
def reset(self):
self.results = [] # clear results
+
+
+class DetFCEMetric(object):
+ def __init__(self, main_indicator='hmean', **kwargs):
+ self.evaluator = DetectionIoUEvaluator()
+ self.main_indicator = main_indicator
+ self.reset()
+
+ def __call__(self, preds, batch, **kwargs):
+ '''
+ batch: a list produced by dataloaders.
+ image: np.ndarray of shape (N, C, H, W).
+ ratio_list: np.ndarray of shape(N,2)
+ polygons: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
+ ignore_tags: np.ndarray of shape (N, K), indicates whether a region is ignorable or not.
+ preds: a list of dict produced by post process
+ points: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
+ '''
+ gt_polyons_batch = batch[2]
+ ignore_tags_batch = batch[3]
+
+ for pred, gt_polyons, ignore_tags in zip(preds, gt_polyons_batch,
+ ignore_tags_batch):
+ # prepare gt
+ gt_info_list = [{
+ 'points': gt_polyon,
+ 'text': '',
+ 'ignore': ignore_tag
+ } for gt_polyon, ignore_tag in zip(gt_polyons, ignore_tags)]
+ # prepare det
+ det_info_list = [{
+ 'points': det_polyon,
+ 'text': '',
+ 'score': score
+ } for det_polyon, score in zip(pred['points'], pred['scores'])]
+
+ for score_thr in self.results.keys():
+ det_info_list_thr = [
+ det_info for det_info in det_info_list
+ if det_info['score'] >= score_thr
+ ]
+ result = self.evaluator.evaluate_image(gt_info_list,
+ det_info_list_thr)
+ self.results[score_thr].append(result)
+
+ def get_metric(self):
+ """
+ return metrics {'heman':0,
+ 'thr 0.3':'precision: 0 recall: 0 hmean: 0',
+ 'thr 0.4':'precision: 0 recall: 0 hmean: 0',
+ 'thr 0.5':'precision: 0 recall: 0 hmean: 0',
+ 'thr 0.6':'precision: 0 recall: 0 hmean: 0',
+ 'thr 0.7':'precision: 0 recall: 0 hmean: 0',
+ 'thr 0.8':'precision: 0 recall: 0 hmean: 0',
+ 'thr 0.9':'precision: 0 recall: 0 hmean: 0',
+ }
+ """
+ metrics = {}
+ hmean = 0
+ for score_thr in self.results.keys():
+ metric = self.evaluator.combine_results(self.results[score_thr])
+ # for key, value in metric.items():
+ # metrics['{}_{}'.format(key, score_thr)] = value
+ metric_str = 'precision:{:.5f} recall:{:.5f} hmean:{:.5f}'.format(
+ metric['precision'], metric['recall'], metric['hmean'])
+ metrics['thr {}'.format(score_thr)] = metric_str
+ hmean = max(hmean, metric['hmean'])
+ metrics['hmean'] = hmean
+
+ self.reset()
+ return metrics
+
+ def reset(self):
+ self.results = {
+ 0.3: [],
+ 0.4: [],
+ 0.5: [],
+ 0.6: [],
+ 0.7: [],
+ 0.8: [],
+ 0.9: []
+ } # clear results
diff --git a/backend/ppocr/metrics/distillation_metric.py b/backend/ppocr/metrics/distillation_metric.py
new file mode 100644
index 00000000..c440cebd
--- /dev/null
+++ b/backend/ppocr/metrics/distillation_metric.py
@@ -0,0 +1,73 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib
+import copy
+
+from .rec_metric import RecMetric
+from .det_metric import DetMetric
+from .e2e_metric import E2EMetric
+from .cls_metric import ClsMetric
+
+
+class DistillationMetric(object):
+ def __init__(self,
+ key=None,
+ base_metric_name=None,
+ main_indicator=None,
+ **kwargs):
+ self.main_indicator = main_indicator
+ self.key = key
+ self.main_indicator = main_indicator
+ self.base_metric_name = base_metric_name
+ self.kwargs = kwargs
+ self.metrics = None
+
+ def _init_metrcis(self, preds):
+ self.metrics = dict()
+ mod = importlib.import_module(__name__)
+ for key in preds:
+ self.metrics[key] = getattr(mod, self.base_metric_name)(
+ main_indicator=self.main_indicator, **self.kwargs)
+ self.metrics[key].reset()
+
+ def __call__(self, preds, batch, **kwargs):
+ assert isinstance(preds, dict)
+ if self.metrics is None:
+ self._init_metrcis(preds)
+ output = dict()
+ for key in preds:
+ self.metrics[key].__call__(preds[key], batch, **kwargs)
+
+ def get_metric(self):
+ """
+ return metrics {
+ 'acc': 0,
+ 'norm_edit_dis': 0,
+ }
+ """
+ output = dict()
+ for key in self.metrics:
+ metric = self.metrics[key].get_metric()
+ # main indicator
+ if key == self.key:
+ output.update(metric)
+ else:
+ for sub_key in metric:
+ output["{}_{}".format(key, sub_key)] = metric[sub_key]
+ return output
+
+ def reset(self):
+ for key in self.metrics:
+ self.metrics[key].reset()
diff --git a/backend/ppocr/metrics/e2e_metric.py b/backend/ppocr/metrics/e2e_metric.py
new file mode 100644
index 00000000..2f8ba3b2
--- /dev/null
+++ b/backend/ppocr/metrics/e2e_metric.py
@@ -0,0 +1,86 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+__all__ = ['E2EMetric']
+
+from ppocr.utils.e2e_metric.Deteval import get_socre_A, get_socre_B, combine_results
+from ppocr.utils.e2e_utils.extract_textpoint_slow import get_dict
+
+
+class E2EMetric(object):
+ def __init__(self,
+ mode,
+ gt_mat_dir,
+ character_dict_path,
+ main_indicator='f_score_e2e',
+ **kwargs):
+ self.mode = mode
+ self.gt_mat_dir = gt_mat_dir
+ self.label_list = get_dict(character_dict_path)
+ self.max_index = len(self.label_list)
+ self.main_indicator = main_indicator
+ self.reset()
+
+ def __call__(self, preds, batch, **kwargs):
+ if self.mode == 'A':
+ gt_polyons_batch = batch[2]
+ temp_gt_strs_batch = batch[3][0]
+ ignore_tags_batch = batch[4]
+ gt_strs_batch = []
+
+ for temp_list in temp_gt_strs_batch:
+ t = ""
+ for index in temp_list:
+ if index < self.max_index:
+ t += self.label_list[index]
+ gt_strs_batch.append(t)
+
+ for pred, gt_polyons, gt_strs, ignore_tags in zip(
+ [preds], gt_polyons_batch, [gt_strs_batch], ignore_tags_batch):
+ # prepare gt
+ gt_info_list = [{
+ 'points': gt_polyon,
+ 'text': gt_str,
+ 'ignore': ignore_tag
+ } for gt_polyon, gt_str, ignore_tag in
+ zip(gt_polyons, gt_strs, ignore_tags)]
+ # prepare det
+ e2e_info_list = [{
+ 'points': det_polyon,
+ 'texts': pred_str
+ } for det_polyon, pred_str in
+ zip(pred['points'], pred['texts'])]
+
+ result = get_socre_A(gt_info_list, e2e_info_list)
+ self.results.append(result)
+ else:
+ img_id = batch[5][0]
+ e2e_info_list = [{
+ 'points': det_polyon,
+ 'texts': pred_str
+ } for det_polyon, pred_str in zip(preds['points'], preds['texts'])]
+ result = get_socre_B(self.gt_mat_dir, img_id, e2e_info_list)
+ self.results.append(result)
+
+ def get_metric(self):
+ metrics = combine_results(self.results)
+ self.reset()
+ return metrics
+
+ def reset(self):
+ self.results = [] # clear results
diff --git a/backend/ppocr/metrics/eval_det_iou.py b/backend/ppocr/metrics/eval_det_iou.py
index a2a3f418..bc05e7df 100644
--- a/backend/ppocr/metrics/eval_det_iou.py
+++ b/backend/ppocr/metrics/eval_det_iou.py
@@ -150,7 +150,7 @@ def compute_ap(confList, matchList, numGtCare):
pairs.append({'gt': gtNum, 'det': detNum})
detMatchedNums.append(detNum)
evaluationLog += "Match GT #" + \
- str(gtNum) + " with Det #" + str(detNum) + "\n"
+ str(gtNum) + " with Det #" + str(detNum) + "\n"
numGtCare = (len(gtPols) - len(gtDontCarePolsNum))
numDetCare = (len(detPols) - len(detDontCarePolsNum))
@@ -162,28 +162,17 @@ def compute_ap(confList, matchList, numGtCare):
precision = 0 if numDetCare == 0 else float(detMatched) / numDetCare
hmean = 0 if (precision + recall) == 0 else 2.0 * \
- precision * recall / (precision + recall)
+ precision * recall / (precision + recall)
matchedSum += detMatched
numGlobalCareGt += numGtCare
numGlobalCareDet += numDetCare
perSampleMetrics = {
- 'precision': precision,
- 'recall': recall,
- 'hmean': hmean,
- 'pairs': pairs,
- 'iouMat': [] if len(detPols) > 100 else iouMat.tolist(),
- 'gtPolPoints': gtPolPoints,
- 'detPolPoints': detPolPoints,
'gtCare': numGtCare,
'detCare': numDetCare,
- 'gtDontCare': gtDontCarePolsNum,
- 'detDontCare': detDontCarePolsNum,
'detMatched': detMatched,
- 'evaluationLog': evaluationLog
}
-
return perSampleMetrics
def combine_results(self, results):
@@ -200,7 +189,8 @@ def combine_results(self, results):
methodPrecision = 0 if numGlobalCareDet == 0 else float(
matchedSum) / numGlobalCareDet
methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \
- methodRecall * methodPrecision / (methodRecall + methodPrecision)
+ methodRecall * methodPrecision / (
+ methodRecall + methodPrecision)
# print(methodRecall, methodPrecision, methodHmean)
# sys.exit(-1)
methodMetrics = {
diff --git a/backend/ppocr/metrics/kie_metric.py b/backend/ppocr/metrics/kie_metric.py
new file mode 100644
index 00000000..28ab22b8
--- /dev/null
+++ b/backend/ppocr/metrics/kie_metric.py
@@ -0,0 +1,71 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# The code is refer from: https://github.com/open-mmlab/mmocr/blob/main/mmocr/core/evaluation/kie_metric.py
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import paddle
+
+__all__ = ['KIEMetric']
+
+
+class KIEMetric(object):
+ def __init__(self, main_indicator='hmean', **kwargs):
+ self.main_indicator = main_indicator
+ self.reset()
+ self.node = []
+ self.gt = []
+
+ def __call__(self, preds, batch, **kwargs):
+ nodes, _ = preds
+ gts, tag = batch[4].squeeze(0), batch[5].tolist()[0]
+ gts = gts[:tag[0], :1].reshape([-1])
+ self.node.append(nodes.numpy())
+ self.gt.append(gts)
+ # result = self.compute_f1_score(nodes, gts)
+ # self.results.append(result)
+
+ def compute_f1_score(self, preds, gts):
+ ignores = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 25]
+ C = preds.shape[1]
+ classes = np.array(sorted(set(range(C)) - set(ignores)))
+ hist = np.bincount(
+ (gts * C).astype('int64') + preds.argmax(1), minlength=C
+ **2).reshape([C, C]).astype('float32')
+ diag = np.diag(hist)
+ recalls = diag / hist.sum(1).clip(min=1)
+ precisions = diag / hist.sum(0).clip(min=1)
+ f1 = 2 * recalls * precisions / (recalls + precisions).clip(min=1e-8)
+ return f1[classes]
+
+ def combine_results(self, results):
+ node = np.concatenate(self.node, 0)
+ gts = np.concatenate(self.gt, 0)
+ results = self.compute_f1_score(node, gts)
+ data = {'hmean': results.mean()}
+ return data
+
+ def get_metric(self):
+
+ metrics = self.combine_results(self.results)
+ self.reset()
+ return metrics
+
+ def reset(self):
+ self.results = [] # clear results
+ self.node = []
+ self.gt = []
diff --git a/backend/ppocr/metrics/rec_metric.py b/backend/ppocr/metrics/rec_metric.py
index 66c084d7..515b9372 100644
--- a/backend/ppocr/metrics/rec_metric.py
+++ b/backend/ppocr/metrics/rec_metric.py
@@ -13,21 +13,38 @@
# limitations under the License.
import Levenshtein
+import string
class RecMetric(object):
- def __init__(self, main_indicator='acc', **kwargs):
+ def __init__(self,
+ main_indicator='acc',
+ is_filter=False,
+ ignore_space=True,
+ **kwargs):
self.main_indicator = main_indicator
+ self.is_filter = is_filter
+ self.ignore_space = ignore_space
+ self.eps = 1e-5
self.reset()
+ def _normalize_text(self, text):
+ text = ''.join(
+ filter(lambda x: x in (string.digits + string.ascii_letters), text))
+ return text.lower()
+
def __call__(self, pred_label, *args, **kwargs):
preds, labels = pred_label
correct_num = 0
all_num = 0
norm_edit_dis = 0.0
for (pred, pred_conf), (target, _) in zip(preds, labels):
- pred = pred.replace(" ", "")
- target = target.replace(" ", "")
+ if self.ignore_space:
+ pred = pred.replace(" ", "")
+ target = target.replace(" ", "")
+ if self.is_filter:
+ pred = self._normalize_text(pred)
+ target = self._normalize_text(target)
norm_edit_dis += Levenshtein.distance(pred, target) / max(
len(pred), len(target), 1)
if pred == target:
@@ -37,8 +54,8 @@ def __call__(self, pred_label, *args, **kwargs):
self.all_num += all_num
self.norm_edit_dis += norm_edit_dis
return {
- 'acc': correct_num / all_num,
- 'norm_edit_dis': 1 - norm_edit_dis / all_num
+ 'acc': correct_num / (all_num + self.eps),
+ 'norm_edit_dis': 1 - norm_edit_dis / (all_num + self.eps)
}
def get_metric(self):
@@ -48,8 +65,8 @@ def get_metric(self):
'norm_edit_dis': 0,
}
"""
- acc = 1.0 * self.correct_num / self.all_num
- norm_edit_dis = 1 - self.norm_edit_dis / self.all_num
+ acc = 1.0 * self.correct_num / (self.all_num + self.eps)
+ norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num + self.eps)
self.reset()
return {'acc': acc, 'norm_edit_dis': norm_edit_dis}
diff --git a/backend/ppocr/metrics/table_metric.py b/backend/ppocr/metrics/table_metric.py
new file mode 100644
index 00000000..ca4d6474
--- /dev/null
+++ b/backend/ppocr/metrics/table_metric.py
@@ -0,0 +1,51 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import numpy as np
+
+
+class TableMetric(object):
+ def __init__(self, main_indicator='acc', **kwargs):
+ self.main_indicator = main_indicator
+ self.eps = 1e-5
+ self.reset()
+
+ def __call__(self, pred, batch, *args, **kwargs):
+ structure_probs = pred['structure_probs'].numpy()
+ structure_labels = batch[1]
+ correct_num = 0
+ all_num = 0
+ structure_probs = np.argmax(structure_probs, axis=2)
+ structure_labels = structure_labels[:, 1:]
+ batch_size = structure_probs.shape[0]
+ for bno in range(batch_size):
+ all_num += 1
+ if (structure_probs[bno] == structure_labels[bno]).all():
+ correct_num += 1
+ self.correct_num += correct_num
+ self.all_num += all_num
+ return {'acc': correct_num * 1.0 / (all_num + self.eps), }
+
+ def get_metric(self):
+ """
+ return metrics {
+ 'acc': 0,
+ }
+ """
+ acc = 1.0 * self.correct_num / (self.all_num + self.eps)
+ self.reset()
+ return {'acc': acc}
+
+ def reset(self):
+ self.correct_num = 0
+ self.all_num = 0
diff --git a/backend/ppocr/metrics/vqa_token_re_metric.py b/backend/ppocr/metrics/vqa_token_re_metric.py
new file mode 100644
index 00000000..8a13bc08
--- /dev/null
+++ b/backend/ppocr/metrics/vqa_token_re_metric.py
@@ -0,0 +1,176 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import paddle
+
+__all__ = ['KIEMetric']
+
+
+class VQAReTokenMetric(object):
+ def __init__(self, main_indicator='hmean', **kwargs):
+ self.main_indicator = main_indicator
+ self.reset()
+
+ def __call__(self, preds, batch, **kwargs):
+ pred_relations, relations, entities = preds
+ self.pred_relations_list.extend(pred_relations)
+ self.relations_list.extend(relations)
+ self.entities_list.extend(entities)
+
+ def get_metric(self):
+ gt_relations = []
+ for b in range(len(self.relations_list)):
+ rel_sent = []
+ for head, tail in zip(self.relations_list[b]["head"],
+ self.relations_list[b]["tail"]):
+ rel = {}
+ rel["head_id"] = head
+ rel["head"] = (self.entities_list[b]["start"][rel["head_id"]],
+ self.entities_list[b]["end"][rel["head_id"]])
+ rel["head_type"] = self.entities_list[b]["label"][rel[
+ "head_id"]]
+
+ rel["tail_id"] = tail
+ rel["tail"] = (self.entities_list[b]["start"][rel["tail_id"]],
+ self.entities_list[b]["end"][rel["tail_id"]])
+ rel["tail_type"] = self.entities_list[b]["label"][rel[
+ "tail_id"]]
+
+ rel["type"] = 1
+ rel_sent.append(rel)
+ gt_relations.append(rel_sent)
+ re_metrics = self.re_score(
+ self.pred_relations_list, gt_relations, mode="boundaries")
+ metrics = {
+ "precision": re_metrics["ALL"]["p"],
+ "recall": re_metrics["ALL"]["r"],
+ "hmean": re_metrics["ALL"]["f1"],
+ }
+ self.reset()
+ return metrics
+
+ def reset(self):
+ self.pred_relations_list = []
+ self.relations_list = []
+ self.entities_list = []
+
+ def re_score(self, pred_relations, gt_relations, mode="strict"):
+ """Evaluate RE predictions
+
+ Args:
+ pred_relations (list) : list of list of predicted relations (several relations in each sentence)
+ gt_relations (list) : list of list of ground truth relations
+
+ rel = { "head": (start_idx (inclusive), end_idx (exclusive)),
+ "tail": (start_idx (inclusive), end_idx (exclusive)),
+ "head_type": ent_type,
+ "tail_type": ent_type,
+ "type": rel_type}
+
+ vocab (Vocab) : dataset vocabulary
+ mode (str) : in 'strict' or 'boundaries'"""
+
+ assert mode in ["strict", "boundaries"]
+
+ relation_types = [v for v in [0, 1] if not v == 0]
+ scores = {
+ rel: {
+ "tp": 0,
+ "fp": 0,
+ "fn": 0
+ }
+ for rel in relation_types + ["ALL"]
+ }
+
+ # Count GT relations and Predicted relations
+ n_sents = len(gt_relations)
+ n_rels = sum([len([rel for rel in sent]) for sent in gt_relations])
+ n_found = sum([len([rel for rel in sent]) for sent in pred_relations])
+
+ # Count TP, FP and FN per type
+ for pred_sent, gt_sent in zip(pred_relations, gt_relations):
+ for rel_type in relation_types:
+ # strict mode takes argument types into account
+ if mode == "strict":
+ pred_rels = {(rel["head"], rel["head_type"], rel["tail"],
+ rel["tail_type"])
+ for rel in pred_sent
+ if rel["type"] == rel_type}
+ gt_rels = {(rel["head"], rel["head_type"], rel["tail"],
+ rel["tail_type"])
+ for rel in gt_sent if rel["type"] == rel_type}
+
+ # boundaries mode only takes argument spans into account
+ elif mode == "boundaries":
+ pred_rels = {(rel["head"], rel["tail"])
+ for rel in pred_sent
+ if rel["type"] == rel_type}
+ gt_rels = {(rel["head"], rel["tail"])
+ for rel in gt_sent if rel["type"] == rel_type}
+
+ scores[rel_type]["tp"] += len(pred_rels & gt_rels)
+ scores[rel_type]["fp"] += len(pred_rels - gt_rels)
+ scores[rel_type]["fn"] += len(gt_rels - pred_rels)
+
+ # Compute per entity Precision / Recall / F1
+ for rel_type in scores.keys():
+ if scores[rel_type]["tp"]:
+ scores[rel_type]["p"] = scores[rel_type]["tp"] / (
+ scores[rel_type]["fp"] + scores[rel_type]["tp"])
+ scores[rel_type]["r"] = scores[rel_type]["tp"] / (
+ scores[rel_type]["fn"] + scores[rel_type]["tp"])
+ else:
+ scores[rel_type]["p"], scores[rel_type]["r"] = 0, 0
+
+ if not scores[rel_type]["p"] + scores[rel_type]["r"] == 0:
+ scores[rel_type]["f1"] = (
+ 2 * scores[rel_type]["p"] * scores[rel_type]["r"] /
+ (scores[rel_type]["p"] + scores[rel_type]["r"]))
+ else:
+ scores[rel_type]["f1"] = 0
+
+ # Compute micro F1 Scores
+ tp = sum([scores[rel_type]["tp"] for rel_type in relation_types])
+ fp = sum([scores[rel_type]["fp"] for rel_type in relation_types])
+ fn = sum([scores[rel_type]["fn"] for rel_type in relation_types])
+
+ if tp:
+ precision = tp / (tp + fp)
+ recall = tp / (tp + fn)
+ f1 = 2 * precision * recall / (precision + recall)
+
+ else:
+ precision, recall, f1 = 0, 0, 0
+
+ scores["ALL"]["p"] = precision
+ scores["ALL"]["r"] = recall
+ scores["ALL"]["f1"] = f1
+ scores["ALL"]["tp"] = tp
+ scores["ALL"]["fp"] = fp
+ scores["ALL"]["fn"] = fn
+
+ # Compute Macro F1 Scores
+ scores["ALL"]["Macro_f1"] = np.mean(
+ [scores[ent_type]["f1"] for ent_type in relation_types])
+ scores["ALL"]["Macro_p"] = np.mean(
+ [scores[ent_type]["p"] for ent_type in relation_types])
+ scores["ALL"]["Macro_r"] = np.mean(
+ [scores[ent_type]["r"] for ent_type in relation_types])
+
+ return scores
diff --git a/backend/ppocr/metrics/vqa_token_ser_metric.py b/backend/ppocr/metrics/vqa_token_ser_metric.py
new file mode 100644
index 00000000..286d8add
--- /dev/null
+++ b/backend/ppocr/metrics/vqa_token_ser_metric.py
@@ -0,0 +1,47 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import paddle
+
+__all__ = ['KIEMetric']
+
+
+class VQASerTokenMetric(object):
+ def __init__(self, main_indicator='hmean', **kwargs):
+ self.main_indicator = main_indicator
+ self.reset()
+
+ def __call__(self, preds, batch, **kwargs):
+ preds, labels = preds
+ self.pred_list.extend(preds)
+ self.gt_list.extend(labels)
+
+ def get_metric(self):
+ from seqeval.metrics import f1_score, precision_score, recall_score
+ metrics = {
+ "precision": precision_score(self.gt_list, self.pred_list),
+ "recall": recall_score(self.gt_list, self.pred_list),
+ "hmean": f1_score(self.gt_list, self.pred_list),
+ }
+ self.reset()
+ return metrics
+
+ def reset(self):
+ self.pred_list = []
+ self.gt_list = []
diff --git a/backend/ppocr/modeling/architectures/__init__.py b/backend/ppocr/modeling/architectures/__init__.py
index 86eaf7c9..e9a01cf0 100755
--- a/backend/ppocr/modeling/architectures/__init__.py
+++ b/backend/ppocr/modeling/architectures/__init__.py
@@ -13,12 +13,20 @@
# limitations under the License.
import copy
+import importlib
+
+from .base_model import BaseModel
+from .distillation_model import DistillationModel
__all__ = ['build_model']
+
def build_model(config):
- from .base_model import BaseModel
-
config = copy.deepcopy(config)
- module_class = BaseModel(config)
- return module_class
\ No newline at end of file
+ if not "name" in config:
+ arch = BaseModel(config)
+ else:
+ name = config.pop("name")
+ mod = importlib.import_module(__name__)
+ arch = getattr(mod, name)(config)
+ return arch
diff --git a/backend/ppocr/modeling/architectures/base_model.py b/backend/ppocr/modeling/architectures/base_model.py
index 09b6e034..c6b50d48 100644
--- a/backend/ppocr/modeling/architectures/base_model.py
+++ b/backend/ppocr/modeling/architectures/base_model.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,7 +14,6 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-
from paddle import nn
from ppocr.modeling.transforms import build_transform
from ppocr.modeling.backbones import build_backbone
@@ -32,7 +31,6 @@ def __init__(self, config):
config (dict): the super parameters for module.
"""
super(BaseModel, self).__init__()
-
in_channels = config.get('in_channels', 3)
model_type = config['model_type']
# build transfrom,
@@ -65,17 +63,38 @@ def __init__(self, config):
in_channels = self.neck.out_channels
# # build head, head is need for det, rec and cls
- config["Head"]['in_channels'] = in_channels
- self.head = build_head(config["Head"])
+ if 'Head' not in config or config['Head'] is None:
+ self.use_head = False
+ else:
+ self.use_head = True
+ config["Head"]['in_channels'] = in_channels
+ self.head = build_head(config["Head"])
+
+ self.return_all_feats = config.get("return_all_feats", False)
def forward(self, x, data=None):
+ y = dict()
if self.use_transform:
x = self.transform(x)
x = self.backbone(x)
+ y["backbone_out"] = x
if self.use_neck:
x = self.neck(x)
- if data is None:
- x = self.head(x)
+ y["neck_out"] = x
+ if self.use_head:
+ x = self.head(x, targets=data)
+ # for multi head, save ctc neck out for udml
+ if isinstance(x, dict) and 'ctc_neck' in x.keys():
+ y["neck_out"] = x["ctc_neck"]
+ y["head_out"] = x
+ elif isinstance(x, dict):
+ y.update(x)
+ else:
+ y["head_out"] = x
+ if self.return_all_feats:
+ if self.training:
+ return y
+ else:
+ return {"head_out": y["head_out"]}
else:
- x = self.head(x, data)
- return x
+ return x
diff --git a/backend/ppocr/modeling/architectures/distillation_model.py b/backend/ppocr/modeling/architectures/distillation_model.py
new file mode 100644
index 00000000..cce8fd31
--- /dev/null
+++ b/backend/ppocr/modeling/architectures/distillation_model.py
@@ -0,0 +1,60 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from paddle import nn
+from ppocr.modeling.transforms import build_transform
+from ppocr.modeling.backbones import build_backbone
+from ppocr.modeling.necks import build_neck
+from ppocr.modeling.heads import build_head
+from .base_model import BaseModel
+from ppocr.utils.save_load import load_pretrained_params
+
+__all__ = ['DistillationModel']
+
+
+class DistillationModel(nn.Layer):
+ def __init__(self, config):
+ """
+ the module for OCR distillation.
+ args:
+ config (dict): the super parameters for module.
+ """
+ super().__init__()
+ self.model_list = []
+ self.model_name_list = []
+ for key in config["Models"]:
+ model_config = config["Models"][key]
+ freeze_params = False
+ pretrained = None
+ if "freeze_params" in model_config:
+ freeze_params = model_config.pop("freeze_params")
+ if "pretrained" in model_config:
+ pretrained = model_config.pop("pretrained")
+ model = BaseModel(model_config)
+ if pretrained is not None:
+ load_pretrained_params(model, pretrained)
+ if freeze_params:
+ for param in model.parameters():
+ param.trainable = False
+ self.model_list.append(self.add_sublayer(key, model))
+ self.model_name_list.append(key)
+
+ def forward(self, x, data=None):
+ result_dict = dict()
+ for idx, model_name in enumerate(self.model_name_list):
+ result_dict[model_name] = self.model_list[idx](x, data)
+ return result_dict
diff --git a/backend/ppocr/modeling/backbones/__init__.py b/backend/ppocr/modeling/backbones/__init__.py
index 03c15508..072d6e0f 100755
--- a/backend/ppocr/modeling/backbones/__init__.py
+++ b/backend/ppocr/modeling/backbones/__init__.py
@@ -12,26 +12,53 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__all__ = ['build_backbone']
+__all__ = ["build_backbone"]
def build_backbone(config, model_type):
- if model_type == 'det':
+ if model_type == "det" or model_type == "table":
from .det_mobilenet_v3 import MobileNetV3
from .det_resnet_vd import ResNet
from .det_resnet_vd_sast import ResNet_SAST
- support_dict = ['MobileNetV3', 'ResNet', 'ResNet_SAST']
- elif model_type == 'rec' or model_type == 'cls':
+ support_dict = ["MobileNetV3", "ResNet", "ResNet_SAST"]
+ elif model_type == "rec" or model_type == "cls":
from .rec_mobilenet_v3 import MobileNetV3
from .rec_resnet_vd import ResNet
from .rec_resnet_fpn import ResNetFPN
- support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN']
+ from .rec_mv1_enhance import MobileNetV1Enhance
+ from .rec_nrtr_mtb import MTB
+ from .rec_resnet_31 import ResNet31
+ from .rec_resnet_aster import ResNet_ASTER
+ from .rec_micronet import MicroNet
+ from .rec_efficientb3_pren import EfficientNetb3_PREN
+ from .rec_svtrnet import SVTRNet
+ support_dict = [
+ 'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
+ "ResNet31", "ResNet_ASTER", 'MicroNet', 'EfficientNetb3_PREN',
+ 'SVTRNet'
+ ]
+ elif model_type == "e2e":
+ from .e2e_resnet_vd_pg import ResNet
+ support_dict = ['ResNet']
+ elif model_type == 'kie':
+ from .kie_unet_sdmgr import Kie_backbone
+ support_dict = ['Kie_backbone']
+ elif model_type == "table":
+ from .table_resnet_vd import ResNet
+ from .table_mobilenet_v3 import MobileNetV3
+ support_dict = ["ResNet", "MobileNetV3"]
+ elif model_type == 'vqa':
+ from .vqa_layoutlm import LayoutLMForSer, LayoutLMv2ForSer, LayoutLMv2ForRe, LayoutXLMForSer, LayoutXLMForRe
+ support_dict = [
+ "LayoutLMForSer", "LayoutLMv2ForSer", 'LayoutLMv2ForRe',
+ "LayoutXLMForSer", 'LayoutXLMForRe'
+ ]
else:
raise NotImplementedError
- module_name = config.pop('name')
+ module_name = config.pop("name")
assert module_name in support_dict, Exception(
- 'when model typs is {}, backbone only support {}'.format(model_type,
+ "when model typs is {}, backbone only support {}".format(model_type,
support_dict))
module_class = eval(module_name)(**config)
return module_class
diff --git a/backend/ppocr/modeling/backbones/det_mobilenet_v3.py b/backend/ppocr/modeling/backbones/det_mobilenet_v3.py
index bb451bbe..05113ea8 100755
--- a/backend/ppocr/modeling/backbones/det_mobilenet_v3.py
+++ b/backend/ppocr/modeling/backbones/det_mobilenet_v3.py
@@ -102,8 +102,7 @@ def __init__(self,
padding=1,
groups=1,
if_act=True,
- act='hardswish',
- name='conv1')
+ act='hardswish')
self.stages = []
self.out_channels = []
@@ -125,8 +124,7 @@ def __init__(self,
kernel_size=k,
stride=s,
use_se=se,
- act=nl,
- name="conv" + str(i + 2)))
+ act=nl))
inplanes = make_divisible(scale * c)
i += 1
block_list.append(
@@ -138,8 +136,7 @@ def __init__(self,
padding=0,
groups=1,
if_act=True,
- act='hardswish',
- name='conv_last'))
+ act='hardswish'))
self.stages.append(nn.Sequential(*block_list))
self.out_channels.append(make_divisible(scale * cls_ch_squeeze))
for i, stage in enumerate(self.stages):
@@ -163,8 +160,7 @@ def __init__(self,
padding,
groups=1,
if_act=True,
- act=None,
- name=None):
+ act=None):
super(ConvBNLayer, self).__init__()
self.if_act = if_act
self.act = act
@@ -175,16 +171,9 @@ def __init__(self,
stride=stride,
padding=padding,
groups=groups,
- weight_attr=ParamAttr(name=name + '_weights'),
bias_attr=False)
- self.bn = nn.BatchNorm(
- num_channels=out_channels,
- act=None,
- param_attr=ParamAttr(name=name + "_bn_scale"),
- bias_attr=ParamAttr(name=name + "_bn_offset"),
- moving_mean_name=name + "_bn_mean",
- moving_variance_name=name + "_bn_variance")
+ self.bn = nn.BatchNorm(num_channels=out_channels, act=None)
def forward(self, x):
x = self.conv(x)
@@ -209,8 +198,7 @@ def __init__(self,
kernel_size,
stride,
use_se,
- act=None,
- name=''):
+ act=None):
super(ResidualUnit, self).__init__()
self.if_shortcut = stride == 1 and in_channels == out_channels
self.if_se = use_se
@@ -222,8 +210,7 @@ def __init__(self,
stride=1,
padding=0,
if_act=True,
- act=act,
- name=name + "_expand")
+ act=act)
self.bottleneck_conv = ConvBNLayer(
in_channels=mid_channels,
out_channels=mid_channels,
@@ -232,10 +219,9 @@ def __init__(self,
padding=int((kernel_size - 1) // 2),
groups=mid_channels,
if_act=True,
- act=act,
- name=name + "_depthwise")
+ act=act)
if self.if_se:
- self.mid_se = SEModule(mid_channels, name=name + "_se")
+ self.mid_se = SEModule(mid_channels)
self.linear_conv = ConvBNLayer(
in_channels=mid_channels,
out_channels=out_channels,
@@ -243,8 +229,7 @@ def __init__(self,
stride=1,
padding=0,
if_act=False,
- act=None,
- name=name + "_linear")
+ act=None)
def forward(self, inputs):
x = self.expand_conv(inputs)
@@ -258,7 +243,7 @@ def forward(self, inputs):
class SEModule(nn.Layer):
- def __init__(self, in_channels, reduction=4, name=""):
+ def __init__(self, in_channels, reduction=4):
super(SEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2D(1)
self.conv1 = nn.Conv2D(
@@ -266,17 +251,13 @@ def __init__(self, in_channels, reduction=4, name=""):
out_channels=in_channels // reduction,
kernel_size=1,
stride=1,
- padding=0,
- weight_attr=ParamAttr(name=name + "_1_weights"),
- bias_attr=ParamAttr(name=name + "_1_offset"))
+ padding=0)
self.conv2 = nn.Conv2D(
in_channels=in_channels // reduction,
out_channels=in_channels,
kernel_size=1,
stride=1,
- padding=0,
- weight_attr=ParamAttr(name + "_2_weights"),
- bias_attr=ParamAttr(name=name + "_2_offset"))
+ padding=0)
def forward(self, inputs):
outputs = self.avg_pool(inputs)
diff --git a/backend/ppocr/modeling/backbones/det_resnet_vd.py b/backend/ppocr/modeling/backbones/det_resnet_vd.py
index 3bb4a0d5..8c955a4a 100644
--- a/backend/ppocr/modeling/backbones/det_resnet_vd.py
+++ b/backend/ppocr/modeling/backbones/det_resnet_vd.py
@@ -21,45 +21,116 @@
import paddle.nn as nn
import paddle.nn.functional as F
+from paddle.vision.ops import DeformConv2D
+from paddle.regularizer import L2Decay
+from paddle.nn.initializer import Normal, Constant, XavierUniform
+
__all__ = ["ResNet"]
-class ConvBNLayer(nn.Layer):
- def __init__(
- self,
+class DeformableConvV2(nn.Layer):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ weight_attr=None,
+ bias_attr=None,
+ lr_scale=1,
+ regularizer=None,
+ skip_quant=False,
+ dcn_bias_regularizer=L2Decay(0.),
+ dcn_bias_lr_scale=2.):
+ super(DeformableConvV2, self).__init__()
+ self.offset_channel = 2 * kernel_size**2 * groups
+ self.mask_channel = kernel_size**2 * groups
+
+ if bias_attr:
+ # in FCOS-DCN head, specifically need learning_rate and regularizer
+ dcn_bias_attr = ParamAttr(
+ initializer=Constant(value=0),
+ regularizer=dcn_bias_regularizer,
+ learning_rate=dcn_bias_lr_scale)
+ else:
+ # in ResNet backbone, do not need bias
+ dcn_bias_attr = False
+ self.conv_dcn = DeformConv2D(
in_channels,
out_channels,
kernel_size,
- stride=1,
- groups=1,
- is_vd_mode=False,
- act=None,
- name=None, ):
+ stride=stride,
+ padding=(kernel_size - 1) // 2 * dilation,
+ dilation=dilation,
+ deformable_groups=groups,
+ weight_attr=weight_attr,
+ bias_attr=dcn_bias_attr)
+
+ if lr_scale == 1 and regularizer is None:
+ offset_bias_attr = ParamAttr(initializer=Constant(0.))
+ else:
+ offset_bias_attr = ParamAttr(
+ initializer=Constant(0.),
+ learning_rate=lr_scale,
+ regularizer=regularizer)
+ self.conv_offset = nn.Conv2D(
+ in_channels,
+ groups * 3 * kernel_size**2,
+ kernel_size,
+ stride=stride,
+ padding=(kernel_size - 1) // 2,
+ weight_attr=ParamAttr(initializer=Constant(0.0)),
+ bias_attr=offset_bias_attr)
+ if skip_quant:
+ self.conv_offset.skip_quant = True
+
+ def forward(self, x):
+ offset_mask = self.conv_offset(x)
+ offset, mask = paddle.split(
+ offset_mask,
+ num_or_sections=[self.offset_channel, self.mask_channel],
+ axis=1)
+ mask = F.sigmoid(mask)
+ y = self.conv_dcn(x, offset, mask=mask)
+ return y
+
+
+class ConvBNLayer(nn.Layer):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ groups=1,
+ is_vd_mode=False,
+ act=None,
+ is_dcn=False):
super(ConvBNLayer, self).__init__()
self.is_vd_mode = is_vd_mode
self._pool2d_avg = nn.AvgPool2D(
kernel_size=2, stride=2, padding=0, ceil_mode=True)
- self._conv = nn.Conv2D(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=kernel_size,
- stride=stride,
- padding=(kernel_size - 1) // 2,
- groups=groups,
- weight_attr=ParamAttr(name=name + "_weights"),
- bias_attr=False)
- if name == "conv1":
- bn_name = "bn_" + name
+ if not is_dcn:
+ self._conv = nn.Conv2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=(kernel_size - 1) // 2,
+ groups=groups,
+ bias_attr=False)
else:
- bn_name = "bn" + name[3:]
- self._batch_norm = nn.BatchNorm(
- out_channels,
- act=act,
- param_attr=ParamAttr(name=bn_name + '_scale'),
- bias_attr=ParamAttr(bn_name + '_offset'),
- moving_mean_name=bn_name + '_mean',
- moving_variance_name=bn_name + '_variance')
+ self._conv = DeformableConvV2(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=(kernel_size - 1) // 2,
+ groups=2, #groups,
+ bias_attr=False)
+ self._batch_norm = nn.BatchNorm(out_channels, act=act)
def forward(self, inputs):
if self.is_vd_mode:
@@ -70,34 +141,33 @@ def forward(self, inputs):
class BottleneckBlock(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- stride,
- shortcut=True,
- if_first=False,
- name=None):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ stride,
+ shortcut=True,
+ if_first=False,
+ is_dcn=False, ):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
- act='relu',
- name=name + "_branch2a")
+ act='relu')
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
act='relu',
- name=name + "_branch2b")
+ is_dcn=is_dcn)
self.conv2 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels * 4,
kernel_size=1,
- act=None,
- name=name + "_branch2c")
+ act=None)
if not shortcut:
self.short = ConvBNLayer(
@@ -105,8 +175,7 @@ def __init__(self,
out_channels=out_channels * 4,
kernel_size=1,
stride=1,
- is_vd_mode=False if if_first else True,
- name=name + "_branch1")
+ is_vd_mode=False if if_first else True)
self.shortcut = shortcut
@@ -125,13 +194,13 @@ def forward(self, inputs):
class BasicBlock(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- stride,
- shortcut=True,
- if_first=False,
- name=None):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ stride,
+ shortcut=True,
+ if_first=False, ):
super(BasicBlock, self).__init__()
self.stride = stride
self.conv0 = ConvBNLayer(
@@ -139,14 +208,12 @@ def __init__(self,
out_channels=out_channels,
kernel_size=3,
stride=stride,
- act='relu',
- name=name + "_branch2a")
+ act='relu')
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
- act=None,
- name=name + "_branch2b")
+ act=None)
if not shortcut:
self.short = ConvBNLayer(
@@ -154,8 +221,7 @@ def __init__(self,
out_channels=out_channels,
kernel_size=1,
stride=1,
- is_vd_mode=False if if_first else True,
- name=name + "_branch1")
+ is_vd_mode=False if if_first else True)
self.shortcut = shortcut
@@ -173,7 +239,12 @@ def forward(self, inputs):
class ResNet(nn.Layer):
- def __init__(self, in_channels=3, layers=50, **kwargs):
+ def __init__(self,
+ in_channels=3,
+ layers=50,
+ dcn_stage=None,
+ out_indices=None,
+ **kwargs):
super(ResNet, self).__init__()
self.layers = layers
@@ -196,27 +267,31 @@ def __init__(self, in_channels=3, layers=50, **kwargs):
1024] if layers >= 50 else [64, 64, 128, 256]
num_filters = [64, 128, 256, 512]
+ self.dcn_stage = dcn_stage if dcn_stage is not None else [
+ False, False, False, False
+ ]
+ self.out_indices = out_indices if out_indices is not None else [
+ 0, 1, 2, 3
+ ]
+
self.conv1_1 = ConvBNLayer(
in_channels=in_channels,
out_channels=32,
kernel_size=3,
stride=2,
- act='relu',
- name="conv1_1")
+ act='relu')
self.conv1_2 = ConvBNLayer(
in_channels=32,
out_channels=32,
kernel_size=3,
stride=1,
- act='relu',
- name="conv1_2")
+ act='relu')
self.conv1_3 = ConvBNLayer(
in_channels=32,
out_channels=64,
kernel_size=3,
stride=1,
- act='relu',
- name="conv1_3")
+ act='relu')
self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
self.stages = []
@@ -225,14 +300,8 @@ def __init__(self, in_channels=3, layers=50, **kwargs):
for block in range(len(depth)):
block_list = []
shortcut = False
+ is_dcn = self.dcn_stage[block]
for i in range(depth[block]):
- if layers in [101, 152] and block == 2:
- if i == 0:
- conv_name = "res" + str(block + 2) + "a"
- else:
- conv_name = "res" + str(block + 2) + "b" + str(i)
- else:
- conv_name = "res" + str(block + 2) + chr(97 + i)
bottleneck_block = self.add_sublayer(
'bb_%d_%d' % (block, i),
BottleneckBlock(
@@ -242,17 +311,18 @@ def __init__(self, in_channels=3, layers=50, **kwargs):
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
if_first=block == i == 0,
- name=conv_name))
+ is_dcn=is_dcn))
shortcut = True
block_list.append(bottleneck_block)
- self.out_channels.append(num_filters[block] * 4)
+ if block in self.out_indices:
+ self.out_channels.append(num_filters[block] * 4)
self.stages.append(nn.Sequential(*block_list))
else:
for block in range(len(depth)):
block_list = []
shortcut = False
+ # is_dcn = self.dcn_stage[block]
for i in range(depth[block]):
- conv_name = "res" + str(block + 2) + chr(97 + i)
basic_block = self.add_sublayer(
'bb_%d_%d' % (block, i),
BasicBlock(
@@ -261,11 +331,11 @@ def __init__(self, in_channels=3, layers=50, **kwargs):
out_channels=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
- if_first=block == i == 0,
- name=conv_name))
+ if_first=block == i == 0))
shortcut = True
block_list.append(basic_block)
- self.out_channels.append(num_filters[block])
+ if block in self.out_indices:
+ self.out_channels.append(num_filters[block])
self.stages.append(nn.Sequential(*block_list))
def forward(self, inputs):
@@ -274,7 +344,8 @@ def forward(self, inputs):
y = self.conv1_3(y)
y = self.pool2d_max(y)
out = []
- for block in self.stages:
+ for i, block in enumerate(self.stages):
y = block(y)
- out.append(y)
+ if i in self.out_indices:
+ out.append(y)
return out
diff --git a/backend/ppocr/modeling/backbones/e2e_resnet_vd_pg.py b/backend/ppocr/modeling/backbones/e2e_resnet_vd_pg.py
new file mode 100644
index 00000000..97afd346
--- /dev/null
+++ b/backend/ppocr/modeling/backbones/e2e_resnet_vd_pg.py
@@ -0,0 +1,265 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import ParamAttr
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+__all__ = ["ResNet"]
+
+
+class ConvBNLayer(nn.Layer):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ groups=1,
+ is_vd_mode=False,
+ act=None,
+ name=None, ):
+ super(ConvBNLayer, self).__init__()
+
+ self.is_vd_mode = is_vd_mode
+ self._pool2d_avg = nn.AvgPool2D(
+ kernel_size=2, stride=2, padding=0, ceil_mode=True)
+ self._conv = nn.Conv2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=(kernel_size - 1) // 2,
+ groups=groups,
+ weight_attr=ParamAttr(name=name + "_weights"),
+ bias_attr=False)
+ if name == "conv1":
+ bn_name = "bn_" + name
+ else:
+ bn_name = "bn" + name[3:]
+ self._batch_norm = nn.BatchNorm(
+ out_channels,
+ act=act,
+ param_attr=ParamAttr(name=bn_name + '_scale'),
+ bias_attr=ParamAttr(bn_name + '_offset'),
+ moving_mean_name=bn_name + '_mean',
+ moving_variance_name=bn_name + '_variance')
+
+ def forward(self, inputs):
+ y = self._conv(inputs)
+ y = self._batch_norm(y)
+ return y
+
+
+class BottleneckBlock(nn.Layer):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ stride,
+ shortcut=True,
+ if_first=False,
+ name=None):
+ super(BottleneckBlock, self).__init__()
+
+ self.conv0 = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ act='relu',
+ name=name + "_branch2a")
+ self.conv1 = ConvBNLayer(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=stride,
+ act='relu',
+ name=name + "_branch2b")
+ self.conv2 = ConvBNLayer(
+ in_channels=out_channels,
+ out_channels=out_channels * 4,
+ kernel_size=1,
+ act=None,
+ name=name + "_branch2c")
+
+ if not shortcut:
+ self.short = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=out_channels * 4,
+ kernel_size=1,
+ stride=stride,
+ is_vd_mode=False if if_first else True,
+ name=name + "_branch1")
+
+ self.shortcut = shortcut
+
+ def forward(self, inputs):
+ y = self.conv0(inputs)
+ conv1 = self.conv1(y)
+ conv2 = self.conv2(conv1)
+
+ if self.shortcut:
+ short = inputs
+ else:
+ short = self.short(inputs)
+ y = paddle.add(x=short, y=conv2)
+ y = F.relu(y)
+ return y
+
+
+class BasicBlock(nn.Layer):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ stride,
+ shortcut=True,
+ if_first=False,
+ name=None):
+ super(BasicBlock, self).__init__()
+ self.stride = stride
+ self.conv0 = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=stride,
+ act='relu',
+ name=name + "_branch2a")
+ self.conv1 = ConvBNLayer(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ act=None,
+ name=name + "_branch2b")
+
+ if not shortcut:
+ self.short = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ stride=1,
+ is_vd_mode=False if if_first else True,
+ name=name + "_branch1")
+
+ self.shortcut = shortcut
+
+ def forward(self, inputs):
+ y = self.conv0(inputs)
+ conv1 = self.conv1(y)
+
+ if self.shortcut:
+ short = inputs
+ else:
+ short = self.short(inputs)
+ y = paddle.add(x=short, y=conv1)
+ y = F.relu(y)
+ return y
+
+
+class ResNet(nn.Layer):
+ def __init__(self, in_channels=3, layers=50, **kwargs):
+ super(ResNet, self).__init__()
+
+ self.layers = layers
+ supported_layers = [18, 34, 50, 101, 152, 200]
+ assert layers in supported_layers, \
+ "supported layers are {} but input layer is {}".format(
+ supported_layers, layers)
+
+ if layers == 18:
+ depth = [2, 2, 2, 2]
+ elif layers == 34 or layers == 50:
+ # depth = [3, 4, 6, 3]
+ depth = [3, 4, 6, 3, 3]
+ elif layers == 101:
+ depth = [3, 4, 23, 3]
+ elif layers == 152:
+ depth = [3, 8, 36, 3]
+ elif layers == 200:
+ depth = [3, 12, 48, 3]
+ num_channels = [64, 256, 512, 1024,
+ 2048] if layers >= 50 else [64, 64, 128, 256]
+ num_filters = [64, 128, 256, 512, 512]
+
+ self.conv1_1 = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=64,
+ kernel_size=7,
+ stride=2,
+ act='relu',
+ name="conv1_1")
+ self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
+
+ self.stages = []
+ self.out_channels = [3, 64]
+ # num_filters = [64, 128, 256, 512, 512]
+ if layers >= 50:
+ for block in range(len(depth)):
+ block_list = []
+ shortcut = False
+ for i in range(depth[block]):
+ if layers in [101, 152] and block == 2:
+ if i == 0:
+ conv_name = "res" + str(block + 2) + "a"
+ else:
+ conv_name = "res" + str(block + 2) + "b" + str(i)
+ else:
+ conv_name = "res" + str(block + 2) + chr(97 + i)
+ bottleneck_block = self.add_sublayer(
+ 'bb_%d_%d' % (block, i),
+ BottleneckBlock(
+ in_channels=num_channels[block]
+ if i == 0 else num_filters[block] * 4,
+ out_channels=num_filters[block],
+ stride=2 if i == 0 and block != 0 else 1,
+ shortcut=shortcut,
+ if_first=block == i == 0,
+ name=conv_name))
+ shortcut = True
+ block_list.append(bottleneck_block)
+ self.out_channels.append(num_filters[block] * 4)
+ self.stages.append(nn.Sequential(*block_list))
+ else:
+ for block in range(len(depth)):
+ block_list = []
+ shortcut = False
+ for i in range(depth[block]):
+ conv_name = "res" + str(block + 2) + chr(97 + i)
+ basic_block = self.add_sublayer(
+ 'bb_%d_%d' % (block, i),
+ BasicBlock(
+ in_channels=num_channels[block]
+ if i == 0 else num_filters[block],
+ out_channels=num_filters[block],
+ stride=2 if i == 0 and block != 0 else 1,
+ shortcut=shortcut,
+ if_first=block == i == 0,
+ name=conv_name))
+ shortcut = True
+ block_list.append(basic_block)
+ self.out_channels.append(num_filters[block])
+ self.stages.append(nn.Sequential(*block_list))
+
+ def forward(self, inputs):
+ out = [inputs]
+ y = self.conv1_1(inputs)
+ out.append(y)
+ y = self.pool2d_max(y)
+ for block in self.stages:
+ y = block(y)
+ out.append(y)
+ return out
diff --git a/backend/ppocr/modeling/backbones/kie_unet_sdmgr.py b/backend/ppocr/modeling/backbones/kie_unet_sdmgr.py
new file mode 100644
index 00000000..545e4e75
--- /dev/null
+++ b/backend/ppocr/modeling/backbones/kie_unet_sdmgr.py
@@ -0,0 +1,186 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+import numpy as np
+import cv2
+
+__all__ = ["Kie_backbone"]
+
+
+class Encoder(nn.Layer):
+ def __init__(self, num_channels, num_filters):
+ super(Encoder, self).__init__()
+ self.conv1 = nn.Conv2D(
+ num_channels,
+ num_filters,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias_attr=False)
+ self.bn1 = nn.BatchNorm(num_filters, act='relu')
+
+ self.conv2 = nn.Conv2D(
+ num_filters,
+ num_filters,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias_attr=False)
+ self.bn2 = nn.BatchNorm(num_filters, act='relu')
+
+ self.pool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
+
+ def forward(self, inputs):
+ x = self.conv1(inputs)
+ x = self.bn1(x)
+ x = self.conv2(x)
+ x = self.bn2(x)
+ x_pooled = self.pool(x)
+ return x, x_pooled
+
+
+class Decoder(nn.Layer):
+ def __init__(self, num_channels, num_filters):
+ super(Decoder, self).__init__()
+
+ self.conv1 = nn.Conv2D(
+ num_channels,
+ num_filters,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias_attr=False)
+ self.bn1 = nn.BatchNorm(num_filters, act='relu')
+
+ self.conv2 = nn.Conv2D(
+ num_filters,
+ num_filters,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias_attr=False)
+ self.bn2 = nn.BatchNorm(num_filters, act='relu')
+
+ self.conv0 = nn.Conv2D(
+ num_channels,
+ num_filters,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias_attr=False)
+ self.bn0 = nn.BatchNorm(num_filters, act='relu')
+
+ def forward(self, inputs_prev, inputs):
+ x = self.conv0(inputs)
+ x = self.bn0(x)
+ x = paddle.nn.functional.interpolate(
+ x, scale_factor=2, mode='bilinear', align_corners=False)
+ x = paddle.concat([inputs_prev, x], axis=1)
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.conv2(x)
+ x = self.bn2(x)
+ return x
+
+
+class UNet(nn.Layer):
+ def __init__(self):
+ super(UNet, self).__init__()
+ self.down1 = Encoder(num_channels=3, num_filters=16)
+ self.down2 = Encoder(num_channels=16, num_filters=32)
+ self.down3 = Encoder(num_channels=32, num_filters=64)
+ self.down4 = Encoder(num_channels=64, num_filters=128)
+ self.down5 = Encoder(num_channels=128, num_filters=256)
+
+ self.up1 = Decoder(32, 16)
+ self.up2 = Decoder(64, 32)
+ self.up3 = Decoder(128, 64)
+ self.up4 = Decoder(256, 128)
+ self.out_channels = 16
+
+ def forward(self, inputs):
+ x1, _ = self.down1(inputs)
+ _, x2 = self.down2(x1)
+ _, x3 = self.down3(x2)
+ _, x4 = self.down4(x3)
+ _, x5 = self.down5(x4)
+
+ x = self.up4(x4, x5)
+ x = self.up3(x3, x)
+ x = self.up2(x2, x)
+ x = self.up1(x1, x)
+ return x
+
+
+class Kie_backbone(nn.Layer):
+ def __init__(self, in_channels, **kwargs):
+ super(Kie_backbone, self).__init__()
+ self.out_channels = 16
+ self.img_feat = UNet()
+ self.maxpool = nn.MaxPool2D(kernel_size=7)
+
+ def bbox2roi(self, bbox_list):
+ rois_list = []
+ rois_num = []
+ for img_id, bboxes in enumerate(bbox_list):
+ rois_num.append(bboxes.shape[0])
+ rois_list.append(bboxes)
+ rois = paddle.concat(rois_list, 0)
+ rois_num = paddle.to_tensor(rois_num, dtype='int32')
+ return rois, rois_num
+
+ def pre_process(self, img, relations, texts, gt_bboxes, tag, img_size):
+ img, relations, texts, gt_bboxes, tag, img_size = img.numpy(
+ ), relations.numpy(), texts.numpy(), gt_bboxes.numpy(), tag.numpy(
+ ).tolist(), img_size.numpy()
+ temp_relations, temp_texts, temp_gt_bboxes = [], [], []
+ h, w = int(np.max(img_size[:, 0])), int(np.max(img_size[:, 1]))
+ img = paddle.to_tensor(img[:, :, :h, :w])
+ batch = len(tag)
+ for i in range(batch):
+ num, recoder_len = tag[i][0], tag[i][1]
+ temp_relations.append(
+ paddle.to_tensor(
+ relations[i, :num, :num, :], dtype='float32'))
+ temp_texts.append(
+ paddle.to_tensor(
+ texts[i, :num, :recoder_len], dtype='float32'))
+ temp_gt_bboxes.append(
+ paddle.to_tensor(
+ gt_bboxes[i, :num, ...], dtype='float32'))
+ return img, temp_relations, temp_texts, temp_gt_bboxes
+
+ def forward(self, inputs):
+ img = inputs[0]
+ relations, texts, gt_bboxes, tag, img_size = inputs[1], inputs[
+ 2], inputs[3], inputs[5], inputs[-1]
+ img, relations, texts, gt_bboxes = self.pre_process(
+ img, relations, texts, gt_bboxes, tag, img_size)
+ x = self.img_feat(img)
+ boxes, rois_num = self.bbox2roi(gt_bboxes)
+ feats = paddle.fluid.layers.roi_align(
+ x,
+ boxes,
+ spatial_scale=1.0,
+ pooled_height=7,
+ pooled_width=7,
+ rois_num=rois_num)
+ feats = self.maxpool(feats).squeeze(-1).squeeze(-1)
+ return [relations, texts, feats]
diff --git a/backend/ppocr/modeling/backbones/rec_efficientb3_pren.py b/backend/ppocr/modeling/backbones/rec_efficientb3_pren.py
new file mode 100644
index 00000000..57eef178
--- /dev/null
+++ b/backend/ppocr/modeling/backbones/rec_efficientb3_pren.py
@@ -0,0 +1,228 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Code is refer from:
+https://github.com/RuijieJ/pren/blob/main/Nets/EfficientNet.py
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+from collections import namedtuple
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+__all__ = ['EfficientNetb3']
+
+
+class EffB3Params:
+ @staticmethod
+ def get_global_params():
+ """
+ The fllowing are efficientnetb3's arch superparams, but to fit for scene
+ text recognition task, the resolution(image_size) here is changed
+ from 300 to 64.
+ """
+ GlobalParams = namedtuple('GlobalParams', [
+ 'drop_connect_rate', 'width_coefficient', 'depth_coefficient',
+ 'depth_divisor', 'image_size'
+ ])
+ global_params = GlobalParams(
+ drop_connect_rate=0.3,
+ width_coefficient=1.2,
+ depth_coefficient=1.4,
+ depth_divisor=8,
+ image_size=64)
+ return global_params
+
+ @staticmethod
+ def get_block_params():
+ BlockParams = namedtuple('BlockParams', [
+ 'kernel_size', 'num_repeat', 'input_filters', 'output_filters',
+ 'expand_ratio', 'id_skip', 'se_ratio', 'stride'
+ ])
+ block_params = [
+ BlockParams(3, 1, 32, 16, 1, True, 0.25, 1),
+ BlockParams(3, 2, 16, 24, 6, True, 0.25, 2),
+ BlockParams(5, 2, 24, 40, 6, True, 0.25, 2),
+ BlockParams(3, 3, 40, 80, 6, True, 0.25, 2),
+ BlockParams(5, 3, 80, 112, 6, True, 0.25, 1),
+ BlockParams(5, 4, 112, 192, 6, True, 0.25, 2),
+ BlockParams(3, 1, 192, 320, 6, True, 0.25, 1)
+ ]
+ return block_params
+
+
+class EffUtils:
+ @staticmethod
+ def round_filters(filters, global_params):
+ """Calculate and round number of filters based on depth multiplier."""
+ multiplier = global_params.width_coefficient
+ if not multiplier:
+ return filters
+ divisor = global_params.depth_divisor
+ filters *= multiplier
+ new_filters = int(filters + divisor / 2) // divisor * divisor
+ if new_filters < 0.9 * filters:
+ new_filters += divisor
+ return int(new_filters)
+
+ @staticmethod
+ def round_repeats(repeats, global_params):
+ """Round number of filters based on depth multiplier."""
+ multiplier = global_params.depth_coefficient
+ if not multiplier:
+ return repeats
+ return int(math.ceil(multiplier * repeats))
+
+
+class ConvBlock(nn.Layer):
+ def __init__(self, block_params):
+ super(ConvBlock, self).__init__()
+ self.block_args = block_params
+ self.has_se = (self.block_args.se_ratio is not None) and \
+ (0 < self.block_args.se_ratio <= 1)
+ self.id_skip = block_params.id_skip
+
+ # expansion phase
+ self.input_filters = self.block_args.input_filters
+ output_filters = \
+ self.block_args.input_filters * self.block_args.expand_ratio
+ if self.block_args.expand_ratio != 1:
+ self.expand_conv = nn.Conv2D(
+ self.input_filters, output_filters, 1, bias_attr=False)
+ self.bn0 = nn.BatchNorm(output_filters)
+
+ # depthwise conv phase
+ k = self.block_args.kernel_size
+ s = self.block_args.stride
+ self.depthwise_conv = nn.Conv2D(
+ output_filters,
+ output_filters,
+ groups=output_filters,
+ kernel_size=k,
+ stride=s,
+ padding='same',
+ bias_attr=False)
+ self.bn1 = nn.BatchNorm(output_filters)
+
+ # squeeze and excitation layer, if desired
+ if self.has_se:
+ num_squeezed_channels = max(1,
+ int(self.block_args.input_filters *
+ self.block_args.se_ratio))
+ self.se_reduce = nn.Conv2D(output_filters, num_squeezed_channels, 1)
+ self.se_expand = nn.Conv2D(num_squeezed_channels, output_filters, 1)
+
+ # output phase
+ self.final_oup = self.block_args.output_filters
+ self.project_conv = nn.Conv2D(
+ output_filters, self.final_oup, 1, bias_attr=False)
+ self.bn2 = nn.BatchNorm(self.final_oup)
+ self.swish = nn.Swish()
+
+ def drop_connect(self, inputs, p, training):
+ if not training:
+ return inputs
+
+ batch_size = inputs.shape[0]
+ keep_prob = 1 - p
+ random_tensor = keep_prob
+ random_tensor += paddle.rand([batch_size, 1, 1, 1], dtype=inputs.dtype)
+ random_tensor = paddle.to_tensor(random_tensor, place=inputs.place)
+ binary_tensor = paddle.floor(random_tensor)
+ output = inputs / keep_prob * binary_tensor
+ return output
+
+ def forward(self, inputs, drop_connect_rate=None):
+ # expansion and depthwise conv
+ x = inputs
+ if self.block_args.expand_ratio != 1:
+ x = self.swish(self.bn0(self.expand_conv(inputs)))
+ x = self.swish(self.bn1(self.depthwise_conv(x)))
+
+ # squeeze and excitation
+ if self.has_se:
+ x_squeezed = F.adaptive_avg_pool2d(x, 1)
+ x_squeezed = self.se_expand(self.swish(self.se_reduce(x_squeezed)))
+ x = F.sigmoid(x_squeezed) * x
+ x = self.bn2(self.project_conv(x))
+
+ # skip conntection and drop connect
+ if self.id_skip and self.block_args.stride == 1 and \
+ self.input_filters == self.final_oup:
+ if drop_connect_rate:
+ x = self.drop_connect(
+ x, p=drop_connect_rate, training=self.training)
+ x = x + inputs
+ return x
+
+
+class EfficientNetb3_PREN(nn.Layer):
+ def __init__(self, in_channels):
+ super(EfficientNetb3_PREN, self).__init__()
+ self.blocks_params = EffB3Params.get_block_params()
+ self.global_params = EffB3Params.get_global_params()
+ self.out_channels = []
+ # stem
+ stem_channels = EffUtils.round_filters(32, self.global_params)
+ self.conv_stem = nn.Conv2D(
+ in_channels, stem_channels, 3, 2, padding='same', bias_attr=False)
+ self.bn0 = nn.BatchNorm(stem_channels)
+
+ self.blocks = []
+ # to extract three feature maps for fpn based on efficientnetb3 backbone
+ self.concerned_block_idxes = [7, 17, 25]
+ concerned_idx = 0
+ for i, block_params in enumerate(self.blocks_params):
+ block_params = block_params._replace(
+ input_filters=EffUtils.round_filters(block_params.input_filters,
+ self.global_params),
+ output_filters=EffUtils.round_filters(
+ block_params.output_filters, self.global_params),
+ num_repeat=EffUtils.round_repeats(block_params.num_repeat,
+ self.global_params))
+ self.blocks.append(
+ self.add_sublayer("{}-0".format(i), ConvBlock(block_params)))
+ concerned_idx += 1
+ if concerned_idx in self.concerned_block_idxes:
+ self.out_channels.append(block_params.output_filters)
+ if block_params.num_repeat > 1:
+ block_params = block_params._replace(
+ input_filters=block_params.output_filters, stride=1)
+ for j in range(block_params.num_repeat - 1):
+ self.blocks.append(
+ self.add_sublayer('{}-{}'.format(i, j + 1),
+ ConvBlock(block_params)))
+ concerned_idx += 1
+ if concerned_idx in self.concerned_block_idxes:
+ self.out_channels.append(block_params.output_filters)
+
+ self.swish = nn.Swish()
+
+ def forward(self, inputs):
+ outs = []
+
+ x = self.swish(self.bn0(self.conv_stem(inputs)))
+ for idx, block in enumerate(self.blocks):
+ drop_connect_rate = self.global_params.drop_connect_rate
+ if drop_connect_rate:
+ drop_connect_rate *= float(idx) / len(self.blocks)
+ x = block(x, drop_connect_rate=drop_connect_rate)
+ if idx in self.concerned_block_idxes:
+ outs.append(x)
+ return outs
diff --git a/backend/ppocr/modeling/backbones/rec_micronet.py b/backend/ppocr/modeling/backbones/rec_micronet.py
new file mode 100644
index 00000000..b0ae5a14
--- /dev/null
+++ b/backend/ppocr/modeling/backbones/rec_micronet.py
@@ -0,0 +1,528 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/liyunsheng13/micronet/blob/main/backbone/micronet.py
+https://github.com/liyunsheng13/micronet/blob/main/backbone/activation.py
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+import paddle.nn as nn
+
+from ppocr.modeling.backbones.det_mobilenet_v3 import make_divisible
+
+M0_cfgs = [
+ # s, n, c, ks, c1, c2, g1, g2, c3, g3, g4, y1, y2, y3, r
+ [2, 1, 8, 3, 2, 2, 0, 4, 8, 2, 2, 2, 0, 1, 1],
+ [2, 1, 12, 3, 2, 2, 0, 8, 12, 4, 4, 2, 2, 1, 1],
+ [2, 1, 16, 5, 2, 2, 0, 12, 16, 4, 4, 2, 2, 1, 1],
+ [1, 1, 32, 5, 1, 4, 4, 4, 32, 4, 4, 2, 2, 1, 1],
+ [2, 1, 64, 5, 1, 4, 8, 8, 64, 8, 8, 2, 2, 1, 1],
+ [1, 1, 96, 3, 1, 4, 8, 8, 96, 8, 8, 2, 2, 1, 2],
+ [1, 1, 384, 3, 1, 4, 12, 12, 0, 0, 0, 2, 2, 1, 2],
+]
+M1_cfgs = [
+ # s, n, c, ks, c1, c2, g1, g2, c3, g3, g4
+ [2, 1, 8, 3, 2, 2, 0, 6, 8, 2, 2, 2, 0, 1, 1],
+ [2, 1, 16, 3, 2, 2, 0, 8, 16, 4, 4, 2, 2, 1, 1],
+ [2, 1, 16, 5, 2, 2, 0, 16, 16, 4, 4, 2, 2, 1, 1],
+ [1, 1, 32, 5, 1, 6, 4, 4, 32, 4, 4, 2, 2, 1, 1],
+ [2, 1, 64, 5, 1, 6, 8, 8, 64, 8, 8, 2, 2, 1, 1],
+ [1, 1, 96, 3, 1, 6, 8, 8, 96, 8, 8, 2, 2, 1, 2],
+ [1, 1, 576, 3, 1, 6, 12, 12, 0, 0, 0, 2, 2, 1, 2],
+]
+M2_cfgs = [
+ # s, n, c, ks, c1, c2, g1, g2, c3, g3, g4
+ [2, 1, 12, 3, 2, 2, 0, 8, 12, 4, 4, 2, 0, 1, 1],
+ [2, 1, 16, 3, 2, 2, 0, 12, 16, 4, 4, 2, 2, 1, 1],
+ [1, 1, 24, 3, 2, 2, 0, 16, 24, 4, 4, 2, 2, 1, 1],
+ [2, 1, 32, 5, 1, 6, 6, 6, 32, 4, 4, 2, 2, 1, 1],
+ [1, 1, 32, 5, 1, 6, 8, 8, 32, 4, 4, 2, 2, 1, 2],
+ [1, 1, 64, 5, 1, 6, 8, 8, 64, 8, 8, 2, 2, 1, 2],
+ [2, 1, 96, 5, 1, 6, 8, 8, 96, 8, 8, 2, 2, 1, 2],
+ [1, 1, 128, 3, 1, 6, 12, 12, 128, 8, 8, 2, 2, 1, 2],
+ [1, 1, 768, 3, 1, 6, 16, 16, 0, 0, 0, 2, 2, 1, 2],
+]
+M3_cfgs = [
+ # s, n, c, ks, c1, c2, g1, g2, c3, g3, g4
+ [2, 1, 16, 3, 2, 2, 0, 12, 16, 4, 4, 0, 2, 0, 1],
+ [2, 1, 24, 3, 2, 2, 0, 16, 24, 4, 4, 0, 2, 0, 1],
+ [1, 1, 24, 3, 2, 2, 0, 24, 24, 4, 4, 0, 2, 0, 1],
+ [2, 1, 32, 5, 1, 6, 6, 6, 32, 4, 4, 0, 2, 0, 1],
+ [1, 1, 32, 5, 1, 6, 8, 8, 32, 4, 4, 0, 2, 0, 2],
+ [1, 1, 64, 5, 1, 6, 8, 8, 48, 8, 8, 0, 2, 0, 2],
+ [1, 1, 80, 5, 1, 6, 8, 8, 80, 8, 8, 0, 2, 0, 2],
+ [1, 1, 80, 5, 1, 6, 10, 10, 80, 8, 8, 0, 2, 0, 2],
+ [1, 1, 120, 5, 1, 6, 10, 10, 120, 10, 10, 0, 2, 0, 2],
+ [1, 1, 120, 5, 1, 6, 12, 12, 120, 10, 10, 0, 2, 0, 2],
+ [1, 1, 144, 3, 1, 6, 12, 12, 144, 12, 12, 0, 2, 0, 2],
+ [1, 1, 432, 3, 1, 3, 12, 12, 0, 0, 0, 0, 2, 0, 2],
+]
+
+
+def get_micronet_config(mode):
+ return eval(mode + '_cfgs')
+
+
+class MaxGroupPooling(nn.Layer):
+ def __init__(self, channel_per_group=2):
+ super(MaxGroupPooling, self).__init__()
+ self.channel_per_group = channel_per_group
+
+ def forward(self, x):
+ if self.channel_per_group == 1:
+ return x
+ # max op
+ b, c, h, w = x.shape
+
+ # reshape
+ y = paddle.reshape(x, [b, c // self.channel_per_group, -1, h, w])
+ out = paddle.max(y, axis=2)
+ return out
+
+
+class SpatialSepConvSF(nn.Layer):
+ def __init__(self, inp, oups, kernel_size, stride):
+ super(SpatialSepConvSF, self).__init__()
+
+ oup1, oup2 = oups
+ self.conv = nn.Sequential(
+ nn.Conv2D(
+ inp,
+ oup1, (kernel_size, 1), (stride, 1), (kernel_size // 2, 0),
+ bias_attr=False,
+ groups=1),
+ nn.BatchNorm2D(oup1),
+ nn.Conv2D(
+ oup1,
+ oup1 * oup2, (1, kernel_size), (1, stride),
+ (0, kernel_size // 2),
+ bias_attr=False,
+ groups=oup1),
+ nn.BatchNorm2D(oup1 * oup2),
+ ChannelShuffle(oup1), )
+
+ def forward(self, x):
+ out = self.conv(x)
+ return out
+
+
+class ChannelShuffle(nn.Layer):
+ def __init__(self, groups):
+ super(ChannelShuffle, self).__init__()
+ self.groups = groups
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+
+ channels_per_group = c // self.groups
+
+ # reshape
+ x = paddle.reshape(x, [b, self.groups, channels_per_group, h, w])
+
+ x = paddle.transpose(x, (0, 2, 1, 3, 4))
+ out = paddle.reshape(x, [b, -1, h, w])
+
+ return out
+
+
+class StemLayer(nn.Layer):
+ def __init__(self, inp, oup, stride, groups=(4, 4)):
+ super(StemLayer, self).__init__()
+
+ g1, g2 = groups
+ self.stem = nn.Sequential(
+ SpatialSepConvSF(inp, groups, 3, stride),
+ MaxGroupPooling(2) if g1 * g2 == 2 * oup else nn.ReLU6())
+
+ def forward(self, x):
+ out = self.stem(x)
+ return out
+
+
+class DepthSpatialSepConv(nn.Layer):
+ def __init__(self, inp, expand, kernel_size, stride):
+ super(DepthSpatialSepConv, self).__init__()
+
+ exp1, exp2 = expand
+
+ hidden_dim = inp * exp1
+ oup = inp * exp1 * exp2
+
+ self.conv = nn.Sequential(
+ nn.Conv2D(
+ inp,
+ inp * exp1, (kernel_size, 1), (stride, 1),
+ (kernel_size // 2, 0),
+ bias_attr=False,
+ groups=inp),
+ nn.BatchNorm2D(inp * exp1),
+ nn.Conv2D(
+ hidden_dim,
+ oup, (1, kernel_size),
+ 1, (0, kernel_size // 2),
+ bias_attr=False,
+ groups=hidden_dim),
+ nn.BatchNorm2D(oup))
+
+ def forward(self, x):
+ x = self.conv(x)
+ return x
+
+
+class GroupConv(nn.Layer):
+ def __init__(self, inp, oup, groups=2):
+ super(GroupConv, self).__init__()
+ self.inp = inp
+ self.oup = oup
+ self.groups = groups
+ self.conv = nn.Sequential(
+ nn.Conv2D(
+ inp, oup, 1, 1, 0, bias_attr=False, groups=self.groups[0]),
+ nn.BatchNorm2D(oup))
+
+ def forward(self, x):
+ x = self.conv(x)
+ return x
+
+
+class DepthConv(nn.Layer):
+ def __init__(self, inp, oup, kernel_size, stride):
+ super(DepthConv, self).__init__()
+ self.conv = nn.Sequential(
+ nn.Conv2D(
+ inp,
+ oup,
+ kernel_size,
+ stride,
+ kernel_size // 2,
+ bias_attr=False,
+ groups=inp),
+ nn.BatchNorm2D(oup))
+
+ def forward(self, x):
+ out = self.conv(x)
+ return out
+
+
+class DYShiftMax(nn.Layer):
+ def __init__(self,
+ inp,
+ oup,
+ reduction=4,
+ act_max=1.0,
+ act_relu=True,
+ init_a=[0.0, 0.0],
+ init_b=[0.0, 0.0],
+ relu_before_pool=False,
+ g=None,
+ expansion=False):
+ super(DYShiftMax, self).__init__()
+ self.oup = oup
+ self.act_max = act_max * 2
+ self.act_relu = act_relu
+ self.avg_pool = nn.Sequential(nn.ReLU() if relu_before_pool == True else
+ nn.Sequential(), nn.AdaptiveAvgPool2D(1))
+
+ self.exp = 4 if act_relu else 2
+ self.init_a = init_a
+ self.init_b = init_b
+
+ # determine squeeze
+ squeeze = make_divisible(inp // reduction, 4)
+ if squeeze < 4:
+ squeeze = 4
+
+ self.fc = nn.Sequential(
+ nn.Linear(inp, squeeze),
+ nn.ReLU(), nn.Linear(squeeze, oup * self.exp), nn.Hardsigmoid())
+
+ if g is None:
+ g = 1
+ self.g = g[1]
+ if self.g != 1 and expansion:
+ self.g = inp // self.g
+
+ self.gc = inp // self.g
+ index = paddle.to_tensor([range(inp)])
+ index = paddle.reshape(index, [1, inp, 1, 1])
+ index = paddle.reshape(index, [1, self.g, self.gc, 1, 1])
+ indexgs = paddle.split(index, [1, self.g - 1], axis=1)
+ indexgs = paddle.concat((indexgs[1], indexgs[0]), axis=1)
+ indexs = paddle.split(indexgs, [1, self.gc - 1], axis=2)
+ indexs = paddle.concat((indexs[1], indexs[0]), axis=2)
+ self.index = paddle.reshape(indexs, [inp])
+ self.expansion = expansion
+
+ def forward(self, x):
+ x_in = x
+ x_out = x
+
+ b, c, _, _ = x_in.shape
+ y = self.avg_pool(x_in)
+ y = paddle.reshape(y, [b, c])
+ y = self.fc(y)
+ y = paddle.reshape(y, [b, self.oup * self.exp, 1, 1])
+ y = (y - 0.5) * self.act_max
+
+ n2, c2, h2, w2 = x_out.shape
+ x2 = paddle.to_tensor(x_out.numpy()[:, self.index.numpy(), :, :])
+
+ if self.exp == 4:
+ temp = y.shape
+ a1, b1, a2, b2 = paddle.split(y, temp[1] // self.oup, axis=1)
+
+ a1 = a1 + self.init_a[0]
+ a2 = a2 + self.init_a[1]
+
+ b1 = b1 + self.init_b[0]
+ b2 = b2 + self.init_b[1]
+
+ z1 = x_out * a1 + x2 * b1
+ z2 = x_out * a2 + x2 * b2
+
+ out = paddle.maximum(z1, z2)
+
+ elif self.exp == 2:
+ temp = y.shape
+ a1, b1 = paddle.split(y, temp[1] // self.oup, axis=1)
+ a1 = a1 + self.init_a[0]
+ b1 = b1 + self.init_b[0]
+ out = x_out * a1 + x2 * b1
+
+ return out
+
+
+class DYMicroBlock(nn.Layer):
+ def __init__(self,
+ inp,
+ oup,
+ kernel_size=3,
+ stride=1,
+ ch_exp=(2, 2),
+ ch_per_group=4,
+ groups_1x1=(1, 1),
+ depthsep=True,
+ shuffle=False,
+ activation_cfg=None):
+ super(DYMicroBlock, self).__init__()
+
+ self.identity = stride == 1 and inp == oup
+
+ y1, y2, y3 = activation_cfg['dy']
+ act_reduction = 8 * activation_cfg['ratio']
+ init_a = activation_cfg['init_a']
+ init_b = activation_cfg['init_b']
+
+ t1 = ch_exp
+ gs1 = ch_per_group
+ hidden_fft, g1, g2 = groups_1x1
+ hidden_dim2 = inp * t1[0] * t1[1]
+
+ if gs1[0] == 0:
+ self.layers = nn.Sequential(
+ DepthSpatialSepConv(inp, t1, kernel_size, stride),
+ DYShiftMax(
+ hidden_dim2,
+ hidden_dim2,
+ act_max=2.0,
+ act_relu=True if y2 == 2 else False,
+ init_a=init_a,
+ reduction=act_reduction,
+ init_b=init_b,
+ g=gs1,
+ expansion=False) if y2 > 0 else nn.ReLU6(),
+ ChannelShuffle(gs1[1]) if shuffle else nn.Sequential(),
+ ChannelShuffle(hidden_dim2 // 2)
+ if shuffle and y2 != 0 else nn.Sequential(),
+ GroupConv(hidden_dim2, oup, (g1, g2)),
+ DYShiftMax(
+ oup,
+ oup,
+ act_max=2.0,
+ act_relu=False,
+ init_a=[1.0, 0.0],
+ reduction=act_reduction // 2,
+ init_b=[0.0, 0.0],
+ g=(g1, g2),
+ expansion=False) if y3 > 0 else nn.Sequential(),
+ ChannelShuffle(g2) if shuffle else nn.Sequential(),
+ ChannelShuffle(oup // 2)
+ if shuffle and oup % 2 == 0 and y3 != 0 else nn.Sequential(), )
+ elif g2 == 0:
+ self.layers = nn.Sequential(
+ GroupConv(inp, hidden_dim2, gs1),
+ DYShiftMax(
+ hidden_dim2,
+ hidden_dim2,
+ act_max=2.0,
+ act_relu=False,
+ init_a=[1.0, 0.0],
+ reduction=act_reduction,
+ init_b=[0.0, 0.0],
+ g=gs1,
+ expansion=False) if y3 > 0 else nn.Sequential(), )
+ else:
+ self.layers = nn.Sequential(
+ GroupConv(inp, hidden_dim2, gs1),
+ DYShiftMax(
+ hidden_dim2,
+ hidden_dim2,
+ act_max=2.0,
+ act_relu=True if y1 == 2 else False,
+ init_a=init_a,
+ reduction=act_reduction,
+ init_b=init_b,
+ g=gs1,
+ expansion=False) if y1 > 0 else nn.ReLU6(),
+ ChannelShuffle(gs1[1]) if shuffle else nn.Sequential(),
+ DepthSpatialSepConv(hidden_dim2, (1, 1), kernel_size, stride)
+ if depthsep else
+ DepthConv(hidden_dim2, hidden_dim2, kernel_size, stride),
+ nn.Sequential(),
+ DYShiftMax(
+ hidden_dim2,
+ hidden_dim2,
+ act_max=2.0,
+ act_relu=True if y2 == 2 else False,
+ init_a=init_a,
+ reduction=act_reduction,
+ init_b=init_b,
+ g=gs1,
+ expansion=True) if y2 > 0 else nn.ReLU6(),
+ ChannelShuffle(hidden_dim2 // 4)
+ if shuffle and y1 != 0 and y2 != 0 else nn.Sequential()
+ if y1 == 0 and y2 == 0 else ChannelShuffle(hidden_dim2 // 2),
+ GroupConv(hidden_dim2, oup, (g1, g2)),
+ DYShiftMax(
+ oup,
+ oup,
+ act_max=2.0,
+ act_relu=False,
+ init_a=[1.0, 0.0],
+ reduction=act_reduction // 2
+ if oup < hidden_dim2 else act_reduction,
+ init_b=[0.0, 0.0],
+ g=(g1, g2),
+ expansion=False) if y3 > 0 else nn.Sequential(),
+ ChannelShuffle(g2) if shuffle else nn.Sequential(),
+ ChannelShuffle(oup // 2)
+ if shuffle and y3 != 0 else nn.Sequential(), )
+
+ def forward(self, x):
+ identity = x
+ out = self.layers(x)
+
+ if self.identity:
+ out = out + identity
+
+ return out
+
+
+class MicroNet(nn.Layer):
+ """
+ the MicroNet backbone network for recognition module.
+ Args:
+ mode(str): {'M0', 'M1', 'M2', 'M3'}
+ Four models are proposed based on four different computational costs (4M, 6M, 12M, 21M MAdds)
+ Default: 'M3'.
+ """
+
+ def __init__(self, mode='M3', **kwargs):
+ super(MicroNet, self).__init__()
+
+ self.cfgs = get_micronet_config(mode)
+
+ activation_cfg = {}
+ if mode == 'M0':
+ input_channel = 4
+ stem_groups = 2, 2
+ out_ch = 384
+ activation_cfg['init_a'] = 1.0, 1.0
+ activation_cfg['init_b'] = 0.0, 0.0
+ elif mode == 'M1':
+ input_channel = 6
+ stem_groups = 3, 2
+ out_ch = 576
+ activation_cfg['init_a'] = 1.0, 1.0
+ activation_cfg['init_b'] = 0.0, 0.0
+ elif mode == 'M2':
+ input_channel = 8
+ stem_groups = 4, 2
+ out_ch = 768
+ activation_cfg['init_a'] = 1.0, 1.0
+ activation_cfg['init_b'] = 0.0, 0.0
+ elif mode == 'M3':
+ input_channel = 12
+ stem_groups = 4, 3
+ out_ch = 432
+ activation_cfg['init_a'] = 1.0, 0.5
+ activation_cfg['init_b'] = 0.0, 0.5
+ else:
+ raise NotImplementedError("mode[" + mode +
+ "_model] is not implemented!")
+
+ layers = [StemLayer(3, input_channel, stride=2, groups=stem_groups)]
+
+ for idx, val in enumerate(self.cfgs):
+ s, n, c, ks, c1, c2, g1, g2, c3, g3, g4, y1, y2, y3, r = val
+
+ t1 = (c1, c2)
+ gs1 = (g1, g2)
+ gs2 = (c3, g3, g4)
+ activation_cfg['dy'] = [y1, y2, y3]
+ activation_cfg['ratio'] = r
+
+ output_channel = c
+ layers.append(
+ DYMicroBlock(
+ input_channel,
+ output_channel,
+ kernel_size=ks,
+ stride=s,
+ ch_exp=t1,
+ ch_per_group=gs1,
+ groups_1x1=gs2,
+ depthsep=True,
+ shuffle=True,
+ activation_cfg=activation_cfg, ))
+ input_channel = output_channel
+ for i in range(1, n):
+ layers.append(
+ DYMicroBlock(
+ input_channel,
+ output_channel,
+ kernel_size=ks,
+ stride=1,
+ ch_exp=t1,
+ ch_per_group=gs1,
+ groups_1x1=gs2,
+ depthsep=True,
+ shuffle=True,
+ activation_cfg=activation_cfg, ))
+ input_channel = output_channel
+ self.features = nn.Sequential(*layers)
+
+ self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
+
+ self.out_channels = make_divisible(out_ch)
+
+ def forward(self, x):
+ x = self.features(x)
+ x = self.pool(x)
+ return x
diff --git a/backend/ppocr/modeling/backbones/rec_mobilenet_v3.py b/backend/ppocr/modeling/backbones/rec_mobilenet_v3.py
index 1ff17159..917e000d 100644
--- a/backend/ppocr/modeling/backbones/rec_mobilenet_v3.py
+++ b/backend/ppocr/modeling/backbones/rec_mobilenet_v3.py
@@ -26,8 +26,10 @@ def __init__(self,
scale=0.5,
large_stride=None,
small_stride=None,
+ disable_se=False,
**kwargs):
super(MobileNetV3, self).__init__()
+ self.disable_se = disable_se
if small_stride is None:
small_stride = [2, 2, 2, 2]
if large_stride is None:
@@ -96,12 +98,12 @@ def __init__(self,
padding=1,
groups=1,
if_act=True,
- act='hardswish',
- name='conv1')
+ act='hardswish')
i = 0
block_list = []
inplanes = make_divisible(inplanes * scale)
for (k, exp, c, se, nl, s) in cfg:
+ se = se and not self.disable_se
block_list.append(
ResidualUnit(
in_channels=inplanes,
@@ -110,8 +112,7 @@ def __init__(self,
kernel_size=k,
stride=s,
use_se=se,
- act=nl,
- name='conv' + str(i + 2)))
+ act=nl))
inplanes = make_divisible(scale * c)
i += 1
self.blocks = nn.Sequential(*block_list)
@@ -124,8 +125,7 @@ def __init__(self,
padding=0,
groups=1,
if_act=True,
- act='hardswish',
- name='conv_last')
+ act='hardswish')
self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
self.out_channels = make_divisible(scale * cls_ch_squeeze)
diff --git a/backend/ppocr/modeling/backbones/rec_mv1_enhance.py b/backend/ppocr/modeling/backbones/rec_mv1_enhance.py
new file mode 100644
index 00000000..bb6af5e8
--- /dev/null
+++ b/backend/ppocr/modeling/backbones/rec_mv1_enhance.py
@@ -0,0 +1,256 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This code is refer from: https://github.com/PaddlePaddle/PaddleClas/blob/develop/ppcls/arch/backbone/legendary_models/pp_lcnet.py
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import numpy as np
+import paddle
+from paddle import ParamAttr, reshape, transpose
+import paddle.nn as nn
+import paddle.nn.functional as F
+from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
+from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
+from paddle.nn.initializer import KaimingNormal
+from paddle.regularizer import L2Decay
+from paddle.nn.functional import hardswish, hardsigmoid
+
+
+class ConvBNLayer(nn.Layer):
+ def __init__(self,
+ num_channels,
+ filter_size,
+ num_filters,
+ stride,
+ padding,
+ channels=None,
+ num_groups=1,
+ act='hard_swish'):
+ super(ConvBNLayer, self).__init__()
+
+ self._conv = Conv2D(
+ in_channels=num_channels,
+ out_channels=num_filters,
+ kernel_size=filter_size,
+ stride=stride,
+ padding=padding,
+ groups=num_groups,
+ weight_attr=ParamAttr(initializer=KaimingNormal()),
+ bias_attr=False)
+
+ self._batch_norm = BatchNorm(
+ num_filters,
+ act=act,
+ param_attr=ParamAttr(regularizer=L2Decay(0.0)),
+ bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
+
+ def forward(self, inputs):
+ y = self._conv(inputs)
+ y = self._batch_norm(y)
+ return y
+
+
+class DepthwiseSeparable(nn.Layer):
+ def __init__(self,
+ num_channels,
+ num_filters1,
+ num_filters2,
+ num_groups,
+ stride,
+ scale,
+ dw_size=3,
+ padding=1,
+ use_se=False):
+ super(DepthwiseSeparable, self).__init__()
+ self.use_se = use_se
+ self._depthwise_conv = ConvBNLayer(
+ num_channels=num_channels,
+ num_filters=int(num_filters1 * scale),
+ filter_size=dw_size,
+ stride=stride,
+ padding=padding,
+ num_groups=int(num_groups * scale))
+ if use_se:
+ self._se = SEModule(int(num_filters1 * scale))
+ self._pointwise_conv = ConvBNLayer(
+ num_channels=int(num_filters1 * scale),
+ filter_size=1,
+ num_filters=int(num_filters2 * scale),
+ stride=1,
+ padding=0)
+
+ def forward(self, inputs):
+ y = self._depthwise_conv(inputs)
+ if self.use_se:
+ y = self._se(y)
+ y = self._pointwise_conv(y)
+ return y
+
+
+class MobileNetV1Enhance(nn.Layer):
+ def __init__(self,
+ in_channels=3,
+ scale=0.5,
+ last_conv_stride=1,
+ last_pool_type='max',
+ **kwargs):
+ super().__init__()
+ self.scale = scale
+ self.block_list = []
+
+ self.conv1 = ConvBNLayer(
+ num_channels=3,
+ filter_size=3,
+ channels=3,
+ num_filters=int(32 * scale),
+ stride=2,
+ padding=1)
+
+ conv2_1 = DepthwiseSeparable(
+ num_channels=int(32 * scale),
+ num_filters1=32,
+ num_filters2=64,
+ num_groups=32,
+ stride=1,
+ scale=scale)
+ self.block_list.append(conv2_1)
+
+ conv2_2 = DepthwiseSeparable(
+ num_channels=int(64 * scale),
+ num_filters1=64,
+ num_filters2=128,
+ num_groups=64,
+ stride=1,
+ scale=scale)
+ self.block_list.append(conv2_2)
+
+ conv3_1 = DepthwiseSeparable(
+ num_channels=int(128 * scale),
+ num_filters1=128,
+ num_filters2=128,
+ num_groups=128,
+ stride=1,
+ scale=scale)
+ self.block_list.append(conv3_1)
+
+ conv3_2 = DepthwiseSeparable(
+ num_channels=int(128 * scale),
+ num_filters1=128,
+ num_filters2=256,
+ num_groups=128,
+ stride=(2, 1),
+ scale=scale)
+ self.block_list.append(conv3_2)
+
+ conv4_1 = DepthwiseSeparable(
+ num_channels=int(256 * scale),
+ num_filters1=256,
+ num_filters2=256,
+ num_groups=256,
+ stride=1,
+ scale=scale)
+ self.block_list.append(conv4_1)
+
+ conv4_2 = DepthwiseSeparable(
+ num_channels=int(256 * scale),
+ num_filters1=256,
+ num_filters2=512,
+ num_groups=256,
+ stride=(2, 1),
+ scale=scale)
+ self.block_list.append(conv4_2)
+
+ for _ in range(5):
+ conv5 = DepthwiseSeparable(
+ num_channels=int(512 * scale),
+ num_filters1=512,
+ num_filters2=512,
+ num_groups=512,
+ stride=1,
+ dw_size=5,
+ padding=2,
+ scale=scale,
+ use_se=False)
+ self.block_list.append(conv5)
+
+ conv5_6 = DepthwiseSeparable(
+ num_channels=int(512 * scale),
+ num_filters1=512,
+ num_filters2=1024,
+ num_groups=512,
+ stride=(2, 1),
+ dw_size=5,
+ padding=2,
+ scale=scale,
+ use_se=True)
+ self.block_list.append(conv5_6)
+
+ conv6 = DepthwiseSeparable(
+ num_channels=int(1024 * scale),
+ num_filters1=1024,
+ num_filters2=1024,
+ num_groups=1024,
+ stride=last_conv_stride,
+ dw_size=5,
+ padding=2,
+ use_se=True,
+ scale=scale)
+ self.block_list.append(conv6)
+
+ self.block_list = nn.Sequential(*self.block_list)
+ if last_pool_type == 'avg':
+ self.pool = nn.AvgPool2D(kernel_size=2, stride=2, padding=0)
+ else:
+ self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
+ self.out_channels = int(1024 * scale)
+
+ def forward(self, inputs):
+ y = self.conv1(inputs)
+ y = self.block_list(y)
+ y = self.pool(y)
+ return y
+
+
+class SEModule(nn.Layer):
+ def __init__(self, channel, reduction=4):
+ super(SEModule, self).__init__()
+ self.avg_pool = AdaptiveAvgPool2D(1)
+ self.conv1 = Conv2D(
+ in_channels=channel,
+ out_channels=channel // reduction,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ weight_attr=ParamAttr(),
+ bias_attr=ParamAttr())
+ self.conv2 = Conv2D(
+ in_channels=channel // reduction,
+ out_channels=channel,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ weight_attr=ParamAttr(),
+ bias_attr=ParamAttr())
+
+ def forward(self, inputs):
+ outputs = self.avg_pool(inputs)
+ outputs = self.conv1(outputs)
+ outputs = F.relu(outputs)
+ outputs = self.conv2(outputs)
+ outputs = hardsigmoid(outputs)
+ return paddle.multiply(x=inputs, y=outputs)
diff --git a/backend/ppocr/modeling/backbones/rec_nrtr_mtb.py b/backend/ppocr/modeling/backbones/rec_nrtr_mtb.py
new file mode 100644
index 00000000..22e02a63
--- /dev/null
+++ b/backend/ppocr/modeling/backbones/rec_nrtr_mtb.py
@@ -0,0 +1,48 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from paddle import nn
+import paddle
+
+
+class MTB(nn.Layer):
+ def __init__(self, cnn_num, in_channels):
+ super(MTB, self).__init__()
+ self.block = nn.Sequential()
+ self.out_channels = in_channels
+ self.cnn_num = cnn_num
+ if self.cnn_num == 2:
+ for i in range(self.cnn_num):
+ self.block.add_sublayer(
+ 'conv_{}'.format(i),
+ nn.Conv2D(
+ in_channels=in_channels
+ if i == 0 else 32 * (2**(i - 1)),
+ out_channels=32 * (2**i),
+ kernel_size=3,
+ stride=2,
+ padding=1))
+ self.block.add_sublayer('relu_{}'.format(i), nn.ReLU())
+ self.block.add_sublayer('bn_{}'.format(i),
+ nn.BatchNorm2D(32 * (2**i)))
+
+ def forward(self, images):
+ x = self.block(images)
+ if self.cnn_num == 2:
+ # (b, w, h, c)
+ x = paddle.transpose(x, [0, 3, 2, 1])
+ x_shape = paddle.shape(x)
+ x = paddle.reshape(
+ x, [x_shape[0], x_shape[1], x_shape[2] * x_shape[3]])
+ return x
diff --git a/backend/ppocr/modeling/backbones/rec_resnet_31.py b/backend/ppocr/modeling/backbones/rec_resnet_31.py
new file mode 100644
index 00000000..96517013
--- /dev/null
+++ b/backend/ppocr/modeling/backbones/rec_resnet_31.py
@@ -0,0 +1,210 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/layers/conv_layer.py
+https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/backbones/resnet31_ocr.py
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import ParamAttr
+import paddle.nn as nn
+import paddle.nn.functional as F
+import numpy as np
+
+__all__ = ["ResNet31"]
+
+
+def conv3x3(in_channel, out_channel, stride=1):
+ return nn.Conv2D(
+ in_channel,
+ out_channel,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ bias_attr=False)
+
+
+class BasicBlock(nn.Layer):
+ expansion = 1
+
+ def __init__(self, in_channels, channels, stride=1, downsample=False):
+ super().__init__()
+ self.conv1 = conv3x3(in_channels, channels, stride)
+ self.bn1 = nn.BatchNorm2D(channels)
+ self.relu = nn.ReLU()
+ self.conv2 = conv3x3(channels, channels)
+ self.bn2 = nn.BatchNorm2D(channels)
+ self.downsample = downsample
+ if downsample:
+ self.downsample = nn.Sequential(
+ nn.Conv2D(
+ in_channels,
+ channels * self.expansion,
+ 1,
+ stride,
+ bias_attr=False),
+ nn.BatchNorm2D(channels * self.expansion), )
+ else:
+ self.downsample = nn.Sequential()
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class ResNet31(nn.Layer):
+ '''
+ Args:
+ in_channels (int): Number of channels of input image tensor.
+ layers (list[int]): List of BasicBlock number for each stage.
+ channels (list[int]): List of out_channels of Conv2d layer.
+ out_indices (None | Sequence[int]): Indices of output stages.
+ last_stage_pool (bool): If True, add `MaxPool2d` layer to last stage.
+ '''
+
+ def __init__(self,
+ in_channels=3,
+ layers=[1, 2, 5, 3],
+ channels=[64, 128, 256, 256, 512, 512, 512],
+ out_indices=None,
+ last_stage_pool=False):
+ super(ResNet31, self).__init__()
+ assert isinstance(in_channels, int)
+ assert isinstance(last_stage_pool, bool)
+
+ self.out_indices = out_indices
+ self.last_stage_pool = last_stage_pool
+
+ # conv 1 (Conv Conv)
+ self.conv1_1 = nn.Conv2D(
+ in_channels, channels[0], kernel_size=3, stride=1, padding=1)
+ self.bn1_1 = nn.BatchNorm2D(channels[0])
+ self.relu1_1 = nn.ReLU()
+
+ self.conv1_2 = nn.Conv2D(
+ channels[0], channels[1], kernel_size=3, stride=1, padding=1)
+ self.bn1_2 = nn.BatchNorm2D(channels[1])
+ self.relu1_2 = nn.ReLU()
+
+ # conv 2 (Max-pooling, Residual block, Conv)
+ self.pool2 = nn.MaxPool2D(
+ kernel_size=2, stride=2, padding=0, ceil_mode=True)
+ self.block2 = self._make_layer(channels[1], channels[2], layers[0])
+ self.conv2 = nn.Conv2D(
+ channels[2], channels[2], kernel_size=3, stride=1, padding=1)
+ self.bn2 = nn.BatchNorm2D(channels[2])
+ self.relu2 = nn.ReLU()
+
+ # conv 3 (Max-pooling, Residual block, Conv)
+ self.pool3 = nn.MaxPool2D(
+ kernel_size=2, stride=2, padding=0, ceil_mode=True)
+ self.block3 = self._make_layer(channels[2], channels[3], layers[1])
+ self.conv3 = nn.Conv2D(
+ channels[3], channels[3], kernel_size=3, stride=1, padding=1)
+ self.bn3 = nn.BatchNorm2D(channels[3])
+ self.relu3 = nn.ReLU()
+
+ # conv 4 (Max-pooling, Residual block, Conv)
+ self.pool4 = nn.MaxPool2D(
+ kernel_size=(2, 1), stride=(2, 1), padding=0, ceil_mode=True)
+ self.block4 = self._make_layer(channels[3], channels[4], layers[2])
+ self.conv4 = nn.Conv2D(
+ channels[4], channels[4], kernel_size=3, stride=1, padding=1)
+ self.bn4 = nn.BatchNorm2D(channels[4])
+ self.relu4 = nn.ReLU()
+
+ # conv 5 ((Max-pooling), Residual block, Conv)
+ self.pool5 = None
+ if self.last_stage_pool:
+ self.pool5 = nn.MaxPool2D(
+ kernel_size=2, stride=2, padding=0, ceil_mode=True)
+ self.block5 = self._make_layer(channels[4], channels[5], layers[3])
+ self.conv5 = nn.Conv2D(
+ channels[5], channels[5], kernel_size=3, stride=1, padding=1)
+ self.bn5 = nn.BatchNorm2D(channels[5])
+ self.relu5 = nn.ReLU()
+
+ self.out_channels = channels[-1]
+
+ def _make_layer(self, input_channels, output_channels, blocks):
+ layers = []
+ for _ in range(blocks):
+ downsample = None
+ if input_channels != output_channels:
+ downsample = nn.Sequential(
+ nn.Conv2D(
+ input_channels,
+ output_channels,
+ kernel_size=1,
+ stride=1,
+ bias_attr=False),
+ nn.BatchNorm2D(output_channels), )
+
+ layers.append(
+ BasicBlock(
+ input_channels, output_channels, downsample=downsample))
+ input_channels = output_channels
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.conv1_1(x)
+ x = self.bn1_1(x)
+ x = self.relu1_1(x)
+
+ x = self.conv1_2(x)
+ x = self.bn1_2(x)
+ x = self.relu1_2(x)
+
+ outs = []
+ for i in range(4):
+ layer_index = i + 2
+ pool_layer = getattr(self, f'pool{layer_index}')
+ block_layer = getattr(self, f'block{layer_index}')
+ conv_layer = getattr(self, f'conv{layer_index}')
+ bn_layer = getattr(self, f'bn{layer_index}')
+ relu_layer = getattr(self, f'relu{layer_index}')
+
+ if pool_layer is not None:
+ x = pool_layer(x)
+ x = block_layer(x)
+ x = conv_layer(x)
+ x = bn_layer(x)
+ x = relu_layer(x)
+
+ outs.append(x)
+
+ if self.out_indices is not None:
+ return tuple([outs[i] for i in self.out_indices])
+
+ return x
diff --git a/backend/ppocr/modeling/backbones/rec_resnet_aster.py b/backend/ppocr/modeling/backbones/rec_resnet_aster.py
new file mode 100644
index 00000000..6a2710df
--- /dev/null
+++ b/backend/ppocr/modeling/backbones/rec_resnet_aster.py
@@ -0,0 +1,143 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/ayumiymk/aster.pytorch/blob/master/lib/models/resnet_aster.py
+"""
+import paddle
+import paddle.nn as nn
+
+import sys
+import math
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2D(
+ in_planes,
+ out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ bias_attr=False)
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2D(
+ in_planes, out_planes, kernel_size=1, stride=stride, bias_attr=False)
+
+
+def get_sinusoid_encoding(n_position, feat_dim, wave_length=10000):
+ # [n_position]
+ positions = paddle.arange(0, n_position)
+ # [feat_dim]
+ dim_range = paddle.arange(0, feat_dim)
+ dim_range = paddle.pow(wave_length, 2 * (dim_range // 2) / feat_dim)
+ # [n_position, feat_dim]
+ angles = paddle.unsqueeze(
+ positions, axis=1) / paddle.unsqueeze(
+ dim_range, axis=0)
+ angles = paddle.cast(angles, "float32")
+ angles[:, 0::2] = paddle.sin(angles[:, 0::2])
+ angles[:, 1::2] = paddle.cos(angles[:, 1::2])
+ return angles
+
+
+class AsterBlock(nn.Layer):
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(AsterBlock, self).__init__()
+ self.conv1 = conv1x1(inplanes, planes, stride)
+ self.bn1 = nn.BatchNorm2D(planes)
+ self.relu = nn.ReLU()
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = nn.BatchNorm2D(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+ out += residual
+ out = self.relu(out)
+ return out
+
+
+class ResNet_ASTER(nn.Layer):
+ """For aster or crnn"""
+
+ def __init__(self, with_lstm=True, n_group=1, in_channels=3):
+ super(ResNet_ASTER, self).__init__()
+ self.with_lstm = with_lstm
+ self.n_group = n_group
+
+ self.layer0 = nn.Sequential(
+ nn.Conv2D(
+ in_channels,
+ 32,
+ kernel_size=(3, 3),
+ stride=1,
+ padding=1,
+ bias_attr=False),
+ nn.BatchNorm2D(32),
+ nn.ReLU())
+
+ self.inplanes = 32
+ self.layer1 = self._make_layer(32, 3, [2, 2]) # [16, 50]
+ self.layer2 = self._make_layer(64, 4, [2, 2]) # [8, 25]
+ self.layer3 = self._make_layer(128, 6, [2, 1]) # [4, 25]
+ self.layer4 = self._make_layer(256, 6, [2, 1]) # [2, 25]
+ self.layer5 = self._make_layer(512, 3, [2, 1]) # [1, 25]
+
+ if with_lstm:
+ self.rnn = nn.LSTM(512, 256, direction="bidirect", num_layers=2)
+ self.out_channels = 2 * 256
+ else:
+ self.out_channels = 512
+
+ def _make_layer(self, planes, blocks, stride):
+ downsample = None
+ if stride != [1, 1] or self.inplanes != planes:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes, stride), nn.BatchNorm2D(planes))
+
+ layers = []
+ layers.append(AsterBlock(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes
+ for _ in range(1, blocks):
+ layers.append(AsterBlock(self.inplanes, planes))
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x0 = self.layer0(x)
+ x1 = self.layer1(x0)
+ x2 = self.layer2(x1)
+ x3 = self.layer3(x2)
+ x4 = self.layer4(x3)
+ x5 = self.layer5(x4)
+
+ cnn_feat = x5.squeeze(2) # [N, c, w]
+ cnn_feat = paddle.transpose(cnn_feat, perm=[0, 2, 1])
+ if self.with_lstm:
+ rnn_feat, _ = self.rnn(cnn_feat)
+ return rnn_feat
+ else:
+ return cnn_feat
diff --git a/backend/ppocr/modeling/backbones/rec_resnet_vd.py b/backend/ppocr/modeling/backbones/rec_resnet_vd.py
index 6837ea0f..0187deb9 100644
--- a/backend/ppocr/modeling/backbones/rec_resnet_vd.py
+++ b/backend/ppocr/modeling/backbones/rec_resnet_vd.py
@@ -249,7 +249,7 @@ def __init__(self, in_channels=3, layers=50, **kwargs):
name=conv_name))
shortcut = True
self.block_list.append(bottleneck_block)
- self.out_channels = num_filters[block]
+ self.out_channels = num_filters[block] * 4
else:
for block in range(len(depth)):
shortcut = False
diff --git a/backend/ppocr/modeling/backbones/rec_svtrnet.py b/backend/ppocr/modeling/backbones/rec_svtrnet.py
new file mode 100644
index 00000000..c57bf463
--- /dev/null
+++ b/backend/ppocr/modeling/backbones/rec_svtrnet.py
@@ -0,0 +1,584 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from paddle import ParamAttr
+from paddle.nn.initializer import KaimingNormal
+import numpy as np
+import paddle
+import paddle.nn as nn
+from paddle.nn.initializer import TruncatedNormal, Constant, Normal
+
+trunc_normal_ = TruncatedNormal(std=.02)
+normal_ = Normal
+zeros_ = Constant(value=0.)
+ones_ = Constant(value=1.)
+
+
+def drop_path(x, drop_prob=0., training=False):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
+ """
+ if drop_prob == 0. or not training:
+ return x
+ keep_prob = paddle.to_tensor(1 - drop_prob)
+ shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1)
+ random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
+ random_tensor = paddle.floor(random_tensor) # binarize
+ output = x.divide(keep_prob) * random_tensor
+ return output
+
+
+class ConvBNLayer(nn.Layer):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=0,
+ bias_attr=False,
+ groups=1,
+ act=nn.GELU):
+ super().__init__()
+ self.conv = nn.Conv2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ groups=groups,
+ weight_attr=paddle.ParamAttr(
+ initializer=nn.initializer.KaimingUniform()),
+ bias_attr=bias_attr)
+ self.norm = nn.BatchNorm2D(out_channels)
+ self.act = act()
+
+ def forward(self, inputs):
+ out = self.conv(inputs)
+ out = self.norm(out)
+ out = self.act(out)
+ return out
+
+
+class DropPath(nn.Layer):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ """
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+
+class Identity(nn.Layer):
+ def __init__(self):
+ super(Identity, self).__init__()
+
+ def forward(self, input):
+ return input
+
+
+class Mlp(nn.Layer):
+ def __init__(self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class ConvMixer(nn.Layer):
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ HW=[8, 25],
+ local_k=[3, 3], ):
+ super().__init__()
+ self.HW = HW
+ self.dim = dim
+ self.local_mixer = nn.Conv2D(
+ dim,
+ dim,
+ local_k,
+ 1, [local_k[0] // 2, local_k[1] // 2],
+ groups=num_heads,
+ weight_attr=ParamAttr(initializer=KaimingNormal()))
+
+ def forward(self, x):
+ h = self.HW[0]
+ w = self.HW[1]
+ x = x.transpose([0, 2, 1]).reshape([0, self.dim, h, w])
+ x = self.local_mixer(x)
+ x = x.flatten(2).transpose([0, 2, 1])
+ return x
+
+
+class Attention(nn.Layer):
+ def __init__(self,
+ dim,
+ num_heads=8,
+ mixer='Global',
+ HW=[8, 25],
+ local_k=[7, 11],
+ qkv_bias=False,
+ qk_scale=None,
+ attn_drop=0.,
+ proj_drop=0.):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.HW = HW
+ if HW is not None:
+ H = HW[0]
+ W = HW[1]
+ self.N = H * W
+ self.C = dim
+ if mixer == 'Local' and HW is not None:
+ hk = local_k[0]
+ wk = local_k[1]
+ mask = paddle.ones([H * W, H + hk - 1, W + wk - 1], dtype='float32')
+ for h in range(0, H):
+ for w in range(0, W):
+ mask[h * W + w, h:h + hk, w:w + wk] = 0.
+ mask_paddle = mask[:, hk // 2:H + hk // 2, wk // 2:W + wk //
+ 2].flatten(1)
+ mask_inf = paddle.full([H * W, H * W], '-inf', dtype='float32')
+ mask = paddle.where(mask_paddle < 1, mask_paddle, mask_inf)
+ self.mask = mask.unsqueeze([0, 1])
+ self.mixer = mixer
+
+ def forward(self, x):
+ if self.HW is not None:
+ N = self.N
+ C = self.C
+ else:
+ _, N, C = x.shape
+ qkv = self.qkv(x).reshape((0, N, 3, self.num_heads, C //
+ self.num_heads)).transpose((2, 0, 3, 1, 4))
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
+
+ attn = (q.matmul(k.transpose((0, 1, 3, 2))))
+ if self.mixer == 'Local':
+ attn += self.mask
+ attn = nn.functional.softmax(attn, axis=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn.matmul(v)).transpose((0, 2, 1, 3)).reshape((0, N, C))
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Layer):
+ def __init__(self,
+ dim,
+ num_heads,
+ mixer='Global',
+ local_mixer=[7, 11],
+ HW=[8, 25],
+ mlp_ratio=4.,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.,
+ attn_drop=0.,
+ drop_path=0.,
+ act_layer=nn.GELU,
+ norm_layer='nn.LayerNorm',
+ epsilon=1e-6,
+ prenorm=True):
+ super().__init__()
+ if isinstance(norm_layer, str):
+ self.norm1 = eval(norm_layer)(dim, epsilon=epsilon)
+ else:
+ self.norm1 = norm_layer(dim)
+ if mixer == 'Global' or mixer == 'Local':
+ self.mixer = Attention(
+ dim,
+ num_heads=num_heads,
+ mixer=mixer,
+ HW=HW,
+ local_k=local_mixer,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop)
+ elif mixer == 'Conv':
+ self.mixer = ConvMixer(
+ dim, num_heads=num_heads, HW=HW, local_k=local_mixer)
+ else:
+ raise TypeError("The mixer must be one of [Global, Local, Conv]")
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
+ if isinstance(norm_layer, str):
+ self.norm2 = eval(norm_layer)(dim, epsilon=epsilon)
+ else:
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp_ratio = mlp_ratio
+ self.mlp = Mlp(in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop)
+ self.prenorm = prenorm
+
+ def forward(self, x):
+ if self.prenorm:
+ x = self.norm1(x + self.drop_path(self.mixer(x)))
+ x = self.norm2(x + self.drop_path(self.mlp(x)))
+ else:
+ x = x + self.drop_path(self.mixer(self.norm1(x)))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class PatchEmbed(nn.Layer):
+ """ Image to Patch Embedding
+ """
+
+ def __init__(self,
+ img_size=[32, 100],
+ in_channels=3,
+ embed_dim=768,
+ sub_num=2):
+ super().__init__()
+ num_patches = (img_size[1] // (2 ** sub_num)) * \
+ (img_size[0] // (2 ** sub_num))
+ self.img_size = img_size
+ self.num_patches = num_patches
+ self.embed_dim = embed_dim
+ self.norm = None
+ if sub_num == 2:
+ self.proj = nn.Sequential(
+ ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=embed_dim // 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias_attr=None),
+ ConvBNLayer(
+ in_channels=embed_dim // 2,
+ out_channels=embed_dim,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias_attr=None))
+ if sub_num == 3:
+ self.proj = nn.Sequential(
+ ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=embed_dim // 4,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias_attr=None),
+ ConvBNLayer(
+ in_channels=embed_dim // 4,
+ out_channels=embed_dim // 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias_attr=None),
+ ConvBNLayer(
+ in_channels=embed_dim // 2,
+ out_channels=embed_dim,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias_attr=None))
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ assert H == self.img_size[0] and W == self.img_size[1], \
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+ x = self.proj(x).flatten(2).transpose((0, 2, 1))
+ return x
+
+
+class SubSample(nn.Layer):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ types='Pool',
+ stride=[2, 1],
+ sub_norm='nn.LayerNorm',
+ act=None):
+ super().__init__()
+ self.types = types
+ if types == 'Pool':
+ self.avgpool = nn.AvgPool2D(
+ kernel_size=[3, 5], stride=stride, padding=[1, 2])
+ self.maxpool = nn.MaxPool2D(
+ kernel_size=[3, 5], stride=stride, padding=[1, 2])
+ self.proj = nn.Linear(in_channels, out_channels)
+ else:
+ self.conv = nn.Conv2D(
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ weight_attr=ParamAttr(initializer=KaimingNormal()))
+ self.norm = eval(sub_norm)(out_channels)
+ if act is not None:
+ self.act = act()
+ else:
+ self.act = None
+
+ def forward(self, x):
+
+ if self.types == 'Pool':
+ x1 = self.avgpool(x)
+ x2 = self.maxpool(x)
+ x = (x1 + x2) * 0.5
+ out = self.proj(x.flatten(2).transpose((0, 2, 1)))
+ else:
+ x = self.conv(x)
+ out = x.flatten(2).transpose((0, 2, 1))
+ out = self.norm(out)
+ if self.act is not None:
+ out = self.act(out)
+
+ return out
+
+
+class SVTRNet(nn.Layer):
+ def __init__(
+ self,
+ img_size=[32, 100],
+ in_channels=3,
+ embed_dim=[64, 128, 256],
+ depth=[3, 6, 3],
+ num_heads=[2, 4, 8],
+ mixer=['Local'] * 6 + ['Global'] *
+ 6, # Local atten, Global atten, Conv
+ local_mixer=[[7, 11], [7, 11], [7, 11]],
+ patch_merging='Conv', # Conv, Pool, None
+ mlp_ratio=4,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ last_drop=0.1,
+ attn_drop_rate=0.,
+ drop_path_rate=0.1,
+ norm_layer='nn.LayerNorm',
+ sub_norm='nn.LayerNorm',
+ epsilon=1e-6,
+ out_channels=192,
+ out_char_num=25,
+ block_unit='Block',
+ act='nn.GELU',
+ last_stage=True,
+ sub_num=2,
+ prenorm=True,
+ use_lenhead=False,
+ **kwargs):
+ super().__init__()
+ self.img_size = img_size
+ self.embed_dim = embed_dim
+ self.out_channels = out_channels
+ self.prenorm = prenorm
+ patch_merging = None if patch_merging != 'Conv' and patch_merging != 'Pool' else patch_merging
+ self.patch_embed = PatchEmbed(
+ img_size=img_size,
+ in_channels=in_channels,
+ embed_dim=embed_dim[0],
+ sub_num=sub_num)
+ num_patches = self.patch_embed.num_patches
+ self.HW = [img_size[0] // (2**sub_num), img_size[1] // (2**sub_num)]
+ self.pos_embed = self.create_parameter(
+ shape=[1, num_patches, embed_dim[0]], default_initializer=zeros_)
+ self.add_parameter("pos_embed", self.pos_embed)
+ self.pos_drop = nn.Dropout(p=drop_rate)
+ Block_unit = eval(block_unit)
+
+ dpr = np.linspace(0, drop_path_rate, sum(depth))
+ self.blocks1 = nn.LayerList([
+ Block_unit(
+ dim=embed_dim[0],
+ num_heads=num_heads[0],
+ mixer=mixer[0:depth[0]][i],
+ HW=self.HW,
+ local_mixer=local_mixer[0],
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ act_layer=eval(act),
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[0:depth[0]][i],
+ norm_layer=norm_layer,
+ epsilon=epsilon,
+ prenorm=prenorm) for i in range(depth[0])
+ ])
+ if patch_merging is not None:
+ self.sub_sample1 = SubSample(
+ embed_dim[0],
+ embed_dim[1],
+ sub_norm=sub_norm,
+ stride=[2, 1],
+ types=patch_merging)
+ HW = [self.HW[0] // 2, self.HW[1]]
+ else:
+ HW = self.HW
+ self.patch_merging = patch_merging
+ self.blocks2 = nn.LayerList([
+ Block_unit(
+ dim=embed_dim[1],
+ num_heads=num_heads[1],
+ mixer=mixer[depth[0]:depth[0] + depth[1]][i],
+ HW=HW,
+ local_mixer=local_mixer[1],
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ act_layer=eval(act),
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[depth[0]:depth[0] + depth[1]][i],
+ norm_layer=norm_layer,
+ epsilon=epsilon,
+ prenorm=prenorm) for i in range(depth[1])
+ ])
+ if patch_merging is not None:
+ self.sub_sample2 = SubSample(
+ embed_dim[1],
+ embed_dim[2],
+ sub_norm=sub_norm,
+ stride=[2, 1],
+ types=patch_merging)
+ HW = [self.HW[0] // 4, self.HW[1]]
+ else:
+ HW = self.HW
+ self.blocks3 = nn.LayerList([
+ Block_unit(
+ dim=embed_dim[2],
+ num_heads=num_heads[2],
+ mixer=mixer[depth[0] + depth[1]:][i],
+ HW=HW,
+ local_mixer=local_mixer[2],
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ act_layer=eval(act),
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[depth[0] + depth[1]:][i],
+ norm_layer=norm_layer,
+ epsilon=epsilon,
+ prenorm=prenorm) for i in range(depth[2])
+ ])
+ self.last_stage = last_stage
+ if last_stage:
+ self.avg_pool = nn.AdaptiveAvgPool2D([1, out_char_num])
+ self.last_conv = nn.Conv2D(
+ in_channels=embed_dim[2],
+ out_channels=self.out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias_attr=False)
+ self.hardswish = nn.Hardswish()
+ self.dropout = nn.Dropout(p=last_drop, mode="downscale_in_infer")
+ if not prenorm:
+ self.norm = eval(norm_layer)(embed_dim[-1], epsilon=epsilon)
+ self.use_lenhead = use_lenhead
+ if use_lenhead:
+ self.len_conv = nn.Linear(embed_dim[2], self.out_channels)
+ self.hardswish_len = nn.Hardswish()
+ self.dropout_len = nn.Dropout(
+ p=last_drop, mode="downscale_in_infer")
+
+ trunc_normal_(self.pos_embed)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ zeros_(m.bias)
+ elif isinstance(m, nn.LayerNorm):
+ zeros_(m.bias)
+ ones_(m.weight)
+
+ def forward_features(self, x):
+ x = self.patch_embed(x)
+ x = x + self.pos_embed
+ x = self.pos_drop(x)
+ for blk in self.blocks1:
+ x = blk(x)
+ if self.patch_merging is not None:
+ x = self.sub_sample1(
+ x.transpose([0, 2, 1]).reshape(
+ [0, self.embed_dim[0], self.HW[0], self.HW[1]]))
+ for blk in self.blocks2:
+ x = blk(x)
+ if self.patch_merging is not None:
+ x = self.sub_sample2(
+ x.transpose([0, 2, 1]).reshape(
+ [0, self.embed_dim[1], self.HW[0] // 2, self.HW[1]]))
+ for blk in self.blocks3:
+ x = blk(x)
+ if not self.prenorm:
+ x = self.norm(x)
+ return x
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ if self.use_lenhead:
+ len_x = self.len_conv(x.mean(1))
+ len_x = self.dropout_len(self.hardswish_len(len_x))
+ if self.last_stage:
+ if self.patch_merging is not None:
+ h = self.HW[0] // 4
+ else:
+ h = self.HW[0]
+ x = self.avg_pool(
+ x.transpose([0, 2, 1]).reshape(
+ [0, self.embed_dim[2], h, self.HW[1]]))
+ x = self.last_conv(x)
+ x = self.hardswish(x)
+ x = self.dropout(x)
+ if self.use_lenhead:
+ return x, len_x
+ return x
diff --git a/backend/ppocr/modeling/backbones/vqa_layoutlm.py b/backend/ppocr/modeling/backbones/vqa_layoutlm.py
new file mode 100644
index 00000000..ede5b7a3
--- /dev/null
+++ b/backend/ppocr/modeling/backbones/vqa_layoutlm.py
@@ -0,0 +1,172 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+from paddle import nn
+
+from paddlenlp.transformers import LayoutXLMModel, LayoutXLMForTokenClassification, LayoutXLMForRelationExtraction
+from paddlenlp.transformers import LayoutLMModel, LayoutLMForTokenClassification
+from paddlenlp.transformers import LayoutLMv2Model, LayoutLMv2ForTokenClassification, LayoutLMv2ForRelationExtraction
+
+__all__ = ["LayoutXLMForSer", 'LayoutLMForSer']
+
+pretrained_model_dict = {
+ LayoutXLMModel: 'layoutxlm-base-uncased',
+ LayoutLMModel: 'layoutlm-base-uncased',
+ LayoutLMv2Model: 'layoutlmv2-base-uncased'
+}
+
+
+class NLPBaseModel(nn.Layer):
+ def __init__(self,
+ base_model_class,
+ model_class,
+ type='ser',
+ pretrained=True,
+ checkpoints=None,
+ **kwargs):
+ super(NLPBaseModel, self).__init__()
+ if checkpoints is not None:
+ self.model = model_class.from_pretrained(checkpoints)
+ else:
+ pretrained_model_name = pretrained_model_dict[base_model_class]
+ if pretrained:
+ base_model = base_model_class.from_pretrained(
+ pretrained_model_name)
+ else:
+ base_model = base_model_class(
+ **base_model_class.pretrained_init_configuration[
+ pretrained_model_name])
+ if type == 'ser':
+ self.model = model_class(
+ base_model, num_classes=kwargs['num_classes'], dropout=None)
+ else:
+ self.model = model_class(base_model, dropout=None)
+ self.out_channels = 1
+
+
+class LayoutLMForSer(NLPBaseModel):
+ def __init__(self, num_classes, pretrained=True, checkpoints=None,
+ **kwargs):
+ super(LayoutLMForSer, self).__init__(
+ LayoutLMModel,
+ LayoutLMForTokenClassification,
+ 'ser',
+ pretrained,
+ checkpoints,
+ num_classes=num_classes)
+
+ def forward(self, x):
+ x = self.model(
+ input_ids=x[0],
+ bbox=x[2],
+ attention_mask=x[4],
+ token_type_ids=x[5],
+ position_ids=None,
+ output_hidden_states=False)
+ return x
+
+
+class LayoutLMv2ForSer(NLPBaseModel):
+ def __init__(self, num_classes, pretrained=True, checkpoints=None,
+ **kwargs):
+ super(LayoutLMv2ForSer, self).__init__(
+ LayoutLMv2Model,
+ LayoutLMv2ForTokenClassification,
+ 'ser',
+ pretrained,
+ checkpoints,
+ num_classes=num_classes)
+
+ def forward(self, x):
+ x = self.model(
+ input_ids=x[0],
+ bbox=x[2],
+ image=x[3],
+ attention_mask=x[4],
+ token_type_ids=x[5],
+ position_ids=None,
+ head_mask=None,
+ labels=None)
+ return x[0]
+
+
+class LayoutXLMForSer(NLPBaseModel):
+ def __init__(self, num_classes, pretrained=True, checkpoints=None,
+ **kwargs):
+ super(LayoutXLMForSer, self).__init__(
+ LayoutXLMModel,
+ LayoutXLMForTokenClassification,
+ 'ser',
+ pretrained,
+ checkpoints,
+ num_classes=num_classes)
+
+ def forward(self, x):
+ x = self.model(
+ input_ids=x[0],
+ bbox=x[2],
+ image=x[3],
+ attention_mask=x[4],
+ token_type_ids=x[5],
+ position_ids=None,
+ head_mask=None,
+ labels=None)
+ return x[0]
+
+
+class LayoutLMv2ForRe(NLPBaseModel):
+ def __init__(self, pretrained=True, checkpoints=None, **kwargs):
+ super(LayoutLMv2ForRe, self).__init__(LayoutLMv2Model,
+ LayoutLMv2ForRelationExtraction,
+ 're', pretrained, checkpoints)
+
+ def forward(self, x):
+ x = self.model(
+ input_ids=x[0],
+ bbox=x[1],
+ labels=None,
+ image=x[2],
+ attention_mask=x[3],
+ token_type_ids=x[4],
+ position_ids=None,
+ head_mask=None,
+ entities=x[5],
+ relations=x[6])
+ return x
+
+
+class LayoutXLMForRe(NLPBaseModel):
+ def __init__(self, pretrained=True, checkpoints=None, **kwargs):
+ super(LayoutXLMForRe, self).__init__(LayoutXLMModel,
+ LayoutXLMForRelationExtraction,
+ 're', pretrained, checkpoints)
+
+ def forward(self, x):
+ x = self.model(
+ input_ids=x[0],
+ bbox=x[1],
+ labels=None,
+ image=x[2],
+ attention_mask=x[3],
+ token_type_ids=x[4],
+ position_ids=None,
+ head_mask=None,
+ entities=x[5],
+ relations=x[6])
+ return x
diff --git a/backend/ppocr/modeling/heads/__init__.py b/backend/ppocr/modeling/heads/__init__.py
index efe05718..1670ea38 100755
--- a/backend/ppocr/modeling/heads/__init__.py
+++ b/backend/ppocr/modeling/heads/__init__.py
@@ -20,19 +20,37 @@ def build_head(config):
from .det_db_head import DBHead
from .det_east_head import EASTHead
from .det_sast_head import SASTHead
+ from .det_pse_head import PSEHead
+ from .det_fce_head import FCEHead
+ from .e2e_pg_head import PGHead
# rec head
from .rec_ctc_head import CTCHead
from .rec_att_head import AttentionHead
from .rec_srn_head import SRNHead
+ from .rec_nrtr_head import Transformer
+ from .rec_sar_head import SARHead
+ from .rec_aster_head import AsterHead
+ from .rec_pren_head import PRENHead
+ from .rec_multi_head import MultiHead
# cls head
from .cls_head import ClsHead
+
+ #kie head
+ from .kie_sdmgr_head import SDMGRHead
+
+ from .table_att_head import TableAttentionHead
+
support_dict = [
- 'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead',
- 'SRNHead'
+ 'DBHead', 'PSEHead', 'FCEHead', 'EASTHead', 'SASTHead', 'CTCHead',
+ 'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
+ 'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
+ 'MultiHead'
]
+ #table head
+
module_name = config.pop('name')
assert module_name in support_dict, Exception('head only support {}'.format(
support_dict))
diff --git a/backend/ppocr/modeling/heads/cls_head.py b/backend/ppocr/modeling/heads/cls_head.py
index d9b78b84..91bfa615 100644
--- a/backend/ppocr/modeling/heads/cls_head.py
+++ b/backend/ppocr/modeling/heads/cls_head.py
@@ -43,7 +43,7 @@ def __init__(self, in_channels, class_dim, **kwargs):
initializer=nn.initializer.Uniform(-stdv, stdv)),
bias_attr=ParamAttr(name="fc_0.b_0"), )
- def forward(self, x):
+ def forward(self, x, targets=None):
x = self.pool(x)
x = paddle.reshape(x, shape=[x.shape[0], x.shape[1]])
x = self.fc(x)
diff --git a/backend/ppocr/modeling/heads/det_db_head.py b/backend/ppocr/modeling/heads/det_db_head.py
index ca18d74a..a686ae5a 100644
--- a/backend/ppocr/modeling/heads/det_db_head.py
+++ b/backend/ppocr/modeling/heads/det_db_head.py
@@ -23,64 +23,54 @@
from paddle import ParamAttr
-def get_bias_attr(k, name):
+def get_bias_attr(k):
stdv = 1.0 / math.sqrt(k * 1.0)
initializer = paddle.nn.initializer.Uniform(-stdv, stdv)
- bias_attr = ParamAttr(initializer=initializer, name=name + "_b_attr")
+ bias_attr = ParamAttr(initializer=initializer)
return bias_attr
class Head(nn.Layer):
- def __init__(self, in_channels, name_list):
+ def __init__(self, in_channels, name_list, kernel_list=[3, 2, 2], **kwargs):
super(Head, self).__init__()
+
self.conv1 = nn.Conv2D(
in_channels=in_channels,
out_channels=in_channels // 4,
- kernel_size=3,
- padding=1,
- weight_attr=ParamAttr(name=name_list[0] + '.w_0'),
+ kernel_size=kernel_list[0],
+ padding=int(kernel_list[0] // 2),
+ weight_attr=ParamAttr(),
bias_attr=False)
self.conv_bn1 = nn.BatchNorm(
num_channels=in_channels // 4,
param_attr=ParamAttr(
- name=name_list[1] + '.w_0',
initializer=paddle.nn.initializer.Constant(value=1.0)),
bias_attr=ParamAttr(
- name=name_list[1] + '.b_0',
initializer=paddle.nn.initializer.Constant(value=1e-4)),
- moving_mean_name=name_list[1] + '.w_1',
- moving_variance_name=name_list[1] + '.w_2',
act='relu')
self.conv2 = nn.Conv2DTranspose(
in_channels=in_channels // 4,
out_channels=in_channels // 4,
- kernel_size=2,
+ kernel_size=kernel_list[1],
stride=2,
weight_attr=ParamAttr(
- name=name_list[2] + '.w_0',
initializer=paddle.nn.initializer.KaimingUniform()),
- bias_attr=get_bias_attr(in_channels // 4, name_list[-1] + "conv2"))
+ bias_attr=get_bias_attr(in_channels // 4))
self.conv_bn2 = nn.BatchNorm(
num_channels=in_channels // 4,
param_attr=ParamAttr(
- name=name_list[3] + '.w_0',
initializer=paddle.nn.initializer.Constant(value=1.0)),
bias_attr=ParamAttr(
- name=name_list[3] + '.b_0',
initializer=paddle.nn.initializer.Constant(value=1e-4)),
- moving_mean_name=name_list[3] + '.w_1',
- moving_variance_name=name_list[3] + '.w_2',
act="relu")
self.conv3 = nn.Conv2DTranspose(
in_channels=in_channels // 4,
out_channels=1,
- kernel_size=2,
+ kernel_size=kernel_list[2],
stride=2,
weight_attr=ParamAttr(
- name=name_list[4] + '.w_0',
initializer=paddle.nn.initializer.KaimingUniform()),
- bias_attr=get_bias_attr(in_channels // 4, name_list[-1] + "conv3"),
- )
+ bias_attr=get_bias_attr(in_channels // 4), )
def forward(self, x):
x = self.conv1(x)
@@ -111,13 +101,13 @@ def __init__(self, in_channels, k=50, **kwargs):
'conv2d_57', 'batch_norm_49', 'conv2d_transpose_2', 'batch_norm_50',
'conv2d_transpose_3', 'thresh'
]
- self.binarize = Head(in_channels, binarize_name_list)
- self.thresh = Head(in_channels, thresh_name_list)
+ self.binarize = Head(in_channels, binarize_name_list, **kwargs)
+ self.thresh = Head(in_channels, thresh_name_list, **kwargs)
def step_function(self, x, y):
return paddle.reciprocal(1 + paddle.exp(-self.k * (x - y)))
- def forward(self, x):
+ def forward(self, x, targets=None):
shrink_maps = self.binarize(x)
if not self.training:
return {'maps': shrink_maps}
diff --git a/backend/ppocr/modeling/heads/det_east_head.py b/backend/ppocr/modeling/heads/det_east_head.py
index 9d0c3c4c..004eb5d7 100644
--- a/backend/ppocr/modeling/heads/det_east_head.py
+++ b/backend/ppocr/modeling/heads/det_east_head.py
@@ -109,7 +109,7 @@ def __init__(self, in_channels, model_name, **kwargs):
act=None,
name="f_geo")
- def forward(self, x):
+ def forward(self, x, targets=None):
f_det = self.det_conv1(x)
f_det = self.det_conv2(f_det)
f_score = self.score_conv(f_det)
diff --git a/backend/ppocr/modeling/heads/det_fce_head.py b/backend/ppocr/modeling/heads/det_fce_head.py
new file mode 100644
index 00000000..9503989f
--- /dev/null
+++ b/backend/ppocr/modeling/heads/det_fce_head.py
@@ -0,0 +1,99 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textdet/dense_heads/fce_head.py
+"""
+
+from paddle import nn
+from paddle import ParamAttr
+import paddle.nn.functional as F
+from paddle.nn.initializer import Normal
+import paddle
+from functools import partial
+
+
+def multi_apply(func, *args, **kwargs):
+ pfunc = partial(func, **kwargs) if kwargs else func
+ map_results = map(pfunc, *args)
+ return tuple(map(list, zip(*map_results)))
+
+
+class FCEHead(nn.Layer):
+ """The class for implementing FCENet head.
+ FCENet(CVPR2021): Fourier Contour Embedding for Arbitrary-shaped Text
+ Detection.
+
+ [https://arxiv.org/abs/2104.10442]
+
+ Args:
+ in_channels (int): The number of input channels.
+ scales (list[int]) : The scale of each layer.
+ fourier_degree (int) : The maximum Fourier transform degree k.
+ """
+
+ def __init__(self, in_channels, fourier_degree=5):
+ super().__init__()
+ assert isinstance(in_channels, int)
+
+ self.downsample_ratio = 1.0
+ self.in_channels = in_channels
+ self.fourier_degree = fourier_degree
+ self.out_channels_cls = 4
+ self.out_channels_reg = (2 * self.fourier_degree + 1) * 2
+
+ self.out_conv_cls = nn.Conv2D(
+ in_channels=self.in_channels,
+ out_channels=self.out_channels_cls,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ groups=1,
+ weight_attr=ParamAttr(
+ name='cls_weights',
+ initializer=Normal(
+ mean=0., std=0.01)),
+ bias_attr=True)
+ self.out_conv_reg = nn.Conv2D(
+ in_channels=self.in_channels,
+ out_channels=self.out_channels_reg,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ groups=1,
+ weight_attr=ParamAttr(
+ name='reg_weights',
+ initializer=Normal(
+ mean=0., std=0.01)),
+ bias_attr=True)
+
+ def forward(self, feats, targets=None):
+ cls_res, reg_res = multi_apply(self.forward_single, feats)
+ level_num = len(cls_res)
+ outs = {}
+ if not self.training:
+ for i in range(level_num):
+ tr_pred = F.softmax(cls_res[i][:, 0:2, :, :], axis=1)
+ tcl_pred = F.softmax(cls_res[i][:, 2:, :, :], axis=1)
+ outs['level_{}'.format(i)] = paddle.concat(
+ [tr_pred, tcl_pred, reg_res[i]], axis=1)
+ else:
+ preds = [[cls_res[i], reg_res[i]] for i in range(level_num)]
+ outs['levels'] = preds
+ return outs
+
+ def forward_single(self, x):
+ cls_predict = self.out_conv_cls(x)
+ reg_predict = self.out_conv_reg(x)
+ return cls_predict, reg_predict
diff --git a/backend/ppocr/modeling/heads/det_pse_head.py b/backend/ppocr/modeling/heads/det_pse_head.py
new file mode 100644
index 00000000..32a5b48e
--- /dev/null
+++ b/backend/ppocr/modeling/heads/det_pse_head.py
@@ -0,0 +1,37 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/whai362/PSENet/blob/python3/models/head/psenet_head.py
+"""
+
+from paddle import nn
+
+
+class PSEHead(nn.Layer):
+ def __init__(self, in_channels, hidden_dim=256, out_channels=7, **kwargs):
+ super(PSEHead, self).__init__()
+ self.conv1 = nn.Conv2D(
+ in_channels, hidden_dim, kernel_size=3, stride=1, padding=1)
+ self.bn1 = nn.BatchNorm2D(hidden_dim)
+ self.relu1 = nn.ReLU()
+
+ self.conv2 = nn.Conv2D(
+ hidden_dim, out_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, x, **kwargs):
+ out = self.conv1(x)
+ out = self.relu1(self.bn1(out))
+ out = self.conv2(out)
+ return {'maps': out}
diff --git a/backend/ppocr/modeling/heads/det_sast_head.py b/backend/ppocr/modeling/heads/det_sast_head.py
index 263b2867..7a88a2db 100644
--- a/backend/ppocr/modeling/heads/det_sast_head.py
+++ b/backend/ppocr/modeling/heads/det_sast_head.py
@@ -116,7 +116,7 @@ def __init__(self, in_channels, **kwargs):
self.head1 = SAST_Header1(in_channels)
self.head2 = SAST_Header2(in_channels)
- def forward(self, x):
+ def forward(self, x, targets=None):
f_score, f_border = self.head1(x)
f_tvo, f_tco = self.head2(x)
diff --git a/backend/ppocr/modeling/heads/e2e_pg_head.py b/backend/ppocr/modeling/heads/e2e_pg_head.py
new file mode 100644
index 00000000..274e1cda
--- /dev/null
+++ b/backend/ppocr/modeling/heads/e2e_pg_head.py
@@ -0,0 +1,253 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import paddle
+from paddle import nn
+import paddle.nn.functional as F
+from paddle import ParamAttr
+
+
+class ConvBNLayer(nn.Layer):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ groups=1,
+ if_act=True,
+ act=None,
+ name=None):
+ super(ConvBNLayer, self).__init__()
+ self.if_act = if_act
+ self.act = act
+ self.conv = nn.Conv2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ groups=groups,
+ weight_attr=ParamAttr(name=name + '_weights'),
+ bias_attr=False)
+
+ self.bn = nn.BatchNorm(
+ num_channels=out_channels,
+ act=act,
+ param_attr=ParamAttr(name="bn_" + name + "_scale"),
+ bias_attr=ParamAttr(name="bn_" + name + "_offset"),
+ moving_mean_name="bn_" + name + "_mean",
+ moving_variance_name="bn_" + name + "_variance",
+ use_global_stats=False)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ return x
+
+
+class PGHead(nn.Layer):
+ """
+ """
+
+ def __init__(self, in_channels, **kwargs):
+ super(PGHead, self).__init__()
+ self.conv_f_score1 = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=64,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ act='relu',
+ name="conv_f_score{}".format(1))
+ self.conv_f_score2 = ConvBNLayer(
+ in_channels=64,
+ out_channels=64,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ act='relu',
+ name="conv_f_score{}".format(2))
+ self.conv_f_score3 = ConvBNLayer(
+ in_channels=64,
+ out_channels=128,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ act='relu',
+ name="conv_f_score{}".format(3))
+
+ self.conv1 = nn.Conv2D(
+ in_channels=128,
+ out_channels=1,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ groups=1,
+ weight_attr=ParamAttr(name="conv_f_score{}".format(4)),
+ bias_attr=False)
+
+ self.conv_f_boder1 = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=64,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ act='relu',
+ name="conv_f_boder{}".format(1))
+ self.conv_f_boder2 = ConvBNLayer(
+ in_channels=64,
+ out_channels=64,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ act='relu',
+ name="conv_f_boder{}".format(2))
+ self.conv_f_boder3 = ConvBNLayer(
+ in_channels=64,
+ out_channels=128,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ act='relu',
+ name="conv_f_boder{}".format(3))
+ self.conv2 = nn.Conv2D(
+ in_channels=128,
+ out_channels=4,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ groups=1,
+ weight_attr=ParamAttr(name="conv_f_boder{}".format(4)),
+ bias_attr=False)
+ self.conv_f_char1 = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=128,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ act='relu',
+ name="conv_f_char{}".format(1))
+ self.conv_f_char2 = ConvBNLayer(
+ in_channels=128,
+ out_channels=128,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ act='relu',
+ name="conv_f_char{}".format(2))
+ self.conv_f_char3 = ConvBNLayer(
+ in_channels=128,
+ out_channels=256,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ act='relu',
+ name="conv_f_char{}".format(3))
+ self.conv_f_char4 = ConvBNLayer(
+ in_channels=256,
+ out_channels=256,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ act='relu',
+ name="conv_f_char{}".format(4))
+ self.conv_f_char5 = ConvBNLayer(
+ in_channels=256,
+ out_channels=256,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ act='relu',
+ name="conv_f_char{}".format(5))
+ self.conv3 = nn.Conv2D(
+ in_channels=256,
+ out_channels=37,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ groups=1,
+ weight_attr=ParamAttr(name="conv_f_char{}".format(6)),
+ bias_attr=False)
+
+ self.conv_f_direc1 = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=64,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ act='relu',
+ name="conv_f_direc{}".format(1))
+ self.conv_f_direc2 = ConvBNLayer(
+ in_channels=64,
+ out_channels=64,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ act='relu',
+ name="conv_f_direc{}".format(2))
+ self.conv_f_direc3 = ConvBNLayer(
+ in_channels=64,
+ out_channels=128,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ act='relu',
+ name="conv_f_direc{}".format(3))
+ self.conv4 = nn.Conv2D(
+ in_channels=128,
+ out_channels=2,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ groups=1,
+ weight_attr=ParamAttr(name="conv_f_direc{}".format(4)),
+ bias_attr=False)
+
+ def forward(self, x, targets=None):
+ f_score = self.conv_f_score1(x)
+ f_score = self.conv_f_score2(f_score)
+ f_score = self.conv_f_score3(f_score)
+ f_score = self.conv1(f_score)
+ f_score = F.sigmoid(f_score)
+
+ # f_border
+ f_border = self.conv_f_boder1(x)
+ f_border = self.conv_f_boder2(f_border)
+ f_border = self.conv_f_boder3(f_border)
+ f_border = self.conv2(f_border)
+
+ f_char = self.conv_f_char1(x)
+ f_char = self.conv_f_char2(f_char)
+ f_char = self.conv_f_char3(f_char)
+ f_char = self.conv_f_char4(f_char)
+ f_char = self.conv_f_char5(f_char)
+ f_char = self.conv3(f_char)
+
+ f_direction = self.conv_f_direc1(x)
+ f_direction = self.conv_f_direc2(f_direction)
+ f_direction = self.conv_f_direc3(f_direction)
+ f_direction = self.conv4(f_direction)
+
+ predicts = {}
+ predicts['f_score'] = f_score
+ predicts['f_border'] = f_border
+ predicts['f_char'] = f_char
+ predicts['f_direction'] = f_direction
+ return predicts
diff --git a/backend/ppocr/modeling/heads/kie_sdmgr_head.py b/backend/ppocr/modeling/heads/kie_sdmgr_head.py
new file mode 100644
index 00000000..ac5f73fa
--- /dev/null
+++ b/backend/ppocr/modeling/heads/kie_sdmgr_head.py
@@ -0,0 +1,207 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# reference from : https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/kie/heads/sdmgr_head.py
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import paddle
+from paddle import nn
+import paddle.nn.functional as F
+from paddle import ParamAttr
+
+
+class SDMGRHead(nn.Layer):
+ def __init__(self,
+ in_channels,
+ num_chars=92,
+ visual_dim=16,
+ fusion_dim=1024,
+ node_input=32,
+ node_embed=256,
+ edge_input=5,
+ edge_embed=256,
+ num_gnn=2,
+ num_classes=26,
+ bidirectional=False):
+ super().__init__()
+
+ self.fusion = Block([visual_dim, node_embed], node_embed, fusion_dim)
+ self.node_embed = nn.Embedding(num_chars, node_input, 0)
+ hidden = node_embed // 2 if bidirectional else node_embed
+ self.rnn = nn.LSTM(
+ input_size=node_input, hidden_size=hidden, num_layers=1)
+ self.edge_embed = nn.Linear(edge_input, edge_embed)
+ self.gnn_layers = nn.LayerList(
+ [GNNLayer(node_embed, edge_embed) for _ in range(num_gnn)])
+ self.node_cls = nn.Linear(node_embed, num_classes)
+ self.edge_cls = nn.Linear(edge_embed, 2)
+
+ def forward(self, input, targets):
+ relations, texts, x = input
+ node_nums, char_nums = [], []
+ for text in texts:
+ node_nums.append(text.shape[0])
+ char_nums.append(paddle.sum((text > -1).astype(int), axis=-1))
+
+ max_num = max([char_num.max() for char_num in char_nums])
+ all_nodes = paddle.concat([
+ paddle.concat(
+ [text, paddle.zeros(
+ (text.shape[0], max_num - text.shape[1]))], -1)
+ for text in texts
+ ])
+ temp = paddle.clip(all_nodes, min=0).astype(int)
+ embed_nodes = self.node_embed(temp)
+ rnn_nodes, _ = self.rnn(embed_nodes)
+
+ b, h, w = rnn_nodes.shape
+ nodes = paddle.zeros([b, w])
+ all_nums = paddle.concat(char_nums)
+ valid = paddle.nonzero((all_nums > 0).astype(int))
+ temp_all_nums = (
+ paddle.gather(all_nums, valid) - 1).unsqueeze(-1).unsqueeze(-1)
+ temp_all_nums = paddle.expand(temp_all_nums, [
+ temp_all_nums.shape[0], temp_all_nums.shape[1], rnn_nodes.shape[-1]
+ ])
+ temp_all_nodes = paddle.gather(rnn_nodes, valid)
+ N, C, A = temp_all_nodes.shape
+ one_hot = F.one_hot(
+ temp_all_nums[:, 0, :], num_classes=C).transpose([0, 2, 1])
+ one_hot = paddle.multiply(
+ temp_all_nodes, one_hot.astype("float32")).sum(axis=1, keepdim=True)
+ t = one_hot.expand([N, 1, A]).squeeze(1)
+ nodes = paddle.scatter(nodes, valid.squeeze(1), t)
+
+ if x is not None:
+ nodes = self.fusion([x, nodes])
+
+ all_edges = paddle.concat(
+ [rel.reshape([-1, rel.shape[-1]]) for rel in relations])
+ embed_edges = self.edge_embed(all_edges.astype('float32'))
+ embed_edges = F.normalize(embed_edges)
+
+ for gnn_layer in self.gnn_layers:
+ nodes, cat_nodes = gnn_layer(nodes, embed_edges, node_nums)
+
+ node_cls, edge_cls = self.node_cls(nodes), self.edge_cls(cat_nodes)
+ return node_cls, edge_cls
+
+
+class GNNLayer(nn.Layer):
+ def __init__(self, node_dim=256, edge_dim=256):
+ super().__init__()
+ self.in_fc = nn.Linear(node_dim * 2 + edge_dim, node_dim)
+ self.coef_fc = nn.Linear(node_dim, 1)
+ self.out_fc = nn.Linear(node_dim, node_dim)
+ self.relu = nn.ReLU()
+
+ def forward(self, nodes, edges, nums):
+ start, cat_nodes = 0, []
+ for num in nums:
+ sample_nodes = nodes[start:start + num]
+ cat_nodes.append(
+ paddle.concat([
+ paddle.expand(sample_nodes.unsqueeze(1), [-1, num, -1]),
+ paddle.expand(sample_nodes.unsqueeze(0), [num, -1, -1])
+ ], -1).reshape([num**2, -1]))
+ start += num
+ cat_nodes = paddle.concat([paddle.concat(cat_nodes), edges], -1)
+ cat_nodes = self.relu(self.in_fc(cat_nodes))
+ coefs = self.coef_fc(cat_nodes)
+
+ start, residuals = 0, []
+ for num in nums:
+ residual = F.softmax(
+ -paddle.eye(num).unsqueeze(-1) * 1e9 +
+ coefs[start:start + num**2].reshape([num, num, -1]), 1)
+ residuals.append((residual * cat_nodes[start:start + num**2]
+ .reshape([num, num, -1])).sum(1))
+ start += num**2
+
+ nodes += self.relu(self.out_fc(paddle.concat(residuals)))
+ return [nodes, cat_nodes]
+
+
+class Block(nn.Layer):
+ def __init__(self,
+ input_dims,
+ output_dim,
+ mm_dim=1600,
+ chunks=20,
+ rank=15,
+ shared=False,
+ dropout_input=0.,
+ dropout_pre_lin=0.,
+ dropout_output=0.,
+ pos_norm='before_cat'):
+ super().__init__()
+ self.rank = rank
+ self.dropout_input = dropout_input
+ self.dropout_pre_lin = dropout_pre_lin
+ self.dropout_output = dropout_output
+ assert (pos_norm in ['before_cat', 'after_cat'])
+ self.pos_norm = pos_norm
+ # Modules
+ self.linear0 = nn.Linear(input_dims[0], mm_dim)
+ self.linear1 = (self.linear0
+ if shared else nn.Linear(input_dims[1], mm_dim))
+ self.merge_linears0 = nn.LayerList()
+ self.merge_linears1 = nn.LayerList()
+ self.chunks = self.chunk_sizes(mm_dim, chunks)
+ for size in self.chunks:
+ ml0 = nn.Linear(size, size * rank)
+ self.merge_linears0.append(ml0)
+ ml1 = ml0 if shared else nn.Linear(size, size * rank)
+ self.merge_linears1.append(ml1)
+ self.linear_out = nn.Linear(mm_dim, output_dim)
+
+ def forward(self, x):
+ x0 = self.linear0(x[0])
+ x1 = self.linear1(x[1])
+ bs = x1.shape[0]
+ if self.dropout_input > 0:
+ x0 = F.dropout(x0, p=self.dropout_input, training=self.training)
+ x1 = F.dropout(x1, p=self.dropout_input, training=self.training)
+ x0_chunks = paddle.split(x0, self.chunks, -1)
+ x1_chunks = paddle.split(x1, self.chunks, -1)
+ zs = []
+ for x0_c, x1_c, m0, m1 in zip(x0_chunks, x1_chunks, self.merge_linears0,
+ self.merge_linears1):
+ m = m0(x0_c) * m1(x1_c) # bs x split_size*rank
+ m = m.reshape([bs, self.rank, -1])
+ z = paddle.sum(m, 1)
+ if self.pos_norm == 'before_cat':
+ z = paddle.sqrt(F.relu(z)) - paddle.sqrt(F.relu(-z))
+ z = F.normalize(z)
+ zs.append(z)
+ z = paddle.concat(zs, 1)
+ if self.pos_norm == 'after_cat':
+ z = paddle.sqrt(F.relu(z)) - paddle.sqrt(F.relu(-z))
+ z = F.normalize(z)
+
+ if self.dropout_pre_lin > 0:
+ z = F.dropout(z, p=self.dropout_pre_lin, training=self.training)
+ z = self.linear_out(z)
+ if self.dropout_output > 0:
+ z = F.dropout(z, p=self.dropout_output, training=self.training)
+ return z
+
+ def chunk_sizes(self, dim, chunks):
+ split_size = (dim + chunks - 1) // chunks
+ sizes_list = [split_size] * chunks
+ sizes_list[-1] = sizes_list[-1] - (sum(sizes_list) - dim)
+ return sizes_list
diff --git a/backend/ppocr/modeling/heads/multiheadAttention.py b/backend/ppocr/modeling/heads/multiheadAttention.py
new file mode 100755
index 00000000..900865ba
--- /dev/null
+++ b/backend/ppocr/modeling/heads/multiheadAttention.py
@@ -0,0 +1,163 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+from paddle import nn
+import paddle.nn.functional as F
+from paddle.nn import Linear
+from paddle.nn.initializer import XavierUniform as xavier_uniform_
+from paddle.nn.initializer import Constant as constant_
+from paddle.nn.initializer import XavierNormal as xavier_normal_
+
+zeros_ = constant_(value=0.)
+ones_ = constant_(value=1.)
+
+
+class MultiheadAttention(nn.Layer):
+ """Allows the model to jointly attend to information
+ from different representation subspaces.
+ See reference: Attention Is All You Need
+
+ .. math::
+ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
+ \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
+
+ Args:
+ embed_dim: total dimension of the model
+ num_heads: parallel attention layers, or heads
+
+ """
+
+ def __init__(self,
+ embed_dim,
+ num_heads,
+ dropout=0.,
+ bias=True,
+ add_bias_kv=False,
+ add_zero_attn=False):
+ super(MultiheadAttention, self).__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
+ self.scaling = self.head_dim**-0.5
+ self.out_proj = Linear(embed_dim, embed_dim, bias_attr=bias)
+ self._reset_parameters()
+ self.conv1 = paddle.nn.Conv2D(
+ in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
+ self.conv2 = paddle.nn.Conv2D(
+ in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
+ self.conv3 = paddle.nn.Conv2D(
+ in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
+
+ def _reset_parameters(self):
+ xavier_uniform_(self.out_proj.weight)
+
+ def forward(self,
+ query,
+ key,
+ value,
+ key_padding_mask=None,
+ incremental_state=None,
+ attn_mask=None):
+ """
+ Inputs of forward function
+ query: [target length, batch size, embed dim]
+ key: [sequence length, batch size, embed dim]
+ value: [sequence length, batch size, embed dim]
+ key_padding_mask: if True, mask padding based on batch size
+ incremental_state: if provided, previous time steps are cashed
+ need_weights: output attn_output_weights
+ static_kv: key and value are static
+
+ Outputs of forward function
+ attn_output: [target length, batch size, embed dim]
+ attn_output_weights: [batch size, target length, sequence length]
+ """
+ q_shape = paddle.shape(query)
+ src_shape = paddle.shape(key)
+ q = self._in_proj_q(query)
+ k = self._in_proj_k(key)
+ v = self._in_proj_v(value)
+ q *= self.scaling
+ q = paddle.transpose(
+ paddle.reshape(
+ q, [q_shape[0], q_shape[1], self.num_heads, self.head_dim]),
+ [1, 2, 0, 3])
+ k = paddle.transpose(
+ paddle.reshape(
+ k, [src_shape[0], q_shape[1], self.num_heads, self.head_dim]),
+ [1, 2, 0, 3])
+ v = paddle.transpose(
+ paddle.reshape(
+ v, [src_shape[0], q_shape[1], self.num_heads, self.head_dim]),
+ [1, 2, 0, 3])
+ if key_padding_mask is not None:
+ assert key_padding_mask.shape[0] == q_shape[1]
+ assert key_padding_mask.shape[1] == src_shape[0]
+ attn_output_weights = paddle.matmul(q,
+ paddle.transpose(k, [0, 1, 3, 2]))
+ if attn_mask is not None:
+ attn_mask = paddle.unsqueeze(paddle.unsqueeze(attn_mask, 0), 0)
+ attn_output_weights += attn_mask
+ if key_padding_mask is not None:
+ attn_output_weights = paddle.reshape(
+ attn_output_weights,
+ [q_shape[1], self.num_heads, q_shape[0], src_shape[0]])
+ key = paddle.unsqueeze(paddle.unsqueeze(key_padding_mask, 1), 2)
+ key = paddle.cast(key, 'float32')
+ y = paddle.full(
+ shape=paddle.shape(key), dtype='float32', fill_value='-inf')
+ y = paddle.where(key == 0., key, y)
+ attn_output_weights += y
+ attn_output_weights = F.softmax(
+ attn_output_weights.astype('float32'),
+ axis=-1,
+ dtype=paddle.float32 if attn_output_weights.dtype == paddle.float16
+ else attn_output_weights.dtype)
+ attn_output_weights = F.dropout(
+ attn_output_weights, p=self.dropout, training=self.training)
+
+ attn_output = paddle.matmul(attn_output_weights, v)
+ attn_output = paddle.reshape(
+ paddle.transpose(attn_output, [2, 0, 1, 3]),
+ [q_shape[0], q_shape[1], self.embed_dim])
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output
+
+ def _in_proj_q(self, query):
+ query = paddle.transpose(query, [1, 2, 0])
+ query = paddle.unsqueeze(query, axis=2)
+ res = self.conv1(query)
+ res = paddle.squeeze(res, axis=2)
+ res = paddle.transpose(res, [2, 0, 1])
+ return res
+
+ def _in_proj_k(self, key):
+ key = paddle.transpose(key, [1, 2, 0])
+ key = paddle.unsqueeze(key, axis=2)
+ res = self.conv2(key)
+ res = paddle.squeeze(res, axis=2)
+ res = paddle.transpose(res, [2, 0, 1])
+ return res
+
+ def _in_proj_v(self, value):
+ value = paddle.transpose(value, [1, 2, 0]) #(1, 2, 0)
+ value = paddle.unsqueeze(value, axis=2)
+ res = self.conv3(value)
+ res = paddle.squeeze(res, axis=2)
+ res = paddle.transpose(res, [2, 0, 1])
+ return res
diff --git a/backend/ppocr/modeling/heads/rec_aster_head.py b/backend/ppocr/modeling/heads/rec_aster_head.py
new file mode 100644
index 00000000..c95e8fd3
--- /dev/null
+++ b/backend/ppocr/modeling/heads/rec_aster_head.py
@@ -0,0 +1,393 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/ayumiymk/aster.pytorch/blob/master/lib/models/attention_recognition_head.py
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sys
+
+import paddle
+from paddle import nn
+from paddle.nn import functional as F
+
+
+class AsterHead(nn.Layer):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ sDim,
+ attDim,
+ max_len_labels,
+ time_step=25,
+ beam_width=5,
+ **kwargs):
+ super(AsterHead, self).__init__()
+ self.num_classes = out_channels
+ self.in_planes = in_channels
+ self.sDim = sDim
+ self.attDim = attDim
+ self.max_len_labels = max_len_labels
+ self.decoder = AttentionRecognitionHead(in_channels, out_channels, sDim,
+ attDim, max_len_labels)
+ self.time_step = time_step
+ self.embeder = Embedding(self.time_step, in_channels)
+ self.beam_width = beam_width
+ self.eos = self.num_classes - 3
+
+ def forward(self, x, targets=None, embed=None):
+ return_dict = {}
+ embedding_vectors = self.embeder(x)
+
+ if self.training:
+ rec_targets, rec_lengths, _ = targets
+ rec_pred = self.decoder([x, rec_targets, rec_lengths],
+ embedding_vectors)
+ return_dict['rec_pred'] = rec_pred
+ return_dict['embedding_vectors'] = embedding_vectors
+ else:
+ rec_pred, rec_pred_scores = self.decoder.beam_search(
+ x, self.beam_width, self.eos, embedding_vectors)
+ return_dict['rec_pred'] = rec_pred
+ return_dict['rec_pred_scores'] = rec_pred_scores
+ return_dict['embedding_vectors'] = embedding_vectors
+
+ return return_dict
+
+
+class Embedding(nn.Layer):
+ def __init__(self, in_timestep, in_planes, mid_dim=4096, embed_dim=300):
+ super(Embedding, self).__init__()
+ self.in_timestep = in_timestep
+ self.in_planes = in_planes
+ self.embed_dim = embed_dim
+ self.mid_dim = mid_dim
+ self.eEmbed = nn.Linear(
+ in_timestep * in_planes,
+ self.embed_dim) # Embed encoder output to a word-embedding like
+
+ def forward(self, x):
+ x = paddle.reshape(x, [paddle.shape(x)[0], -1])
+ x = self.eEmbed(x)
+ return x
+
+
+class AttentionRecognitionHead(nn.Layer):
+ """
+ input: [b x 16 x 64 x in_planes]
+ output: probability sequence: [b x T x num_classes]
+ """
+
+ def __init__(self, in_channels, out_channels, sDim, attDim, max_len_labels):
+ super(AttentionRecognitionHead, self).__init__()
+ self.num_classes = out_channels # this is the output classes. So it includes the .
+ self.in_planes = in_channels
+ self.sDim = sDim
+ self.attDim = attDim
+ self.max_len_labels = max_len_labels
+
+ self.decoder = DecoderUnit(
+ sDim=sDim, xDim=in_channels, yDim=self.num_classes, attDim=attDim)
+
+ def forward(self, x, embed):
+ x, targets, lengths = x
+ batch_size = paddle.shape(x)[0]
+ # Decoder
+ state = self.decoder.get_initial_state(embed)
+ outputs = []
+ for i in range(max(lengths)):
+ if i == 0:
+ y_prev = paddle.full(
+ shape=[batch_size], fill_value=self.num_classes)
+ else:
+ y_prev = targets[:, i - 1]
+ output, state = self.decoder(x, state, y_prev)
+ outputs.append(output)
+ outputs = paddle.concat([_.unsqueeze(1) for _ in outputs], 1)
+ return outputs
+
+ # inference stage.
+ def sample(self, x):
+ x, _, _ = x
+ batch_size = x.size(0)
+ # Decoder
+ state = paddle.zeros([1, batch_size, self.sDim])
+
+ predicted_ids, predicted_scores = [], []
+ for i in range(self.max_len_labels):
+ if i == 0:
+ y_prev = paddle.full(
+ shape=[batch_size], fill_value=self.num_classes)
+ else:
+ y_prev = predicted
+
+ output, state = self.decoder(x, state, y_prev)
+ output = F.softmax(output, axis=1)
+ score, predicted = output.max(1)
+ predicted_ids.append(predicted.unsqueeze(1))
+ predicted_scores.append(score.unsqueeze(1))
+ predicted_ids = paddle.concat([predicted_ids, 1])
+ predicted_scores = paddle.concat([predicted_scores, 1])
+ # return predicted_ids.squeeze(), predicted_scores.squeeze()
+ return predicted_ids, predicted_scores
+
+ def beam_search(self, x, beam_width, eos, embed):
+ def _inflate(tensor, times, dim):
+ repeat_dims = [1] * tensor.dim()
+ repeat_dims[dim] = times
+ output = paddle.tile(tensor, repeat_dims)
+ return output
+
+ # https://github.com/IBM/pytorch-seq2seq/blob/fede87655ddce6c94b38886089e05321dc9802af/seq2seq/models/TopKDecoder.py
+ batch_size, l, d = x.shape
+ x = paddle.tile(
+ paddle.transpose(
+ x.unsqueeze(1), perm=[1, 0, 2, 3]), [beam_width, 1, 1, 1])
+ inflated_encoder_feats = paddle.reshape(
+ paddle.transpose(
+ x, perm=[1, 0, 2, 3]), [-1, l, d])
+
+ # Initialize the decoder
+ state = self.decoder.get_initial_state(embed, tile_times=beam_width)
+
+ pos_index = paddle.reshape(
+ paddle.arange(batch_size) * beam_width, shape=[-1, 1])
+
+ # Initialize the scores
+ sequence_scores = paddle.full(
+ shape=[batch_size * beam_width, 1], fill_value=-float('Inf'))
+ index = [i * beam_width for i in range(0, batch_size)]
+ sequence_scores[index] = 0.0
+
+ # Initialize the input vector
+ y_prev = paddle.full(
+ shape=[batch_size * beam_width], fill_value=self.num_classes)
+
+ # Store decisions for backtracking
+ stored_scores = list()
+ stored_predecessors = list()
+ stored_emitted_symbols = list()
+
+ for i in range(self.max_len_labels):
+ output, state = self.decoder(inflated_encoder_feats, state, y_prev)
+ state = paddle.unsqueeze(state, axis=0)
+ log_softmax_output = paddle.nn.functional.log_softmax(
+ output, axis=1)
+
+ sequence_scores = _inflate(sequence_scores, self.num_classes, 1)
+ sequence_scores += log_softmax_output
+ scores, candidates = paddle.topk(
+ paddle.reshape(sequence_scores, [batch_size, -1]),
+ beam_width,
+ axis=1)
+
+ # Reshape input = (bk, 1) and sequence_scores = (bk, 1)
+ y_prev = paddle.reshape(
+ candidates % self.num_classes, shape=[batch_size * beam_width])
+ sequence_scores = paddle.reshape(
+ scores, shape=[batch_size * beam_width, 1])
+
+ # Update fields for next timestep
+ pos_index = paddle.expand_as(pos_index, candidates)
+ predecessors = paddle.cast(
+ candidates / self.num_classes + pos_index, dtype='int64')
+ predecessors = paddle.reshape(
+ predecessors, shape=[batch_size * beam_width, 1])
+ state = paddle.index_select(
+ state, index=predecessors.squeeze(), axis=1)
+
+ # Update sequence socres and erase scores for symbol so that they aren't expanded
+ stored_scores.append(sequence_scores.clone())
+ y_prev = paddle.reshape(y_prev, shape=[-1, 1])
+ eos_prev = paddle.full_like(y_prev, fill_value=eos)
+ mask = eos_prev == y_prev
+ mask = paddle.nonzero(mask)
+ if mask.dim() > 0:
+ sequence_scores = sequence_scores.numpy()
+ mask = mask.numpy()
+ sequence_scores[mask] = -float('inf')
+ sequence_scores = paddle.to_tensor(sequence_scores)
+
+ # Cache results for backtracking
+ stored_predecessors.append(predecessors)
+ y_prev = paddle.squeeze(y_prev)
+ stored_emitted_symbols.append(y_prev)
+
+ # Do backtracking to return the optimal values
+ #====== backtrak ======#
+ # Initialize return variables given different types
+ p = list()
+ l = [[self.max_len_labels] * beam_width for _ in range(batch_size)
+ ] # Placeholder for lengths of top-k sequences
+
+ # the last step output of the beams are not sorted
+ # thus they are sorted here
+ sorted_score, sorted_idx = paddle.topk(
+ paddle.reshape(
+ stored_scores[-1], shape=[batch_size, beam_width]),
+ beam_width)
+
+ # initialize the sequence scores with the sorted last step beam scores
+ s = sorted_score.clone()
+
+ batch_eos_found = [0] * batch_size # the number of EOS found
+ # in the backward loop below for each batch
+ t = self.max_len_labels - 1
+ # initialize the back pointer with the sorted order of the last step beams.
+ # add pos_index for indexing variable with b*k as the first dimension.
+ t_predecessors = paddle.reshape(
+ sorted_idx + pos_index.expand_as(sorted_idx),
+ shape=[batch_size * beam_width])
+ while t >= 0:
+ # Re-order the variables with the back pointer
+ current_symbol = paddle.index_select(
+ stored_emitted_symbols[t], index=t_predecessors, axis=0)
+ t_predecessors = paddle.index_select(
+ stored_predecessors[t].squeeze(), index=t_predecessors, axis=0)
+ eos_indices = stored_emitted_symbols[t] == eos
+ eos_indices = paddle.nonzero(eos_indices)
+
+ if eos_indices.dim() > 0:
+ for i in range(eos_indices.shape[0] - 1, -1, -1):
+ # Indices of the EOS symbol for both variables
+ # with b*k as the first dimension, and b, k for
+ # the first two dimensions
+ idx = eos_indices[i]
+ b_idx = int(idx[0] / beam_width)
+ # The indices of the replacing position
+ # according to the replacement strategy noted above
+ res_k_idx = beam_width - (batch_eos_found[b_idx] %
+ beam_width) - 1
+ batch_eos_found[b_idx] += 1
+ res_idx = b_idx * beam_width + res_k_idx
+
+ # Replace the old information in return variables
+ # with the new ended sequence information
+ t_predecessors[res_idx] = stored_predecessors[t][idx[0]]
+ current_symbol[res_idx] = stored_emitted_symbols[t][idx[0]]
+ s[b_idx, res_k_idx] = stored_scores[t][idx[0], 0]
+ l[b_idx][res_k_idx] = t + 1
+
+ # record the back tracked results
+ p.append(current_symbol)
+ t -= 1
+
+ # Sort and re-order again as the added ended sequences may change
+ # the order (very unlikely)
+ s, re_sorted_idx = s.topk(beam_width)
+ for b_idx in range(batch_size):
+ l[b_idx] = [
+ l[b_idx][k_idx.item()] for k_idx in re_sorted_idx[b_idx, :]
+ ]
+
+ re_sorted_idx = paddle.reshape(
+ re_sorted_idx + pos_index.expand_as(re_sorted_idx),
+ [batch_size * beam_width])
+
+ # Reverse the sequences and re-order at the same time
+ # It is reversed because the backtracking happens in reverse time order
+ p = [
+ paddle.reshape(
+ paddle.index_select(step, re_sorted_idx, 0),
+ shape=[batch_size, beam_width, -1]) for step in reversed(p)
+ ]
+ p = paddle.concat(p, -1)[:, 0, :]
+ return p, paddle.ones_like(p)
+
+
+class AttentionUnit(nn.Layer):
+ def __init__(self, sDim, xDim, attDim):
+ super(AttentionUnit, self).__init__()
+
+ self.sDim = sDim
+ self.xDim = xDim
+ self.attDim = attDim
+
+ self.sEmbed = nn.Linear(sDim, attDim)
+ self.xEmbed = nn.Linear(xDim, attDim)
+ self.wEmbed = nn.Linear(attDim, 1)
+
+ def forward(self, x, sPrev):
+ batch_size, T, _ = x.shape # [b x T x xDim]
+ x = paddle.reshape(x, [-1, self.xDim]) # [(b x T) x xDim]
+ xProj = self.xEmbed(x) # [(b x T) x attDim]
+ xProj = paddle.reshape(xProj, [batch_size, T, -1]) # [b x T x attDim]
+
+ sPrev = sPrev.squeeze(0)
+ sProj = self.sEmbed(sPrev) # [b x attDim]
+ sProj = paddle.unsqueeze(sProj, 1) # [b x 1 x attDim]
+ sProj = paddle.expand(sProj,
+ [batch_size, T, self.attDim]) # [b x T x attDim]
+
+ sumTanh = paddle.tanh(sProj + xProj)
+ sumTanh = paddle.reshape(sumTanh, [-1, self.attDim])
+
+ vProj = self.wEmbed(sumTanh) # [(b x T) x 1]
+ vProj = paddle.reshape(vProj, [batch_size, T])
+ alpha = F.softmax(
+ vProj, axis=1) # attention weights for each sample in the minibatch
+ return alpha
+
+
+class DecoderUnit(nn.Layer):
+ def __init__(self, sDim, xDim, yDim, attDim):
+ super(DecoderUnit, self).__init__()
+ self.sDim = sDim
+ self.xDim = xDim
+ self.yDim = yDim
+ self.attDim = attDim
+ self.emdDim = attDim
+
+ self.attention_unit = AttentionUnit(sDim, xDim, attDim)
+ self.tgt_embedding = nn.Embedding(
+ yDim + 1, self.emdDim, weight_attr=nn.initializer.Normal(
+ std=0.01)) # the last is used for
+ self.gru = nn.GRUCell(input_size=xDim + self.emdDim, hidden_size=sDim)
+ self.fc = nn.Linear(
+ sDim,
+ yDim,
+ weight_attr=nn.initializer.Normal(std=0.01),
+ bias_attr=nn.initializer.Constant(value=0))
+ self.embed_fc = nn.Linear(300, self.sDim)
+
+ def get_initial_state(self, embed, tile_times=1):
+ assert embed.shape[1] == 300
+ state = self.embed_fc(embed) # N * sDim
+ if tile_times != 1:
+ state = state.unsqueeze(1)
+ trans_state = paddle.transpose(state, perm=[1, 0, 2])
+ state = paddle.tile(trans_state, repeat_times=[tile_times, 1, 1])
+ trans_state = paddle.transpose(state, perm=[1, 0, 2])
+ state = paddle.reshape(trans_state, shape=[-1, self.sDim])
+ state = state.unsqueeze(0) # 1 * N * sDim
+ return state
+
+ def forward(self, x, sPrev, yPrev):
+ # x: feature sequence from the image decoder.
+ batch_size, T, _ = x.shape
+ alpha = self.attention_unit(x, sPrev)
+ context = paddle.squeeze(paddle.matmul(alpha.unsqueeze(1), x), axis=1)
+ yPrev = paddle.cast(yPrev, dtype="int64")
+ yProj = self.tgt_embedding(yPrev)
+
+ concat_context = paddle.concat([yProj, context], 1)
+ concat_context = paddle.squeeze(concat_context, 1)
+ sPrev = paddle.squeeze(sPrev, 0)
+ output, state = self.gru(concat_context, sPrev)
+ output = paddle.squeeze(output, axis=1)
+ output = self.fc(output)
+ return output, state
\ No newline at end of file
diff --git a/backend/ppocr/modeling/heads/rec_att_head.py b/backend/ppocr/modeling/heads/rec_att_head.py
index 0d222714..ab8b119f 100644
--- a/backend/ppocr/modeling/heads/rec_att_head.py
+++ b/backend/ppocr/modeling/heads/rec_att_head.py
@@ -38,7 +38,7 @@ def _char_to_onehot(self, input_char, onehot_dim):
return input_ont_hot
def forward(self, inputs, targets=None, batch_max_length=25):
- batch_size = inputs.shape[0]
+ batch_size = paddle.shape(inputs)[0]
num_steps = batch_max_length
hidden = paddle.zeros((batch_size, self.hidden_size))
@@ -53,7 +53,6 @@ def forward(self, inputs, targets=None, batch_max_length=25):
output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
output = paddle.concat(output_hiddens, axis=1)
probs = self.generator(output)
-
else:
targets = paddle.zeros(shape=[batch_size], dtype="int32")
probs = None
@@ -75,7 +74,8 @@ def forward(self, inputs, targets=None, batch_max_length=25):
probs_step, axis=1)], axis=1)
next_input = probs_step.argmax(axis=1)
targets = next_input
-
+ if not self.training:
+ probs = paddle.nn.functional.softmax(probs, axis=2)
return probs
diff --git a/backend/ppocr/modeling/heads/rec_ctc_head.py b/backend/ppocr/modeling/heads/rec_ctc_head.py
index 69d4ef50..6c1cf065 100755
--- a/backend/ppocr/modeling/heads/rec_ctc_head.py
+++ b/backend/ppocr/modeling/heads/rec_ctc_head.py
@@ -23,32 +23,65 @@
from paddle.nn import functional as F
-def get_para_bias_attr(l2_decay, k, name):
+def get_para_bias_attr(l2_decay, k):
regularizer = paddle.regularizer.L2Decay(l2_decay)
stdv = 1.0 / math.sqrt(k * 1.0)
initializer = nn.initializer.Uniform(-stdv, stdv)
- weight_attr = ParamAttr(
- regularizer=regularizer, initializer=initializer, name=name + "_w_attr")
- bias_attr = ParamAttr(
- regularizer=regularizer, initializer=initializer, name=name + "_b_attr")
+ weight_attr = ParamAttr(regularizer=regularizer, initializer=initializer)
+ bias_attr = ParamAttr(regularizer=regularizer, initializer=initializer)
return [weight_attr, bias_attr]
class CTCHead(nn.Layer):
- def __init__(self, in_channels, out_channels, fc_decay=0.0004, **kwargs):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ fc_decay=0.0004,
+ mid_channels=None,
+ return_feats=False,
+ **kwargs):
super(CTCHead, self).__init__()
- weight_attr, bias_attr = get_para_bias_attr(
- l2_decay=fc_decay, k=in_channels, name='ctc_fc')
- self.fc = nn.Linear(
- in_channels,
- out_channels,
- weight_attr=weight_attr,
- bias_attr=bias_attr,
- name='ctc_fc')
+ if mid_channels is None:
+ weight_attr, bias_attr = get_para_bias_attr(
+ l2_decay=fc_decay, k=in_channels)
+ self.fc = nn.Linear(
+ in_channels,
+ out_channels,
+ weight_attr=weight_attr,
+ bias_attr=bias_attr)
+ else:
+ weight_attr1, bias_attr1 = get_para_bias_attr(
+ l2_decay=fc_decay, k=in_channels)
+ self.fc1 = nn.Linear(
+ in_channels,
+ mid_channels,
+ weight_attr=weight_attr1,
+ bias_attr=bias_attr1)
+
+ weight_attr2, bias_attr2 = get_para_bias_attr(
+ l2_decay=fc_decay, k=mid_channels)
+ self.fc2 = nn.Linear(
+ mid_channels,
+ out_channels,
+ weight_attr=weight_attr2,
+ bias_attr=bias_attr2)
self.out_channels = out_channels
+ self.mid_channels = mid_channels
+ self.return_feats = return_feats
+
+ def forward(self, x, targets=None):
+ if self.mid_channels is None:
+ predicts = self.fc(x)
+ else:
+ x = self.fc1(x)
+ predicts = self.fc2(x)
- def forward(self, x, labels=None):
- predicts = self.fc(x)
+ if self.return_feats:
+ result = (x, predicts)
+ else:
+ result = predicts
if not self.training:
predicts = F.softmax(predicts, axis=2)
- return predicts
+ result = predicts
+
+ return result
diff --git a/backend/ppocr/modeling/heads/rec_multi_head.py b/backend/ppocr/modeling/heads/rec_multi_head.py
new file mode 100644
index 00000000..ef78bf98
--- /dev/null
+++ b/backend/ppocr/modeling/heads/rec_multi_head.py
@@ -0,0 +1,73 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import paddle
+from paddle import ParamAttr
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+from ppocr.modeling.necks.rnn import Im2Seq, EncoderWithRNN, EncoderWithFC, SequenceEncoder, EncoderWithSVTR
+from .rec_ctc_head import CTCHead
+from .rec_sar_head import SARHead
+
+
+class MultiHead(nn.Layer):
+ def __init__(self, in_channels, out_channels_list, **kwargs):
+ super().__init__()
+ self.head_list = kwargs.pop('head_list')
+ self.gtc_head = 'sar'
+ assert len(self.head_list) >= 2
+ for idx, head_name in enumerate(self.head_list):
+ name = list(head_name)[0]
+ if name == 'SARHead':
+ # sar head
+ sar_args = self.head_list[idx][name]
+ self.sar_head = eval(name)(in_channels=in_channels, \
+ out_channels=out_channels_list['SARLabelDecode'], **sar_args)
+ elif name == 'CTCHead':
+ # ctc neck
+ self.encoder_reshape = Im2Seq(in_channels)
+ neck_args = self.head_list[idx][name]['Neck']
+ encoder_type = neck_args.pop('name')
+ self.encoder = encoder_type
+ self.ctc_encoder = SequenceEncoder(in_channels=in_channels, \
+ encoder_type=encoder_type, **neck_args)
+ # ctc head
+ head_args = self.head_list[idx][name]['Head']
+ self.ctc_head = eval(name)(in_channels=self.ctc_encoder.out_channels, \
+ out_channels=out_channels_list['CTCLabelDecode'], **head_args)
+ else:
+ raise NotImplementedError(
+ '{} is not supported in MultiHead yet'.format(name))
+
+ def forward(self, x, targets=None):
+ ctc_encoder = self.ctc_encoder(x)
+ ctc_out = self.ctc_head(ctc_encoder, targets)
+ head_out = dict()
+ head_out['ctc'] = ctc_out
+ head_out['ctc_neck'] = ctc_encoder
+ # eval mode
+ if not self.training:
+ return ctc_out
+ if self.gtc_head == 'sar':
+ sar_out = self.sar_head(x, targets[1:])
+ head_out['sar'] = sar_out
+ return head_out
+ else:
+ return head_out
diff --git a/backend/ppocr/modeling/heads/rec_nrtr_head.py b/backend/ppocr/modeling/heads/rec_nrtr_head.py
new file mode 100644
index 00000000..38ba0c91
--- /dev/null
+++ b/backend/ppocr/modeling/heads/rec_nrtr_head.py
@@ -0,0 +1,826 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import paddle
+import copy
+from paddle import nn
+import paddle.nn.functional as F
+from paddle.nn import LayerList
+from paddle.nn.initializer import XavierNormal as xavier_uniform_
+from paddle.nn import Dropout, Linear, LayerNorm, Conv2D
+import numpy as np
+from ppocr.modeling.heads.multiheadAttention import MultiheadAttention
+from paddle.nn.initializer import Constant as constant_
+from paddle.nn.initializer import XavierNormal as xavier_normal_
+
+zeros_ = constant_(value=0.)
+ones_ = constant_(value=1.)
+
+
+class Transformer(nn.Layer):
+ """A transformer model. User is able to modify the attributes as needed. The architechture
+ is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
+ Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
+ Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information
+ Processing Systems, pages 6000-6010.
+
+ Args:
+ d_model: the number of expected features in the encoder/decoder inputs (default=512).
+ nhead: the number of heads in the multiheadattention models (default=8).
+ num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
+ num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
+ dim_feedforward: the dimension of the feedforward network model (default=2048).
+ dropout: the dropout value (default=0.1).
+ custom_encoder: custom encoder (default=None).
+ custom_decoder: custom decoder (default=None).
+
+ """
+
+ def __init__(self,
+ d_model=512,
+ nhead=8,
+ num_encoder_layers=6,
+ beam_size=0,
+ num_decoder_layers=6,
+ dim_feedforward=1024,
+ attention_dropout_rate=0.0,
+ residual_dropout_rate=0.1,
+ custom_encoder=None,
+ custom_decoder=None,
+ in_channels=0,
+ out_channels=0,
+ scale_embedding=True):
+ super(Transformer, self).__init__()
+ self.out_channels = out_channels + 1
+ self.embedding = Embeddings(
+ d_model=d_model,
+ vocab=self.out_channels,
+ padding_idx=0,
+ scale_embedding=scale_embedding)
+ self.positional_encoding = PositionalEncoding(
+ dropout=residual_dropout_rate,
+ dim=d_model, )
+ if custom_encoder is not None:
+ self.encoder = custom_encoder
+ else:
+ if num_encoder_layers > 0:
+ encoder_layer = TransformerEncoderLayer(
+ d_model, nhead, dim_feedforward, attention_dropout_rate,
+ residual_dropout_rate)
+ self.encoder = TransformerEncoder(encoder_layer,
+ num_encoder_layers)
+ else:
+ self.encoder = None
+
+ if custom_decoder is not None:
+ self.decoder = custom_decoder
+ else:
+ decoder_layer = TransformerDecoderLayer(
+ d_model, nhead, dim_feedforward, attention_dropout_rate,
+ residual_dropout_rate)
+ self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers)
+
+ self._reset_parameters()
+ self.beam_size = beam_size
+ self.d_model = d_model
+ self.nhead = nhead
+ self.tgt_word_prj = nn.Linear(
+ d_model, self.out_channels, bias_attr=False)
+ w0 = np.random.normal(0.0, d_model**-0.5,
+ (d_model, self.out_channels)).astype(np.float32)
+ self.tgt_word_prj.weight.set_value(w0)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+
+ if isinstance(m, nn.Conv2D):
+ xavier_normal_(m.weight)
+ if m.bias is not None:
+ zeros_(m.bias)
+
+ def forward_train(self, src, tgt):
+ tgt = tgt[:, :-1]
+
+ tgt_key_padding_mask = self.generate_padding_mask(tgt)
+ tgt = self.embedding(tgt).transpose([1, 0, 2])
+ tgt = self.positional_encoding(tgt)
+ tgt_mask = self.generate_square_subsequent_mask(tgt.shape[0])
+
+ if self.encoder is not None:
+ src = self.positional_encoding(src.transpose([1, 0, 2]))
+ memory = self.encoder(src)
+ else:
+ memory = src.squeeze(2).transpose([2, 0, 1])
+ output = self.decoder(
+ tgt,
+ memory,
+ tgt_mask=tgt_mask,
+ memory_mask=None,
+ tgt_key_padding_mask=tgt_key_padding_mask,
+ memory_key_padding_mask=None)
+ output = output.transpose([1, 0, 2])
+ logit = self.tgt_word_prj(output)
+ return logit
+
+ def forward(self, src, targets=None):
+ """Take in and process masked source/target sequences.
+ Args:
+ src: the sequence to the encoder (required).
+ tgt: the sequence to the decoder (required).
+ Shape:
+ - src: :math:`(S, N, E)`.
+ - tgt: :math:`(T, N, E)`.
+ Examples:
+ >>> output = transformer_model(src, tgt)
+ """
+
+ if self.training:
+ max_len = targets[1].max()
+ tgt = targets[0][:, :2 + max_len]
+ return self.forward_train(src, tgt)
+ else:
+ if self.beam_size > 0:
+ return self.forward_beam(src)
+ else:
+ return self.forward_test(src)
+
+ def forward_test(self, src):
+ bs = paddle.shape(src)[0]
+ if self.encoder is not None:
+ src = self.positional_encoding(paddle.transpose(src, [1, 0, 2]))
+ memory = self.encoder(src)
+ else:
+ memory = paddle.transpose(paddle.squeeze(src, 2), [2, 0, 1])
+ dec_seq = paddle.full((bs, 1), 2, dtype=paddle.int64)
+ dec_prob = paddle.full((bs, 1), 1., dtype=paddle.float32)
+ for len_dec_seq in range(1, 25):
+ dec_seq_embed = paddle.transpose(self.embedding(dec_seq), [1, 0, 2])
+ dec_seq_embed = self.positional_encoding(dec_seq_embed)
+ tgt_mask = self.generate_square_subsequent_mask(
+ paddle.shape(dec_seq_embed)[0])
+ output = self.decoder(
+ dec_seq_embed,
+ memory,
+ tgt_mask=tgt_mask,
+ memory_mask=None,
+ tgt_key_padding_mask=None,
+ memory_key_padding_mask=None)
+ dec_output = paddle.transpose(output, [1, 0, 2])
+ dec_output = dec_output[:, -1, :]
+ word_prob = F.softmax(self.tgt_word_prj(dec_output), axis=1)
+ preds_idx = paddle.argmax(word_prob, axis=1)
+ if paddle.equal_all(
+ preds_idx,
+ paddle.full(
+ paddle.shape(preds_idx), 3, dtype='int64')):
+ break
+ preds_prob = paddle.max(word_prob, axis=1)
+ dec_seq = paddle.concat(
+ [dec_seq, paddle.reshape(preds_idx, [-1, 1])], axis=1)
+ dec_prob = paddle.concat(
+ [dec_prob, paddle.reshape(preds_prob, [-1, 1])], axis=1)
+ return [dec_seq, dec_prob]
+
+ def forward_beam(self, images):
+ ''' Translation work in one batch '''
+
+ def get_inst_idx_to_tensor_position_map(inst_idx_list):
+ ''' Indicate the position of an instance in a tensor. '''
+ return {
+ inst_idx: tensor_position
+ for tensor_position, inst_idx in enumerate(inst_idx_list)
+ }
+
+ def collect_active_part(beamed_tensor, curr_active_inst_idx,
+ n_prev_active_inst, n_bm):
+ ''' Collect tensor parts associated to active instances. '''
+
+ beamed_tensor_shape = paddle.shape(beamed_tensor)
+ n_curr_active_inst = len(curr_active_inst_idx)
+ new_shape = (n_curr_active_inst * n_bm, beamed_tensor_shape[1],
+ beamed_tensor_shape[2])
+
+ beamed_tensor = beamed_tensor.reshape([n_prev_active_inst, -1])
+ beamed_tensor = beamed_tensor.index_select(
+ curr_active_inst_idx, axis=0)
+ beamed_tensor = beamed_tensor.reshape(new_shape)
+
+ return beamed_tensor
+
+ def collate_active_info(src_enc, inst_idx_to_position_map,
+ active_inst_idx_list):
+ # Sentences which are still active are collected,
+ # so the decoder will not run on completed sentences.
+
+ n_prev_active_inst = len(inst_idx_to_position_map)
+ active_inst_idx = [
+ inst_idx_to_position_map[k] for k in active_inst_idx_list
+ ]
+ active_inst_idx = paddle.to_tensor(active_inst_idx, dtype='int64')
+ active_src_enc = collect_active_part(
+ src_enc.transpose([1, 0, 2]), active_inst_idx,
+ n_prev_active_inst, n_bm).transpose([1, 0, 2])
+ active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(
+ active_inst_idx_list)
+ return active_src_enc, active_inst_idx_to_position_map
+
+ def beam_decode_step(inst_dec_beams, len_dec_seq, enc_output,
+ inst_idx_to_position_map, n_bm,
+ memory_key_padding_mask):
+ ''' Decode and update beam status, and then return active beam idx '''
+
+ def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq):
+ dec_partial_seq = [
+ b.get_current_state() for b in inst_dec_beams if not b.done
+ ]
+ dec_partial_seq = paddle.stack(dec_partial_seq)
+ dec_partial_seq = dec_partial_seq.reshape([-1, len_dec_seq])
+ return dec_partial_seq
+
+ def predict_word(dec_seq, enc_output, n_active_inst, n_bm,
+ memory_key_padding_mask):
+ dec_seq = paddle.transpose(self.embedding(dec_seq), [1, 0, 2])
+ dec_seq = self.positional_encoding(dec_seq)
+ tgt_mask = self.generate_square_subsequent_mask(
+ paddle.shape(dec_seq)[0])
+ dec_output = self.decoder(
+ dec_seq,
+ enc_output,
+ tgt_mask=tgt_mask,
+ tgt_key_padding_mask=None,
+ memory_key_padding_mask=memory_key_padding_mask, )
+ dec_output = paddle.transpose(dec_output, [1, 0, 2])
+ dec_output = dec_output[:,
+ -1, :] # Pick the last step: (bh * bm) * d_h
+ word_prob = F.softmax(self.tgt_word_prj(dec_output), axis=1)
+ word_prob = paddle.reshape(word_prob, [n_active_inst, n_bm, -1])
+ return word_prob
+
+ def collect_active_inst_idx_list(inst_beams, word_prob,
+ inst_idx_to_position_map):
+ active_inst_idx_list = []
+ for inst_idx, inst_position in inst_idx_to_position_map.items():
+ is_inst_complete = inst_beams[inst_idx].advance(word_prob[
+ inst_position])
+ if not is_inst_complete:
+ active_inst_idx_list += [inst_idx]
+
+ return active_inst_idx_list
+
+ n_active_inst = len(inst_idx_to_position_map)
+ dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq)
+ word_prob = predict_word(dec_seq, enc_output, n_active_inst, n_bm,
+ None)
+ # Update the beam with predicted word prob information and collect incomplete instances
+ active_inst_idx_list = collect_active_inst_idx_list(
+ inst_dec_beams, word_prob, inst_idx_to_position_map)
+ return active_inst_idx_list
+
+ def collect_hypothesis_and_scores(inst_dec_beams, n_best):
+ all_hyp, all_scores = [], []
+ for inst_idx in range(len(inst_dec_beams)):
+ scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores()
+ all_scores += [scores[:n_best]]
+ hyps = [
+ inst_dec_beams[inst_idx].get_hypothesis(i)
+ for i in tail_idxs[:n_best]
+ ]
+ all_hyp += [hyps]
+ return all_hyp, all_scores
+
+ with paddle.no_grad():
+ #-- Encode
+ if self.encoder is not None:
+ src = self.positional_encoding(images.transpose([1, 0, 2]))
+ src_enc = self.encoder(src)
+ else:
+ src_enc = images.squeeze(2).transpose([0, 2, 1])
+
+ n_bm = self.beam_size
+ src_shape = paddle.shape(src_enc)
+ inst_dec_beams = [Beam(n_bm) for _ in range(1)]
+ active_inst_idx_list = list(range(1))
+ # Repeat data for beam search
+ src_enc = paddle.tile(src_enc, [1, n_bm, 1])
+ inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(
+ active_inst_idx_list)
+ # Decode
+ for len_dec_seq in range(1, 25):
+ src_enc_copy = src_enc.clone()
+ active_inst_idx_list = beam_decode_step(
+ inst_dec_beams, len_dec_seq, src_enc_copy,
+ inst_idx_to_position_map, n_bm, None)
+ if not active_inst_idx_list:
+ break # all instances have finished their path to
+ src_enc, inst_idx_to_position_map = collate_active_info(
+ src_enc_copy, inst_idx_to_position_map,
+ active_inst_idx_list)
+ batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams,
+ 1)
+ result_hyp = []
+ hyp_scores = []
+ for bs_hyp, score in zip(batch_hyp, batch_scores):
+ l = len(bs_hyp[0])
+ bs_hyp_pad = bs_hyp[0] + [3] * (25 - l)
+ result_hyp.append(bs_hyp_pad)
+ score = float(score) / l
+ hyp_score = [score for _ in range(25)]
+ hyp_scores.append(hyp_score)
+ return [
+ paddle.to_tensor(
+ np.array(result_hyp), dtype=paddle.int64),
+ paddle.to_tensor(hyp_scores)
+ ]
+
+ def generate_square_subsequent_mask(self, sz):
+ """Generate a square mask for the sequence. The masked positions are filled with float('-inf').
+ Unmasked positions are filled with float(0.0).
+ """
+ mask = paddle.zeros([sz, sz], dtype='float32')
+ mask_inf = paddle.triu(
+ paddle.full(
+ shape=[sz, sz], dtype='float32', fill_value='-inf'),
+ diagonal=1)
+ mask = mask + mask_inf
+ return mask
+
+ def generate_padding_mask(self, x):
+ padding_mask = paddle.equal(x, paddle.to_tensor(0, dtype=x.dtype))
+ return padding_mask
+
+ def _reset_parameters(self):
+ """Initiate parameters in the transformer model."""
+
+ for p in self.parameters():
+ if p.dim() > 1:
+ xavier_uniform_(p)
+
+
+class TransformerEncoder(nn.Layer):
+ """TransformerEncoder is a stack of N encoder layers
+ Args:
+ encoder_layer: an instance of the TransformerEncoderLayer() class (required).
+ num_layers: the number of sub-encoder-layers in the encoder (required).
+ norm: the layer normalization component (optional).
+ """
+
+ def __init__(self, encoder_layer, num_layers):
+ super(TransformerEncoder, self).__init__()
+ self.layers = _get_clones(encoder_layer, num_layers)
+ self.num_layers = num_layers
+
+ def forward(self, src):
+ """Pass the input through the endocder layers in turn.
+ Args:
+ src: the sequnce to the encoder (required).
+ mask: the mask for the src sequence (optional).
+ src_key_padding_mask: the mask for the src keys per batch (optional).
+ """
+ output = src
+
+ for i in range(self.num_layers):
+ output = self.layers[i](output,
+ src_mask=None,
+ src_key_padding_mask=None)
+
+ return output
+
+
+class TransformerDecoder(nn.Layer):
+ """TransformerDecoder is a stack of N decoder layers
+
+ Args:
+ decoder_layer: an instance of the TransformerDecoderLayer() class (required).
+ num_layers: the number of sub-decoder-layers in the decoder (required).
+ norm: the layer normalization component (optional).
+
+ """
+
+ def __init__(self, decoder_layer, num_layers):
+ super(TransformerDecoder, self).__init__()
+ self.layers = _get_clones(decoder_layer, num_layers)
+ self.num_layers = num_layers
+
+ def forward(self,
+ tgt,
+ memory,
+ tgt_mask=None,
+ memory_mask=None,
+ tgt_key_padding_mask=None,
+ memory_key_padding_mask=None):
+ """Pass the inputs (and mask) through the decoder layer in turn.
+
+ Args:
+ tgt: the sequence to the decoder (required).
+ memory: the sequnce from the last layer of the encoder (required).
+ tgt_mask: the mask for the tgt sequence (optional).
+ memory_mask: the mask for the memory sequence (optional).
+ tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
+ memory_key_padding_mask: the mask for the memory keys per batch (optional).
+ """
+ output = tgt
+ for i in range(self.num_layers):
+ output = self.layers[i](
+ output,
+ memory,
+ tgt_mask=tgt_mask,
+ memory_mask=memory_mask,
+ tgt_key_padding_mask=tgt_key_padding_mask,
+ memory_key_padding_mask=memory_key_padding_mask)
+
+ return output
+
+
+class TransformerEncoderLayer(nn.Layer):
+ """TransformerEncoderLayer is made up of self-attn and feedforward network.
+ This standard encoder layer is based on the paper "Attention Is All You Need".
+ Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
+ Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
+ Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
+ in a different way during application.
+
+ Args:
+ d_model: the number of expected features in the input (required).
+ nhead: the number of heads in the multiheadattention models (required).
+ dim_feedforward: the dimension of the feedforward network model (default=2048).
+ dropout: the dropout value (default=0.1).
+
+ """
+
+ def __init__(self,
+ d_model,
+ nhead,
+ dim_feedforward=2048,
+ attention_dropout_rate=0.0,
+ residual_dropout_rate=0.1):
+ super(TransformerEncoderLayer, self).__init__()
+ self.self_attn = MultiheadAttention(
+ d_model, nhead, dropout=attention_dropout_rate)
+
+ self.conv1 = Conv2D(
+ in_channels=d_model,
+ out_channels=dim_feedforward,
+ kernel_size=(1, 1))
+ self.conv2 = Conv2D(
+ in_channels=dim_feedforward,
+ out_channels=d_model,
+ kernel_size=(1, 1))
+
+ self.norm1 = LayerNorm(d_model)
+ self.norm2 = LayerNorm(d_model)
+ self.dropout1 = Dropout(residual_dropout_rate)
+ self.dropout2 = Dropout(residual_dropout_rate)
+
+ def forward(self, src, src_mask=None, src_key_padding_mask=None):
+ """Pass the input through the endocder layer.
+ Args:
+ src: the sequnce to the encoder layer (required).
+ src_mask: the mask for the src sequence (optional).
+ src_key_padding_mask: the mask for the src keys per batch (optional).
+ """
+ src2 = self.self_attn(
+ src,
+ src,
+ src,
+ attn_mask=src_mask,
+ key_padding_mask=src_key_padding_mask)
+ src = src + self.dropout1(src2)
+ src = self.norm1(src)
+
+ src = paddle.transpose(src, [1, 2, 0])
+ src = paddle.unsqueeze(src, 2)
+ src2 = self.conv2(F.relu(self.conv1(src)))
+ src2 = paddle.squeeze(src2, 2)
+ src2 = paddle.transpose(src2, [2, 0, 1])
+ src = paddle.squeeze(src, 2)
+ src = paddle.transpose(src, [2, 0, 1])
+
+ src = src + self.dropout2(src2)
+ src = self.norm2(src)
+ return src
+
+
+class TransformerDecoderLayer(nn.Layer):
+ """TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
+ This standard decoder layer is based on the paper "Attention Is All You Need".
+ Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
+ Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
+ Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
+ in a different way during application.
+
+ Args:
+ d_model: the number of expected features in the input (required).
+ nhead: the number of heads in the multiheadattention models (required).
+ dim_feedforward: the dimension of the feedforward network model (default=2048).
+ dropout: the dropout value (default=0.1).
+
+ """
+
+ def __init__(self,
+ d_model,
+ nhead,
+ dim_feedforward=2048,
+ attention_dropout_rate=0.0,
+ residual_dropout_rate=0.1):
+ super(TransformerDecoderLayer, self).__init__()
+ self.self_attn = MultiheadAttention(
+ d_model, nhead, dropout=attention_dropout_rate)
+ self.multihead_attn = MultiheadAttention(
+ d_model, nhead, dropout=attention_dropout_rate)
+
+ self.conv1 = Conv2D(
+ in_channels=d_model,
+ out_channels=dim_feedforward,
+ kernel_size=(1, 1))
+ self.conv2 = Conv2D(
+ in_channels=dim_feedforward,
+ out_channels=d_model,
+ kernel_size=(1, 1))
+
+ self.norm1 = LayerNorm(d_model)
+ self.norm2 = LayerNorm(d_model)
+ self.norm3 = LayerNorm(d_model)
+ self.dropout1 = Dropout(residual_dropout_rate)
+ self.dropout2 = Dropout(residual_dropout_rate)
+ self.dropout3 = Dropout(residual_dropout_rate)
+
+ def forward(self,
+ tgt,
+ memory,
+ tgt_mask=None,
+ memory_mask=None,
+ tgt_key_padding_mask=None,
+ memory_key_padding_mask=None):
+ """Pass the inputs (and mask) through the decoder layer.
+
+ Args:
+ tgt: the sequence to the decoder layer (required).
+ memory: the sequnce from the last layer of the encoder (required).
+ tgt_mask: the mask for the tgt sequence (optional).
+ memory_mask: the mask for the memory sequence (optional).
+ tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
+ memory_key_padding_mask: the mask for the memory keys per batch (optional).
+
+ """
+ tgt2 = self.self_attn(
+ tgt,
+ tgt,
+ tgt,
+ attn_mask=tgt_mask,
+ key_padding_mask=tgt_key_padding_mask)
+ tgt = tgt + self.dropout1(tgt2)
+ tgt = self.norm1(tgt)
+ tgt2 = self.multihead_attn(
+ tgt,
+ memory,
+ memory,
+ attn_mask=memory_mask,
+ key_padding_mask=memory_key_padding_mask)
+ tgt = tgt + self.dropout2(tgt2)
+ tgt = self.norm2(tgt)
+
+ # default
+ tgt = paddle.transpose(tgt, [1, 2, 0])
+ tgt = paddle.unsqueeze(tgt, 2)
+ tgt2 = self.conv2(F.relu(self.conv1(tgt)))
+ tgt2 = paddle.squeeze(tgt2, 2)
+ tgt2 = paddle.transpose(tgt2, [2, 0, 1])
+ tgt = paddle.squeeze(tgt, 2)
+ tgt = paddle.transpose(tgt, [2, 0, 1])
+
+ tgt = tgt + self.dropout3(tgt2)
+ tgt = self.norm3(tgt)
+ return tgt
+
+
+def _get_clones(module, N):
+ return LayerList([copy.deepcopy(module) for i in range(N)])
+
+
+class PositionalEncoding(nn.Layer):
+ """Inject some information about the relative or absolute position of the tokens
+ in the sequence. The positional encodings have the same dimension as
+ the embeddings, so that the two can be summed. Here, we use sine and cosine
+ functions of different frequencies.
+ .. math::
+ \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
+ \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
+ \text{where pos is the word position and i is the embed idx)
+ Args:
+ d_model: the embed dim (required).
+ dropout: the dropout value (default=0.1).
+ max_len: the max. length of the incoming sequence (default=5000).
+ Examples:
+ >>> pos_encoder = PositionalEncoding(d_model)
+ """
+
+ def __init__(self, dropout, dim, max_len=5000):
+ super(PositionalEncoding, self).__init__()
+ self.dropout = nn.Dropout(p=dropout)
+
+ pe = paddle.zeros([max_len, dim])
+ position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1)
+ div_term = paddle.exp(
+ paddle.arange(0, dim, 2).astype('float32') *
+ (-math.log(10000.0) / dim))
+ pe[:, 0::2] = paddle.sin(position * div_term)
+ pe[:, 1::2] = paddle.cos(position * div_term)
+ pe = paddle.unsqueeze(pe, 0)
+ pe = paddle.transpose(pe, [1, 0, 2])
+ self.register_buffer('pe', pe)
+
+ def forward(self, x):
+ """Inputs of forward function
+ Args:
+ x: the sequence fed to the positional encoder model (required).
+ Shape:
+ x: [sequence length, batch size, embed dim]
+ output: [sequence length, batch size, embed dim]
+ Examples:
+ >>> output = pos_encoder(x)
+ """
+ x = x + self.pe[:paddle.shape(x)[0], :]
+ return self.dropout(x)
+
+
+class PositionalEncoding_2d(nn.Layer):
+ """Inject some information about the relative or absolute position of the tokens
+ in the sequence. The positional encodings have the same dimension as
+ the embeddings, so that the two can be summed. Here, we use sine and cosine
+ functions of different frequencies.
+ .. math::
+ \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
+ \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
+ \text{where pos is the word position and i is the embed idx)
+ Args:
+ d_model: the embed dim (required).
+ dropout: the dropout value (default=0.1).
+ max_len: the max. length of the incoming sequence (default=5000).
+ Examples:
+ >>> pos_encoder = PositionalEncoding(d_model)
+ """
+
+ def __init__(self, dropout, dim, max_len=5000):
+ super(PositionalEncoding_2d, self).__init__()
+ self.dropout = nn.Dropout(p=dropout)
+
+ pe = paddle.zeros([max_len, dim])
+ position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1)
+ div_term = paddle.exp(
+ paddle.arange(0, dim, 2).astype('float32') *
+ (-math.log(10000.0) / dim))
+ pe[:, 0::2] = paddle.sin(position * div_term)
+ pe[:, 1::2] = paddle.cos(position * div_term)
+ pe = paddle.transpose(paddle.unsqueeze(pe, 0), [1, 0, 2])
+ self.register_buffer('pe', pe)
+
+ self.avg_pool_1 = nn.AdaptiveAvgPool2D((1, 1))
+ self.linear1 = nn.Linear(dim, dim)
+ self.linear1.weight.data.fill_(1.)
+ self.avg_pool_2 = nn.AdaptiveAvgPool2D((1, 1))
+ self.linear2 = nn.Linear(dim, dim)
+ self.linear2.weight.data.fill_(1.)
+
+ def forward(self, x):
+ """Inputs of forward function
+ Args:
+ x: the sequence fed to the positional encoder model (required).
+ Shape:
+ x: [sequence length, batch size, embed dim]
+ output: [sequence length, batch size, embed dim]
+ Examples:
+ >>> output = pos_encoder(x)
+ """
+ w_pe = self.pe[:paddle.shape(x)[-1], :]
+ w1 = self.linear1(self.avg_pool_1(x).squeeze()).unsqueeze(0)
+ w_pe = w_pe * w1
+ w_pe = paddle.transpose(w_pe, [1, 2, 0])
+ w_pe = paddle.unsqueeze(w_pe, 2)
+
+ h_pe = self.pe[:paddle.shape(x).shape[-2], :]
+ w2 = self.linear2(self.avg_pool_2(x).squeeze()).unsqueeze(0)
+ h_pe = h_pe * w2
+ h_pe = paddle.transpose(h_pe, [1, 2, 0])
+ h_pe = paddle.unsqueeze(h_pe, 3)
+
+ x = x + w_pe + h_pe
+ x = paddle.transpose(
+ paddle.reshape(x,
+ [x.shape[0], x.shape[1], x.shape[2] * x.shape[3]]),
+ [2, 0, 1])
+
+ return self.dropout(x)
+
+
+class Embeddings(nn.Layer):
+ def __init__(self, d_model, vocab, padding_idx, scale_embedding):
+ super(Embeddings, self).__init__()
+ self.embedding = nn.Embedding(vocab, d_model, padding_idx=padding_idx)
+ w0 = np.random.normal(0.0, d_model**-0.5,
+ (vocab, d_model)).astype(np.float32)
+ self.embedding.weight.set_value(w0)
+ self.d_model = d_model
+ self.scale_embedding = scale_embedding
+
+ def forward(self, x):
+ if self.scale_embedding:
+ x = self.embedding(x)
+ return x * math.sqrt(self.d_model)
+ return self.embedding(x)
+
+
+class Beam():
+ ''' Beam search '''
+
+ def __init__(self, size, device=False):
+
+ self.size = size
+ self._done = False
+ # The score for each translation on the beam.
+ self.scores = paddle.zeros((size, ), dtype=paddle.float32)
+ self.all_scores = []
+ # The backpointers at each time-step.
+ self.prev_ks = []
+ # The outputs at each time-step.
+ self.next_ys = [paddle.full((size, ), 0, dtype=paddle.int64)]
+ self.next_ys[0][0] = 2
+
+ def get_current_state(self):
+ "Get the outputs for the current timestep."
+ return self.get_tentative_hypothesis()
+
+ def get_current_origin(self):
+ "Get the backpointers for the current timestep."
+ return self.prev_ks[-1]
+
+ @property
+ def done(self):
+ return self._done
+
+ def advance(self, word_prob):
+ "Update beam status and check if finished or not."
+ num_words = word_prob.shape[1]
+
+ # Sum the previous scores.
+ if len(self.prev_ks) > 0:
+ beam_lk = word_prob + self.scores.unsqueeze(1).expand_as(word_prob)
+ else:
+ beam_lk = word_prob[0]
+
+ flat_beam_lk = beam_lk.reshape([-1])
+ best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True,
+ True) # 1st sort
+ self.all_scores.append(self.scores)
+ self.scores = best_scores
+ # bestScoresId is flattened as a (beam x word) array,
+ # so we need to calculate which word and beam each score came from
+ prev_k = best_scores_id // num_words
+ self.prev_ks.append(prev_k)
+ self.next_ys.append(best_scores_id - prev_k * num_words)
+ # End condition is when top-of-beam is EOS.
+ if self.next_ys[-1][0] == 3:
+ self._done = True
+ self.all_scores.append(self.scores)
+
+ return self._done
+
+ def sort_scores(self):
+ "Sort the scores."
+ return self.scores, paddle.to_tensor(
+ [i for i in range(int(self.scores.shape[0]))], dtype='int32')
+
+ def get_the_best_score_and_idx(self):
+ "Get the score of the best in the beam."
+ scores, ids = self.sort_scores()
+ return scores[1], ids[1]
+
+ def get_tentative_hypothesis(self):
+ "Get the decoded sequence for the current timestep."
+ if len(self.next_ys) == 1:
+ dec_seq = self.next_ys[0].unsqueeze(1)
+ else:
+ _, keys = self.sort_scores()
+ hyps = [self.get_hypothesis(k) for k in keys]
+ hyps = [[2] + h for h in hyps]
+ dec_seq = paddle.to_tensor(hyps, dtype='int64')
+ return dec_seq
+
+ def get_hypothesis(self, k):
+ """ Walk back to construct the full hypothesis. """
+ hyp = []
+ for j in range(len(self.prev_ks) - 1, -1, -1):
+ hyp.append(self.next_ys[j + 1][k])
+ k = self.prev_ks[j][k]
+ return list(map(lambda x: x.item(), hyp[::-1]))
diff --git a/backend/ppocr/modeling/heads/rec_pren_head.py b/backend/ppocr/modeling/heads/rec_pren_head.py
new file mode 100644
index 00000000..c9e4b3e9
--- /dev/null
+++ b/backend/ppocr/modeling/heads/rec_pren_head.py
@@ -0,0 +1,34 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from paddle import nn
+from paddle.nn import functional as F
+
+
+class PRENHead(nn.Layer):
+ def __init__(self, in_channels, out_channels, **kwargs):
+ super(PRENHead, self).__init__()
+ self.linear = nn.Linear(in_channels, out_channels)
+
+ def forward(self, x, targets=None):
+ predicts = self.linear(x)
+
+ if not self.training:
+ predicts = F.softmax(predicts, axis=2)
+
+ return predicts
diff --git a/backend/ppocr/modeling/heads/rec_sar_head.py b/backend/ppocr/modeling/heads/rec_sar_head.py
new file mode 100644
index 00000000..0e6b3440
--- /dev/null
+++ b/backend/ppocr/modeling/heads/rec_sar_head.py
@@ -0,0 +1,410 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/encoders/sar_encoder.py
+https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/decoders/sar_decoder.py
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import paddle
+from paddle import ParamAttr
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+
+class SAREncoder(nn.Layer):
+ """
+ Args:
+ enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
+ enc_drop_rnn (float): Dropout probability of RNN layer in encoder.
+ enc_gru (bool): If True, use GRU, else LSTM in encoder.
+ d_model (int): Dim of channels from backbone.
+ d_enc (int): Dim of encoder RNN layer.
+ mask (bool): If True, mask padding in RNN sequence.
+ """
+
+ def __init__(self,
+ enc_bi_rnn=False,
+ enc_drop_rnn=0.1,
+ enc_gru=False,
+ d_model=512,
+ d_enc=512,
+ mask=True,
+ **kwargs):
+ super().__init__()
+ assert isinstance(enc_bi_rnn, bool)
+ assert isinstance(enc_drop_rnn, (int, float))
+ assert 0 <= enc_drop_rnn < 1.0
+ assert isinstance(enc_gru, bool)
+ assert isinstance(d_model, int)
+ assert isinstance(d_enc, int)
+ assert isinstance(mask, bool)
+
+ self.enc_bi_rnn = enc_bi_rnn
+ self.enc_drop_rnn = enc_drop_rnn
+ self.mask = mask
+
+ # LSTM Encoder
+ if enc_bi_rnn:
+ direction = 'bidirectional'
+ else:
+ direction = 'forward'
+ kwargs = dict(
+ input_size=d_model,
+ hidden_size=d_enc,
+ num_layers=2,
+ time_major=False,
+ dropout=enc_drop_rnn,
+ direction=direction)
+ if enc_gru:
+ self.rnn_encoder = nn.GRU(**kwargs)
+ else:
+ self.rnn_encoder = nn.LSTM(**kwargs)
+
+ # global feature transformation
+ encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
+ self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size)
+
+ def forward(self, feat, img_metas=None):
+ if img_metas is not None:
+ assert len(img_metas[0]) == feat.shape[0]
+
+ valid_ratios = None
+ if img_metas is not None and self.mask:
+ valid_ratios = img_metas[-1]
+
+ h_feat = feat.shape[2] # bsz c h w
+ feat_v = F.max_pool2d(
+ feat, kernel_size=(h_feat, 1), stride=1, padding=0)
+ feat_v = feat_v.squeeze(2) # bsz * C * W
+ feat_v = paddle.transpose(feat_v, perm=[0, 2, 1]) # bsz * W * C
+ holistic_feat = self.rnn_encoder(feat_v)[0] # bsz * T * C
+
+ if valid_ratios is not None:
+ valid_hf = []
+ T = holistic_feat.shape[1]
+ for i in range(len(valid_ratios)):
+ valid_step = min(T, math.ceil(T * valid_ratios[i])) - 1
+ valid_hf.append(holistic_feat[i, valid_step, :])
+ valid_hf = paddle.stack(valid_hf, axis=0)
+ else:
+ valid_hf = holistic_feat[:, -1, :] # bsz * C
+ holistic_feat = self.linear(valid_hf) # bsz * C
+
+ return holistic_feat
+
+
+class BaseDecoder(nn.Layer):
+ def __init__(self, **kwargs):
+ super().__init__()
+
+ def forward_train(self, feat, out_enc, targets, img_metas):
+ raise NotImplementedError
+
+ def forward_test(self, feat, out_enc, img_metas):
+ raise NotImplementedError
+
+ def forward(self,
+ feat,
+ out_enc,
+ label=None,
+ img_metas=None,
+ train_mode=True):
+ self.train_mode = train_mode
+
+ if train_mode:
+ return self.forward_train(feat, out_enc, label, img_metas)
+ return self.forward_test(feat, out_enc, img_metas)
+
+
+class ParallelSARDecoder(BaseDecoder):
+ """
+ Args:
+ out_channels (int): Output class number.
+ enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
+ dec_bi_rnn (bool): If True, use bidirectional RNN in decoder.
+ dec_drop_rnn (float): Dropout of RNN layer in decoder.
+ dec_gru (bool): If True, use GRU, else LSTM in decoder.
+ d_model (int): Dim of channels from backbone.
+ d_enc (int): Dim of encoder RNN layer.
+ d_k (int): Dim of channels of attention module.
+ pred_dropout (float): Dropout probability of prediction layer.
+ max_seq_len (int): Maximum sequence length for decoding.
+ mask (bool): If True, mask padding in feature map.
+ start_idx (int): Index of start token.
+ padding_idx (int): Index of padding token.
+ pred_concat (bool): If True, concat glimpse feature from
+ attention with holistic feature and hidden state.
+ """
+
+ def __init__(
+ self,
+ out_channels, # 90 + unknown + start + padding
+ enc_bi_rnn=False,
+ dec_bi_rnn=False,
+ dec_drop_rnn=0.0,
+ dec_gru=False,
+ d_model=512,
+ d_enc=512,
+ d_k=64,
+ pred_dropout=0.1,
+ max_text_length=30,
+ mask=True,
+ pred_concat=True,
+ **kwargs):
+ super().__init__()
+
+ self.num_classes = out_channels
+ self.enc_bi_rnn = enc_bi_rnn
+ self.d_k = d_k
+ self.start_idx = out_channels - 2
+ self.padding_idx = out_channels - 1
+ self.max_seq_len = max_text_length
+ self.mask = mask
+ self.pred_concat = pred_concat
+
+ encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
+ decoder_rnn_out_size = encoder_rnn_out_size * (int(dec_bi_rnn) + 1)
+
+ # 2D attention layer
+ self.conv1x1_1 = nn.Linear(decoder_rnn_out_size, d_k)
+ self.conv3x3_1 = nn.Conv2D(
+ d_model, d_k, kernel_size=3, stride=1, padding=1)
+ self.conv1x1_2 = nn.Linear(d_k, 1)
+
+ # Decoder RNN layer
+ if dec_bi_rnn:
+ direction = 'bidirectional'
+ else:
+ direction = 'forward'
+
+ kwargs = dict(
+ input_size=encoder_rnn_out_size,
+ hidden_size=encoder_rnn_out_size,
+ num_layers=2,
+ time_major=False,
+ dropout=dec_drop_rnn,
+ direction=direction)
+ if dec_gru:
+ self.rnn_decoder = nn.GRU(**kwargs)
+ else:
+ self.rnn_decoder = nn.LSTM(**kwargs)
+
+ # Decoder input embedding
+ self.embedding = nn.Embedding(
+ self.num_classes,
+ encoder_rnn_out_size,
+ padding_idx=self.padding_idx)
+
+ # Prediction layer
+ self.pred_dropout = nn.Dropout(pred_dropout)
+ pred_num_classes = self.num_classes - 1
+ if pred_concat:
+ fc_in_channel = decoder_rnn_out_size + d_model + encoder_rnn_out_size
+ else:
+ fc_in_channel = d_model
+ self.prediction = nn.Linear(fc_in_channel, pred_num_classes)
+
+ def _2d_attention(self,
+ decoder_input,
+ feat,
+ holistic_feat,
+ valid_ratios=None):
+
+ y = self.rnn_decoder(decoder_input)[0]
+ # y: bsz * (seq_len + 1) * hidden_size
+
+ attn_query = self.conv1x1_1(y) # bsz * (seq_len + 1) * attn_size
+ bsz, seq_len, attn_size = attn_query.shape
+ attn_query = paddle.unsqueeze(attn_query, axis=[3, 4])
+ # (bsz, seq_len + 1, attn_size, 1, 1)
+
+ attn_key = self.conv3x3_1(feat)
+ # bsz * attn_size * h * w
+ attn_key = attn_key.unsqueeze(1)
+ # bsz * 1 * attn_size * h * w
+
+ attn_weight = paddle.tanh(paddle.add(attn_key, attn_query))
+
+ # bsz * (seq_len + 1) * attn_size * h * w
+ attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 3, 4, 2])
+ # bsz * (seq_len + 1) * h * w * attn_size
+ attn_weight = self.conv1x1_2(attn_weight)
+ # bsz * (seq_len + 1) * h * w * 1
+ bsz, T, h, w, c = attn_weight.shape
+ assert c == 1
+
+ if valid_ratios is not None:
+ # cal mask of attention weight
+ for i in range(len(valid_ratios)):
+ valid_width = min(w, math.ceil(w * valid_ratios[i]))
+ if valid_width < w:
+ attn_weight[i, :, :, valid_width:, :] = float('-inf')
+
+ attn_weight = paddle.reshape(attn_weight, [bsz, T, -1])
+ attn_weight = F.softmax(attn_weight, axis=-1)
+
+ attn_weight = paddle.reshape(attn_weight, [bsz, T, h, w, c])
+ attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 4, 2, 3])
+ # attn_weight: bsz * T * c * h * w
+ # feat: bsz * c * h * w
+ attn_feat = paddle.sum(paddle.multiply(feat.unsqueeze(1), attn_weight),
+ (3, 4),
+ keepdim=False)
+ # bsz * (seq_len + 1) * C
+
+ # Linear transformation
+ if self.pred_concat:
+ hf_c = holistic_feat.shape[-1]
+ holistic_feat = paddle.expand(
+ holistic_feat, shape=[bsz, seq_len, hf_c])
+ y = self.prediction(paddle.concat((y, attn_feat, holistic_feat), 2))
+ else:
+ y = self.prediction(attn_feat)
+ # bsz * (seq_len + 1) * num_classes
+ if self.train_mode:
+ y = self.pred_dropout(y)
+
+ return y
+
+ def forward_train(self, feat, out_enc, label, img_metas):
+ '''
+ img_metas: [label, valid_ratio]
+ '''
+ if img_metas is not None:
+ assert len(img_metas[0]) == feat.shape[0]
+
+ valid_ratios = None
+ if img_metas is not None and self.mask:
+ valid_ratios = img_metas[-1]
+
+ lab_embedding = self.embedding(label)
+ # bsz * seq_len * emb_dim
+ out_enc = out_enc.unsqueeze(1)
+ # bsz * 1 * emb_dim
+ in_dec = paddle.concat((out_enc, lab_embedding), axis=1)
+ # bsz * (seq_len + 1) * C
+ out_dec = self._2d_attention(
+ in_dec, feat, out_enc, valid_ratios=valid_ratios)
+ # bsz * (seq_len + 1) * num_classes
+
+ return out_dec[:, 1:, :] # bsz * seq_len * num_classes
+
+ def forward_test(self, feat, out_enc, img_metas):
+ if img_metas is not None:
+ assert len(img_metas[0]) == feat.shape[0]
+
+ valid_ratios = None
+ if img_metas is not None and self.mask:
+ valid_ratios = img_metas[-1]
+
+ seq_len = self.max_seq_len
+ bsz = feat.shape[0]
+ start_token = paddle.full(
+ (bsz, ), fill_value=self.start_idx, dtype='int64')
+ # bsz
+ start_token = self.embedding(start_token)
+ # bsz * emb_dim
+ emb_dim = start_token.shape[1]
+ start_token = start_token.unsqueeze(1)
+ start_token = paddle.expand(start_token, shape=[bsz, seq_len, emb_dim])
+ # bsz * seq_len * emb_dim
+ out_enc = out_enc.unsqueeze(1)
+ # bsz * 1 * emb_dim
+ decoder_input = paddle.concat((out_enc, start_token), axis=1)
+ # bsz * (seq_len + 1) * emb_dim
+
+ outputs = []
+ for i in range(1, seq_len + 1):
+ decoder_output = self._2d_attention(
+ decoder_input, feat, out_enc, valid_ratios=valid_ratios)
+ char_output = decoder_output[:, i, :] # bsz * num_classes
+ char_output = F.softmax(char_output, -1)
+ outputs.append(char_output)
+ max_idx = paddle.argmax(char_output, axis=1, keepdim=False)
+ char_embedding = self.embedding(max_idx) # bsz * emb_dim
+ if i < seq_len:
+ decoder_input[:, i + 1, :] = char_embedding
+
+ outputs = paddle.stack(outputs, 1) # bsz * seq_len * num_classes
+
+ return outputs
+
+
+class SARHead(nn.Layer):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ enc_dim=512,
+ max_text_length=30,
+ enc_bi_rnn=False,
+ enc_drop_rnn=0.1,
+ enc_gru=False,
+ dec_bi_rnn=False,
+ dec_drop_rnn=0.0,
+ dec_gru=False,
+ d_k=512,
+ pred_dropout=0.1,
+ pred_concat=True,
+ **kwargs):
+ super(SARHead, self).__init__()
+
+ # encoder module
+ self.encoder = SAREncoder(
+ enc_bi_rnn=enc_bi_rnn,
+ enc_drop_rnn=enc_drop_rnn,
+ enc_gru=enc_gru,
+ d_model=in_channels,
+ d_enc=enc_dim)
+
+ # decoder module
+ self.decoder = ParallelSARDecoder(
+ out_channels=out_channels,
+ enc_bi_rnn=enc_bi_rnn,
+ dec_bi_rnn=dec_bi_rnn,
+ dec_drop_rnn=dec_drop_rnn,
+ dec_gru=dec_gru,
+ d_model=in_channels,
+ d_enc=enc_dim,
+ d_k=d_k,
+ pred_dropout=pred_dropout,
+ max_text_length=max_text_length,
+ pred_concat=pred_concat)
+
+ def forward(self, feat, targets=None):
+ '''
+ img_metas: [label, valid_ratio]
+ '''
+ holistic_feat = self.encoder(feat, targets) # bsz c
+
+ if self.training:
+ label = targets[0] # label
+ label = paddle.to_tensor(label, dtype='int64')
+ final_out = self.decoder(
+ feat, holistic_feat, label, img_metas=targets)
+ else:
+ final_out = self.decoder(
+ feat,
+ holistic_feat,
+ label=None,
+ img_metas=targets,
+ train_mode=False)
+ # (bsz, seq_len, num_classes)
+
+ return final_out
diff --git a/backend/ppocr/modeling/heads/rec_srn_head.py b/backend/ppocr/modeling/heads/rec_srn_head.py
index d2c7fc02..8d59e471 100644
--- a/backend/ppocr/modeling/heads/rec_srn_head.py
+++ b/backend/ppocr/modeling/heads/rec_srn_head.py
@@ -250,7 +250,8 @@ def __init__(self, in_channels, out_channels, max_text_length, num_heads,
self.gsrm.wrap_encoder1.prepare_decoder.emb0 = self.gsrm.wrap_encoder0.prepare_decoder.emb0
- def forward(self, inputs, others):
+ def forward(self, inputs, targets=None):
+ others = targets[-4:]
encoder_word_pos = others[0]
gsrm_word_pos = others[1]
gsrm_slf_attn_bias1 = others[2]
diff --git a/backend/ppocr/modeling/heads/self_attention.py b/backend/ppocr/modeling/heads/self_attention.py
index 51d5198f..6c27fdbe 100644
--- a/backend/ppocr/modeling/heads/self_attention.py
+++ b/backend/ppocr/modeling/heads/self_attention.py
@@ -285,8 +285,7 @@ def __init__(self, process_cmd, d_model, dropout_rate):
elif cmd == "n": # add layer normalization
self.functors.append(
self.add_sublayer(
- "layer_norm_%d" % len(
- self.sublayers(include_sublayers=False)),
+ "layer_norm_%d" % len(self.sublayers()),
paddle.nn.LayerNorm(
normalized_shape=d_model,
weight_attr=fluid.ParamAttr(
@@ -320,9 +319,7 @@ def __init__(self,
self.src_emb_dim = src_emb_dim
self.src_max_len = src_max_len
self.emb = paddle.nn.Embedding(
- num_embeddings=self.src_max_len,
- embedding_dim=self.src_emb_dim,
- sparse=True)
+ num_embeddings=self.src_max_len, embedding_dim=self.src_emb_dim)
self.dropout_rate = dropout_rate
def forward(self, src_word, src_pos):
diff --git a/backend/ppocr/modeling/heads/table_att_head.py b/backend/ppocr/modeling/heads/table_att_head.py
new file mode 100644
index 00000000..e354f40d
--- /dev/null
+++ b/backend/ppocr/modeling/heads/table_att_head.py
@@ -0,0 +1,246 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+import numpy as np
+
+
+class TableAttentionHead(nn.Layer):
+ def __init__(self,
+ in_channels,
+ hidden_size,
+ loc_type,
+ in_max_len=488,
+ max_text_length=100,
+ max_elem_length=800,
+ max_cell_num=500,
+ **kwargs):
+ super(TableAttentionHead, self).__init__()
+ self.input_size = in_channels[-1]
+ self.hidden_size = hidden_size
+ self.elem_num = 30
+ self.max_text_length = max_text_length
+ self.max_elem_length = max_elem_length
+ self.max_cell_num = max_cell_num
+
+ self.structure_attention_cell = AttentionGRUCell(
+ self.input_size, hidden_size, self.elem_num, use_gru=False)
+ self.structure_generator = nn.Linear(hidden_size, self.elem_num)
+ self.loc_type = loc_type
+ self.in_max_len = in_max_len
+
+ if self.loc_type == 1:
+ self.loc_generator = nn.Linear(hidden_size, 4)
+ else:
+ if self.in_max_len == 640:
+ self.loc_fea_trans = nn.Linear(400, self.max_elem_length + 1)
+ elif self.in_max_len == 800:
+ self.loc_fea_trans = nn.Linear(625, self.max_elem_length + 1)
+ else:
+ self.loc_fea_trans = nn.Linear(256, self.max_elem_length + 1)
+ self.loc_generator = nn.Linear(self.input_size + hidden_size, 4)
+
+ def _char_to_onehot(self, input_char, onehot_dim):
+ input_ont_hot = F.one_hot(input_char, onehot_dim)
+ return input_ont_hot
+
+ def forward(self, inputs, targets=None):
+ # if and else branch are both needed when you want to assign a variable
+ # if you modify the var in just one branch, then the modification will not work.
+ fea = inputs[-1]
+ if len(fea.shape) == 3:
+ pass
+ else:
+ last_shape = int(np.prod(fea.shape[2:])) # gry added
+ fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape])
+ fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
+ batch_size = fea.shape[0]
+
+ hidden = paddle.zeros((batch_size, self.hidden_size))
+ output_hiddens = []
+ if self.training and targets is not None:
+ structure = targets[0]
+ for i in range(self.max_elem_length + 1):
+ elem_onehots = self._char_to_onehot(
+ structure[:, i], onehot_dim=self.elem_num)
+ (outputs, hidden), alpha = self.structure_attention_cell(
+ hidden, fea, elem_onehots)
+ output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
+ output = paddle.concat(output_hiddens, axis=1)
+ structure_probs = self.structure_generator(output)
+ if self.loc_type == 1:
+ loc_preds = self.loc_generator(output)
+ loc_preds = F.sigmoid(loc_preds)
+ else:
+ loc_fea = fea.transpose([0, 2, 1])
+ loc_fea = self.loc_fea_trans(loc_fea)
+ loc_fea = loc_fea.transpose([0, 2, 1])
+ loc_concat = paddle.concat([output, loc_fea], axis=2)
+ loc_preds = self.loc_generator(loc_concat)
+ loc_preds = F.sigmoid(loc_preds)
+ else:
+ temp_elem = paddle.zeros(shape=[batch_size], dtype="int32")
+ structure_probs = None
+ loc_preds = None
+ elem_onehots = None
+ outputs = None
+ alpha = None
+ max_elem_length = paddle.to_tensor(self.max_elem_length)
+ i = 0
+ while i < max_elem_length + 1:
+ elem_onehots = self._char_to_onehot(
+ temp_elem, onehot_dim=self.elem_num)
+ (outputs, hidden), alpha = self.structure_attention_cell(
+ hidden, fea, elem_onehots)
+ output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
+ structure_probs_step = self.structure_generator(outputs)
+ temp_elem = structure_probs_step.argmax(axis=1, dtype="int32")
+ i += 1
+
+ output = paddle.concat(output_hiddens, axis=1)
+ structure_probs = self.structure_generator(output)
+ structure_probs = F.softmax(structure_probs)
+ if self.loc_type == 1:
+ loc_preds = self.loc_generator(output)
+ loc_preds = F.sigmoid(loc_preds)
+ else:
+ loc_fea = fea.transpose([0, 2, 1])
+ loc_fea = self.loc_fea_trans(loc_fea)
+ loc_fea = loc_fea.transpose([0, 2, 1])
+ loc_concat = paddle.concat([output, loc_fea], axis=2)
+ loc_preds = self.loc_generator(loc_concat)
+ loc_preds = F.sigmoid(loc_preds)
+ return {'structure_probs': structure_probs, 'loc_preds': loc_preds}
+
+
+class AttentionGRUCell(nn.Layer):
+ def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
+ super(AttentionGRUCell, self).__init__()
+ self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
+ self.h2h = nn.Linear(hidden_size, hidden_size)
+ self.score = nn.Linear(hidden_size, 1, bias_attr=False)
+ self.rnn = nn.GRUCell(
+ input_size=input_size + num_embeddings, hidden_size=hidden_size)
+ self.hidden_size = hidden_size
+
+ def forward(self, prev_hidden, batch_H, char_onehots):
+ batch_H_proj = self.i2h(batch_H)
+ prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden), axis=1)
+ res = paddle.add(batch_H_proj, prev_hidden_proj)
+ res = paddle.tanh(res)
+ e = self.score(res)
+ alpha = F.softmax(e, axis=1)
+ alpha = paddle.transpose(alpha, [0, 2, 1])
+ context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
+ concat_context = paddle.concat([context, char_onehots], 1)
+ cur_hidden = self.rnn(concat_context, prev_hidden)
+ return cur_hidden, alpha
+
+
+class AttentionLSTM(nn.Layer):
+ def __init__(self, in_channels, out_channels, hidden_size, **kwargs):
+ super(AttentionLSTM, self).__init__()
+ self.input_size = in_channels
+ self.hidden_size = hidden_size
+ self.num_classes = out_channels
+
+ self.attention_cell = AttentionLSTMCell(
+ in_channels, hidden_size, out_channels, use_gru=False)
+ self.generator = nn.Linear(hidden_size, out_channels)
+
+ def _char_to_onehot(self, input_char, onehot_dim):
+ input_ont_hot = F.one_hot(input_char, onehot_dim)
+ return input_ont_hot
+
+ def forward(self, inputs, targets=None, batch_max_length=25):
+ batch_size = inputs.shape[0]
+ num_steps = batch_max_length
+
+ hidden = (paddle.zeros((batch_size, self.hidden_size)), paddle.zeros(
+ (batch_size, self.hidden_size)))
+ output_hiddens = []
+
+ if targets is not None:
+ for i in range(num_steps):
+ # one-hot vectors for a i-th char
+ char_onehots = self._char_to_onehot(
+ targets[:, i], onehot_dim=self.num_classes)
+ hidden, alpha = self.attention_cell(hidden, inputs,
+ char_onehots)
+
+ hidden = (hidden[1][0], hidden[1][1])
+ output_hiddens.append(paddle.unsqueeze(hidden[0], axis=1))
+ output = paddle.concat(output_hiddens, axis=1)
+ probs = self.generator(output)
+
+ else:
+ targets = paddle.zeros(shape=[batch_size], dtype="int32")
+ probs = None
+
+ for i in range(num_steps):
+ char_onehots = self._char_to_onehot(
+ targets, onehot_dim=self.num_classes)
+ hidden, alpha = self.attention_cell(hidden, inputs,
+ char_onehots)
+ probs_step = self.generator(hidden[0])
+ hidden = (hidden[1][0], hidden[1][1])
+ if probs is None:
+ probs = paddle.unsqueeze(probs_step, axis=1)
+ else:
+ probs = paddle.concat(
+ [probs, paddle.unsqueeze(
+ probs_step, axis=1)], axis=1)
+
+ next_input = probs_step.argmax(axis=1)
+
+ targets = next_input
+
+ return probs
+
+
+class AttentionLSTMCell(nn.Layer):
+ def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
+ super(AttentionLSTMCell, self).__init__()
+ self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
+ self.h2h = nn.Linear(hidden_size, hidden_size)
+ self.score = nn.Linear(hidden_size, 1, bias_attr=False)
+ if not use_gru:
+ self.rnn = nn.LSTMCell(
+ input_size=input_size + num_embeddings, hidden_size=hidden_size)
+ else:
+ self.rnn = nn.GRUCell(
+ input_size=input_size + num_embeddings, hidden_size=hidden_size)
+
+ self.hidden_size = hidden_size
+
+ def forward(self, prev_hidden, batch_H, char_onehots):
+ batch_H_proj = self.i2h(batch_H)
+ prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden[0]), axis=1)
+ res = paddle.add(batch_H_proj, prev_hidden_proj)
+ res = paddle.tanh(res)
+ e = self.score(res)
+
+ alpha = F.softmax(e, axis=1)
+ alpha = paddle.transpose(alpha, [0, 2, 1])
+ context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
+ concat_context = paddle.concat([context, char_onehots], 1)
+ cur_hidden = self.rnn(concat_context, prev_hidden)
+
+ return cur_hidden, alpha
diff --git a/backend/ppocr/modeling/necks/__init__.py b/backend/ppocr/modeling/necks/__init__.py
index 405e062b..e10b082d 100644
--- a/backend/ppocr/modeling/necks/__init__.py
+++ b/backend/ppocr/modeling/necks/__init__.py
@@ -14,12 +14,21 @@
__all__ = ['build_neck']
+
def build_neck(config):
- from .db_fpn import DBFPN
+ from .db_fpn import DBFPN, RSEFPN, LKPAN
from .east_fpn import EASTFPN
from .sast_fpn import SASTFPN
from .rnn import SequenceEncoder
- support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder']
+ from .pg_fpn import PGFPN
+ from .table_fpn import TableFPN
+ from .fpn import FPN
+ from .fce_fpn import FCEFPN
+ from .pren_fpn import PRENFPN
+ support_dict = [
+ 'FPN', 'FCEFPN', 'LKPAN', 'DBFPN', 'RSEFPN', 'EASTFPN', 'SASTFPN',
+ 'SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN'
+ ]
module_name = config.pop('name')
assert module_name in support_dict, Exception('neck only support {}'.format(
diff --git a/backend/ppocr/modeling/necks/db_fpn.py b/backend/ppocr/modeling/necks/db_fpn.py
index 710023f3..93ed2dbf 100644
--- a/backend/ppocr/modeling/necks/db_fpn.py
+++ b/backend/ppocr/modeling/necks/db_fpn.py
@@ -20,6 +20,88 @@
from paddle import nn
import paddle.nn.functional as F
from paddle import ParamAttr
+import os
+import sys
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../../..')))
+
+from ppocr.modeling.backbones.det_mobilenet_v3 import SEModule
+
+
+class DSConv(nn.Layer):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ padding,
+ stride=1,
+ groups=None,
+ if_act=True,
+ act="relu",
+ **kwargs):
+ super(DSConv, self).__init__()
+ if groups == None:
+ groups = in_channels
+ self.if_act = if_act
+ self.act = act
+ self.conv1 = nn.Conv2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ groups=groups,
+ bias_attr=False)
+
+ self.bn1 = nn.BatchNorm(num_channels=in_channels, act=None)
+
+ self.conv2 = nn.Conv2D(
+ in_channels=in_channels,
+ out_channels=int(in_channels * 4),
+ kernel_size=1,
+ stride=1,
+ bias_attr=False)
+
+ self.bn2 = nn.BatchNorm(num_channels=int(in_channels * 4), act=None)
+
+ self.conv3 = nn.Conv2D(
+ in_channels=int(in_channels * 4),
+ out_channels=out_channels,
+ kernel_size=1,
+ stride=1,
+ bias_attr=False)
+ self._c = [in_channels, out_channels]
+ if in_channels != out_channels:
+ self.conv_end = nn.Conv2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ stride=1,
+ bias_attr=False)
+
+ def forward(self, inputs):
+
+ x = self.conv1(inputs)
+ x = self.bn1(x)
+
+ x = self.conv2(x)
+ x = self.bn2(x)
+ if self.if_act:
+ if self.act == "relu":
+ x = F.relu(x)
+ elif self.act == "hardswish":
+ x = F.hardswish(x)
+ else:
+ print("The activation function({}) is selected incorrectly.".
+ format(self.act))
+ exit()
+
+ x = self.conv3(x)
+ if self._c[0] != self._c[1]:
+ x = x + self.conv_end(inputs)
+ return x
class DBFPN(nn.Layer):
@@ -32,61 +114,53 @@ def __init__(self, in_channels, out_channels, **kwargs):
in_channels=in_channels[0],
out_channels=self.out_channels,
kernel_size=1,
- weight_attr=ParamAttr(
- name='conv2d_51.w_0', initializer=weight_attr),
+ weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.in3_conv = nn.Conv2D(
in_channels=in_channels[1],
out_channels=self.out_channels,
kernel_size=1,
- weight_attr=ParamAttr(
- name='conv2d_50.w_0', initializer=weight_attr),
+ weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.in4_conv = nn.Conv2D(
in_channels=in_channels[2],
out_channels=self.out_channels,
kernel_size=1,
- weight_attr=ParamAttr(
- name='conv2d_49.w_0', initializer=weight_attr),
+ weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.in5_conv = nn.Conv2D(
in_channels=in_channels[3],
out_channels=self.out_channels,
kernel_size=1,
- weight_attr=ParamAttr(
- name='conv2d_48.w_0', initializer=weight_attr),
+ weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.p5_conv = nn.Conv2D(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
- weight_attr=ParamAttr(
- name='conv2d_52.w_0', initializer=weight_attr),
+ weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.p4_conv = nn.Conv2D(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
- weight_attr=ParamAttr(
- name='conv2d_53.w_0', initializer=weight_attr),
+ weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.p3_conv = nn.Conv2D(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
- weight_attr=ParamAttr(
- name='conv2d_54.w_0', initializer=weight_attr),
+ weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
self.p2_conv = nn.Conv2D(
in_channels=self.out_channels,
out_channels=self.out_channels // 4,
kernel_size=3,
padding=1,
- weight_attr=ParamAttr(
- name='conv2d_55.w_0', initializer=weight_attr),
+ weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
def forward(self, x):
@@ -114,3 +188,171 @@ def forward(self, x):
fuse = paddle.concat([p5, p4, p3, p2], axis=1)
return fuse
+
+
+class RSELayer(nn.Layer):
+ def __init__(self, in_channels, out_channels, kernel_size, shortcut=True):
+ super(RSELayer, self).__init__()
+ weight_attr = paddle.nn.initializer.KaimingUniform()
+ self.out_channels = out_channels
+ self.in_conv = nn.Conv2D(
+ in_channels=in_channels,
+ out_channels=self.out_channels,
+ kernel_size=kernel_size,
+ padding=int(kernel_size // 2),
+ weight_attr=ParamAttr(initializer=weight_attr),
+ bias_attr=False)
+ self.se_block = SEModule(self.out_channels)
+ self.shortcut = shortcut
+
+ def forward(self, ins):
+ x = self.in_conv(ins)
+ if self.shortcut:
+ out = x + self.se_block(x)
+ else:
+ out = self.se_block(x)
+ return out
+
+
+class RSEFPN(nn.Layer):
+ def __init__(self, in_channels, out_channels, shortcut=True, **kwargs):
+ super(RSEFPN, self).__init__()
+ self.out_channels = out_channels
+ self.ins_conv = nn.LayerList()
+ self.inp_conv = nn.LayerList()
+
+ for i in range(len(in_channels)):
+ self.ins_conv.append(
+ RSELayer(
+ in_channels[i],
+ out_channels,
+ kernel_size=1,
+ shortcut=shortcut))
+ self.inp_conv.append(
+ RSELayer(
+ out_channels,
+ out_channels // 4,
+ kernel_size=3,
+ shortcut=shortcut))
+
+ def forward(self, x):
+ c2, c3, c4, c5 = x
+
+ in5 = self.ins_conv[3](c5)
+ in4 = self.ins_conv[2](c4)
+ in3 = self.ins_conv[1](c3)
+ in2 = self.ins_conv[0](c2)
+
+ out4 = in4 + F.upsample(
+ in5, scale_factor=2, mode="nearest", align_mode=1) # 1/16
+ out3 = in3 + F.upsample(
+ out4, scale_factor=2, mode="nearest", align_mode=1) # 1/8
+ out2 = in2 + F.upsample(
+ out3, scale_factor=2, mode="nearest", align_mode=1) # 1/4
+
+ p5 = self.inp_conv[3](in5)
+ p4 = self.inp_conv[2](out4)
+ p3 = self.inp_conv[1](out3)
+ p2 = self.inp_conv[0](out2)
+
+ p5 = F.upsample(p5, scale_factor=8, mode="nearest", align_mode=1)
+ p4 = F.upsample(p4, scale_factor=4, mode="nearest", align_mode=1)
+ p3 = F.upsample(p3, scale_factor=2, mode="nearest", align_mode=1)
+
+ fuse = paddle.concat([p5, p4, p3, p2], axis=1)
+ return fuse
+
+
+class LKPAN(nn.Layer):
+ def __init__(self, in_channels, out_channels, mode='large', **kwargs):
+ super(LKPAN, self).__init__()
+ self.out_channels = out_channels
+ weight_attr = paddle.nn.initializer.KaimingUniform()
+
+ self.ins_conv = nn.LayerList()
+ self.inp_conv = nn.LayerList()
+ # pan head
+ self.pan_head_conv = nn.LayerList()
+ self.pan_lat_conv = nn.LayerList()
+
+ if mode.lower() == 'lite':
+ p_layer = DSConv
+ elif mode.lower() == 'large':
+ p_layer = nn.Conv2D
+ else:
+ raise ValueError(
+ "mode can only be one of ['lite', 'large'], but received {}".
+ format(mode))
+
+ for i in range(len(in_channels)):
+ self.ins_conv.append(
+ nn.Conv2D(
+ in_channels=in_channels[i],
+ out_channels=self.out_channels,
+ kernel_size=1,
+ weight_attr=ParamAttr(initializer=weight_attr),
+ bias_attr=False))
+
+ self.inp_conv.append(
+ p_layer(
+ in_channels=self.out_channels,
+ out_channels=self.out_channels // 4,
+ kernel_size=9,
+ padding=4,
+ weight_attr=ParamAttr(initializer=weight_attr),
+ bias_attr=False))
+
+ if i > 0:
+ self.pan_head_conv.append(
+ nn.Conv2D(
+ in_channels=self.out_channels // 4,
+ out_channels=self.out_channels // 4,
+ kernel_size=3,
+ padding=1,
+ stride=2,
+ weight_attr=ParamAttr(initializer=weight_attr),
+ bias_attr=False))
+ self.pan_lat_conv.append(
+ p_layer(
+ in_channels=self.out_channels // 4,
+ out_channels=self.out_channels // 4,
+ kernel_size=9,
+ padding=4,
+ weight_attr=ParamAttr(initializer=weight_attr),
+ bias_attr=False))
+
+ def forward(self, x):
+ c2, c3, c4, c5 = x
+
+ in5 = self.ins_conv[3](c5)
+ in4 = self.ins_conv[2](c4)
+ in3 = self.ins_conv[1](c3)
+ in2 = self.ins_conv[0](c2)
+
+ out4 = in4 + F.upsample(
+ in5, scale_factor=2, mode="nearest", align_mode=1) # 1/16
+ out3 = in3 + F.upsample(
+ out4, scale_factor=2, mode="nearest", align_mode=1) # 1/8
+ out2 = in2 + F.upsample(
+ out3, scale_factor=2, mode="nearest", align_mode=1) # 1/4
+
+ f5 = self.inp_conv[3](in5)
+ f4 = self.inp_conv[2](out4)
+ f3 = self.inp_conv[1](out3)
+ f2 = self.inp_conv[0](out2)
+
+ pan3 = f3 + self.pan_head_conv[0](f2)
+ pan4 = f4 + self.pan_head_conv[1](pan3)
+ pan5 = f5 + self.pan_head_conv[2](pan4)
+
+ p2 = self.pan_lat_conv[0](f2)
+ p3 = self.pan_lat_conv[1](pan3)
+ p4 = self.pan_lat_conv[2](pan4)
+ p5 = self.pan_lat_conv[3](pan5)
+
+ p5 = F.upsample(p5, scale_factor=8, mode="nearest", align_mode=1)
+ p4 = F.upsample(p4, scale_factor=4, mode="nearest", align_mode=1)
+ p3 = F.upsample(p3, scale_factor=2, mode="nearest", align_mode=1)
+
+ fuse = paddle.concat([p5, p4, p3, p2], axis=1)
+ return fuse
diff --git a/backend/ppocr/modeling/necks/fce_fpn.py b/backend/ppocr/modeling/necks/fce_fpn.py
new file mode 100644
index 00000000..954e964e
--- /dev/null
+++ b/backend/ppocr/modeling/necks/fce_fpn.py
@@ -0,0 +1,280 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.3/ppdet/modeling/necks/fpn.py
+"""
+
+import paddle.nn as nn
+import paddle.nn.functional as F
+from paddle import ParamAttr
+from paddle.nn.initializer import XavierUniform
+from paddle.nn.initializer import Normal
+from paddle.regularizer import L2Decay
+
+__all__ = ['FCEFPN']
+
+
+class ConvNormLayer(nn.Layer):
+ def __init__(self,
+ ch_in,
+ ch_out,
+ filter_size,
+ stride,
+ groups=1,
+ norm_type='bn',
+ norm_decay=0.,
+ norm_groups=32,
+ lr_scale=1.,
+ freeze_norm=False,
+ initializer=Normal(
+ mean=0., std=0.01)):
+ super(ConvNormLayer, self).__init__()
+ assert norm_type in ['bn', 'sync_bn', 'gn']
+
+ bias_attr = False
+
+ self.conv = nn.Conv2D(
+ in_channels=ch_in,
+ out_channels=ch_out,
+ kernel_size=filter_size,
+ stride=stride,
+ padding=(filter_size - 1) // 2,
+ groups=groups,
+ weight_attr=ParamAttr(
+ initializer=initializer, learning_rate=1.),
+ bias_attr=bias_attr)
+
+ norm_lr = 0. if freeze_norm else 1.
+ param_attr = ParamAttr(
+ learning_rate=norm_lr,
+ regularizer=L2Decay(norm_decay) if norm_decay is not None else None)
+ bias_attr = ParamAttr(
+ learning_rate=norm_lr,
+ regularizer=L2Decay(norm_decay) if norm_decay is not None else None)
+ if norm_type == 'bn':
+ self.norm = nn.BatchNorm2D(
+ ch_out, weight_attr=param_attr, bias_attr=bias_attr)
+ elif norm_type == 'sync_bn':
+ self.norm = nn.SyncBatchNorm(
+ ch_out, weight_attr=param_attr, bias_attr=bias_attr)
+ elif norm_type == 'gn':
+ self.norm = nn.GroupNorm(
+ num_groups=norm_groups,
+ num_channels=ch_out,
+ weight_attr=param_attr,
+ bias_attr=bias_attr)
+
+ def forward(self, inputs):
+ out = self.conv(inputs)
+ out = self.norm(out)
+ return out
+
+
+class FCEFPN(nn.Layer):
+ """
+ Feature Pyramid Network, see https://arxiv.org/abs/1612.03144
+ Args:
+ in_channels (list[int]): input channels of each level which can be
+ derived from the output shape of backbone by from_config
+ out_channels (list[int]): output channel of each level
+ spatial_scales (list[float]): the spatial scales between input feature
+ maps and original input image which can be derived from the output
+ shape of backbone by from_config
+ has_extra_convs (bool): whether to add extra conv to the last level.
+ default False
+ extra_stage (int): the number of extra stages added to the last level.
+ default 1
+ use_c5 (bool): Whether to use c5 as the input of extra stage,
+ otherwise p5 is used. default True
+ norm_type (string|None): The normalization type in FPN module. If
+ norm_type is None, norm will not be used after conv and if
+ norm_type is string, bn, gn, sync_bn are available. default None
+ norm_decay (float): weight decay for normalization layer weights.
+ default 0.
+ freeze_norm (bool): whether to freeze normalization layer.
+ default False
+ relu_before_extra_convs (bool): whether to add relu before extra convs.
+ default False
+
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ spatial_scales=[0.25, 0.125, 0.0625, 0.03125],
+ has_extra_convs=False,
+ extra_stage=1,
+ use_c5=True,
+ norm_type=None,
+ norm_decay=0.,
+ freeze_norm=False,
+ relu_before_extra_convs=True):
+ super(FCEFPN, self).__init__()
+ self.out_channels = out_channels
+ for s in range(extra_stage):
+ spatial_scales = spatial_scales + [spatial_scales[-1] / 2.]
+ self.spatial_scales = spatial_scales
+ self.has_extra_convs = has_extra_convs
+ self.extra_stage = extra_stage
+ self.use_c5 = use_c5
+ self.relu_before_extra_convs = relu_before_extra_convs
+ self.norm_type = norm_type
+ self.norm_decay = norm_decay
+ self.freeze_norm = freeze_norm
+
+ self.lateral_convs = []
+ self.fpn_convs = []
+ fan = out_channels * 3 * 3
+
+ # stage index 0,1,2,3 stands for res2,res3,res4,res5 on ResNet Backbone
+ # 0 <= st_stage < ed_stage <= 3
+ st_stage = 4 - len(in_channels)
+ ed_stage = st_stage + len(in_channels) - 1
+ for i in range(st_stage, ed_stage + 1):
+ if i == 3:
+ lateral_name = 'fpn_inner_res5_sum'
+ else:
+ lateral_name = 'fpn_inner_res{}_sum_lateral'.format(i + 2)
+ in_c = in_channels[i - st_stage]
+ if self.norm_type is not None:
+ lateral = self.add_sublayer(
+ lateral_name,
+ ConvNormLayer(
+ ch_in=in_c,
+ ch_out=out_channels,
+ filter_size=1,
+ stride=1,
+ norm_type=self.norm_type,
+ norm_decay=self.norm_decay,
+ freeze_norm=self.freeze_norm,
+ initializer=XavierUniform(fan_out=in_c)))
+ else:
+ lateral = self.add_sublayer(
+ lateral_name,
+ nn.Conv2D(
+ in_channels=in_c,
+ out_channels=out_channels,
+ kernel_size=1,
+ weight_attr=ParamAttr(
+ initializer=XavierUniform(fan_out=in_c))))
+ self.lateral_convs.append(lateral)
+
+ for i in range(st_stage, ed_stage + 1):
+ fpn_name = 'fpn_res{}_sum'.format(i + 2)
+ if self.norm_type is not None:
+ fpn_conv = self.add_sublayer(
+ fpn_name,
+ ConvNormLayer(
+ ch_in=out_channels,
+ ch_out=out_channels,
+ filter_size=3,
+ stride=1,
+ norm_type=self.norm_type,
+ norm_decay=self.norm_decay,
+ freeze_norm=self.freeze_norm,
+ initializer=XavierUniform(fan_out=fan)))
+ else:
+ fpn_conv = self.add_sublayer(
+ fpn_name,
+ nn.Conv2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ padding=1,
+ weight_attr=ParamAttr(
+ initializer=XavierUniform(fan_out=fan))))
+ self.fpn_convs.append(fpn_conv)
+
+ # add extra conv levels for RetinaNet(use_c5)/FCOS(use_p5)
+ if self.has_extra_convs:
+ for i in range(self.extra_stage):
+ lvl = ed_stage + 1 + i
+ if i == 0 and self.use_c5:
+ in_c = in_channels[-1]
+ else:
+ in_c = out_channels
+ extra_fpn_name = 'fpn_{}'.format(lvl + 2)
+ if self.norm_type is not None:
+ extra_fpn_conv = self.add_sublayer(
+ extra_fpn_name,
+ ConvNormLayer(
+ ch_in=in_c,
+ ch_out=out_channels,
+ filter_size=3,
+ stride=2,
+ norm_type=self.norm_type,
+ norm_decay=self.norm_decay,
+ freeze_norm=self.freeze_norm,
+ initializer=XavierUniform(fan_out=fan)))
+ else:
+ extra_fpn_conv = self.add_sublayer(
+ extra_fpn_name,
+ nn.Conv2D(
+ in_channels=in_c,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ weight_attr=ParamAttr(
+ initializer=XavierUniform(fan_out=fan))))
+ self.fpn_convs.append(extra_fpn_conv)
+
+ @classmethod
+ def from_config(cls, cfg, input_shape):
+ return {
+ 'in_channels': [i.channels for i in input_shape],
+ 'spatial_scales': [1.0 / i.stride for i in input_shape],
+ }
+
+ def forward(self, body_feats):
+ laterals = []
+ num_levels = len(body_feats)
+
+ for i in range(num_levels):
+ laterals.append(self.lateral_convs[i](body_feats[i]))
+
+ for i in range(1, num_levels):
+ lvl = num_levels - i
+ upsample = F.interpolate(
+ laterals[lvl],
+ scale_factor=2.,
+ mode='nearest', )
+ laterals[lvl - 1] += upsample
+
+ fpn_output = []
+ for lvl in range(num_levels):
+ fpn_output.append(self.fpn_convs[lvl](laterals[lvl]))
+
+ if self.extra_stage > 0:
+ # use max pool to get more levels on top of outputs (Faster R-CNN, Mask R-CNN)
+ if not self.has_extra_convs:
+ assert self.extra_stage == 1, 'extra_stage should be 1 if FPN has not extra convs'
+ fpn_output.append(F.max_pool2d(fpn_output[-1], 1, stride=2))
+ # add extra conv levels for RetinaNet(use_c5)/FCOS(use_p5)
+ else:
+ if self.use_c5:
+ extra_source = body_feats[-1]
+ else:
+ extra_source = fpn_output[-1]
+ fpn_output.append(self.fpn_convs[num_levels](extra_source))
+
+ for i in range(1, self.extra_stage):
+ if self.relu_before_extra_convs:
+ fpn_output.append(self.fpn_convs[num_levels + i](F.relu(
+ fpn_output[-1])))
+ else:
+ fpn_output.append(self.fpn_convs[num_levels + i](
+ fpn_output[-1]))
+ return fpn_output
diff --git a/backend/ppocr/modeling/necks/fpn.py b/backend/ppocr/modeling/necks/fpn.py
new file mode 100644
index 00000000..48c85b1e
--- /dev/null
+++ b/backend/ppocr/modeling/necks/fpn.py
@@ -0,0 +1,138 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/whai362/PSENet/blob/python3/models/neck/fpn.py
+"""
+
+import paddle.nn as nn
+import paddle
+import math
+import paddle.nn.functional as F
+
+
+class Conv_BN_ReLU(nn.Layer):
+ def __init__(self,
+ in_planes,
+ out_planes,
+ kernel_size=1,
+ stride=1,
+ padding=0):
+ super(Conv_BN_ReLU, self).__init__()
+ self.conv = nn.Conv2D(
+ in_planes,
+ out_planes,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ bias_attr=False)
+ self.bn = nn.BatchNorm2D(out_planes, momentum=0.1)
+ self.relu = nn.ReLU()
+
+ for m in self.sublayers():
+ if isinstance(m, nn.Conv2D):
+ n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
+ m.weight = paddle.create_parameter(
+ shape=m.weight.shape,
+ dtype='float32',
+ default_initializer=paddle.nn.initializer.Normal(
+ 0, math.sqrt(2. / n)))
+ elif isinstance(m, nn.BatchNorm2D):
+ m.weight = paddle.create_parameter(
+ shape=m.weight.shape,
+ dtype='float32',
+ default_initializer=paddle.nn.initializer.Constant(1.0))
+ m.bias = paddle.create_parameter(
+ shape=m.bias.shape,
+ dtype='float32',
+ default_initializer=paddle.nn.initializer.Constant(0.0))
+
+ def forward(self, x):
+ return self.relu(self.bn(self.conv(x)))
+
+
+class FPN(nn.Layer):
+ def __init__(self, in_channels, out_channels):
+ super(FPN, self).__init__()
+
+ # Top layer
+ self.toplayer_ = Conv_BN_ReLU(
+ in_channels[3], out_channels, kernel_size=1, stride=1, padding=0)
+ # Lateral layers
+ self.latlayer1_ = Conv_BN_ReLU(
+ in_channels[2], out_channels, kernel_size=1, stride=1, padding=0)
+
+ self.latlayer2_ = Conv_BN_ReLU(
+ in_channels[1], out_channels, kernel_size=1, stride=1, padding=0)
+
+ self.latlayer3_ = Conv_BN_ReLU(
+ in_channels[0], out_channels, kernel_size=1, stride=1, padding=0)
+
+ # Smooth layers
+ self.smooth1_ = Conv_BN_ReLU(
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ self.smooth2_ = Conv_BN_ReLU(
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ self.smooth3_ = Conv_BN_ReLU(
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ self.out_channels = out_channels * 4
+ for m in self.sublayers():
+ if isinstance(m, nn.Conv2D):
+ n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
+ m.weight = paddle.create_parameter(
+ shape=m.weight.shape,
+ dtype='float32',
+ default_initializer=paddle.nn.initializer.Normal(
+ 0, math.sqrt(2. / n)))
+ elif isinstance(m, nn.BatchNorm2D):
+ m.weight = paddle.create_parameter(
+ shape=m.weight.shape,
+ dtype='float32',
+ default_initializer=paddle.nn.initializer.Constant(1.0))
+ m.bias = paddle.create_parameter(
+ shape=m.bias.shape,
+ dtype='float32',
+ default_initializer=paddle.nn.initializer.Constant(0.0))
+
+ def _upsample(self, x, scale=1):
+ return F.upsample(x, scale_factor=scale, mode='bilinear')
+
+ def _upsample_add(self, x, y, scale=1):
+ return F.upsample(x, scale_factor=scale, mode='bilinear') + y
+
+ def forward(self, x):
+ f2, f3, f4, f5 = x
+ p5 = self.toplayer_(f5)
+
+ f4 = self.latlayer1_(f4)
+ p4 = self._upsample_add(p5, f4, 2)
+ p4 = self.smooth1_(p4)
+
+ f3 = self.latlayer2_(f3)
+ p3 = self._upsample_add(p4, f3, 2)
+ p3 = self.smooth2_(p3)
+
+ f2 = self.latlayer3_(f2)
+ p2 = self._upsample_add(p3, f2, 2)
+ p2 = self.smooth3_(p2)
+
+ p3 = self._upsample(p3, 2)
+ p4 = self._upsample(p4, 4)
+ p5 = self._upsample(p5, 8)
+
+ fuse = paddle.concat([p2, p3, p4, p5], axis=1)
+ return fuse
diff --git a/backend/ppocr/modeling/necks/pg_fpn.py b/backend/ppocr/modeling/necks/pg_fpn.py
new file mode 100644
index 00000000..3f64539f
--- /dev/null
+++ b/backend/ppocr/modeling/necks/pg_fpn.py
@@ -0,0 +1,314 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+import paddle.nn.functional as F
+from paddle import ParamAttr
+
+
+class ConvBNLayer(nn.Layer):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ groups=1,
+ is_vd_mode=False,
+ act=None,
+ name=None):
+ super(ConvBNLayer, self).__init__()
+
+ self.is_vd_mode = is_vd_mode
+ self._pool2d_avg = nn.AvgPool2D(
+ kernel_size=2, stride=2, padding=0, ceil_mode=True)
+ self._conv = nn.Conv2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=(kernel_size - 1) // 2,
+ groups=groups,
+ weight_attr=ParamAttr(name=name + "_weights"),
+ bias_attr=False)
+ if name == "conv1":
+ bn_name = "bn_" + name
+ else:
+ bn_name = "bn" + name[3:]
+ self._batch_norm = nn.BatchNorm(
+ out_channels,
+ act=act,
+ param_attr=ParamAttr(name=bn_name + '_scale'),
+ bias_attr=ParamAttr(bn_name + '_offset'),
+ moving_mean_name=bn_name + '_mean',
+ moving_variance_name=bn_name + '_variance',
+ use_global_stats=False)
+
+ def forward(self, inputs):
+ y = self._conv(inputs)
+ y = self._batch_norm(y)
+ return y
+
+
+class DeConvBNLayer(nn.Layer):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size=4,
+ stride=2,
+ padding=1,
+ groups=1,
+ if_act=True,
+ act=None,
+ name=None):
+ super(DeConvBNLayer, self).__init__()
+
+ self.if_act = if_act
+ self.act = act
+ self.deconv = nn.Conv2DTranspose(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ groups=groups,
+ weight_attr=ParamAttr(name=name + '_weights'),
+ bias_attr=False)
+ self.bn = nn.BatchNorm(
+ num_channels=out_channels,
+ act=act,
+ param_attr=ParamAttr(name="bn_" + name + "_scale"),
+ bias_attr=ParamAttr(name="bn_" + name + "_offset"),
+ moving_mean_name="bn_" + name + "_mean",
+ moving_variance_name="bn_" + name + "_variance",
+ use_global_stats=False)
+
+ def forward(self, x):
+ x = self.deconv(x)
+ x = self.bn(x)
+ return x
+
+
+class PGFPN(nn.Layer):
+ def __init__(self, in_channels, **kwargs):
+ super(PGFPN, self).__init__()
+ num_inputs = [2048, 2048, 1024, 512, 256]
+ num_outputs = [256, 256, 192, 192, 128]
+ self.out_channels = 128
+ self.conv_bn_layer_1 = ConvBNLayer(
+ in_channels=3,
+ out_channels=32,
+ kernel_size=3,
+ stride=1,
+ act=None,
+ name='FPN_d1')
+ self.conv_bn_layer_2 = ConvBNLayer(
+ in_channels=64,
+ out_channels=64,
+ kernel_size=3,
+ stride=1,
+ act=None,
+ name='FPN_d2')
+ self.conv_bn_layer_3 = ConvBNLayer(
+ in_channels=256,
+ out_channels=128,
+ kernel_size=3,
+ stride=1,
+ act=None,
+ name='FPN_d3')
+ self.conv_bn_layer_4 = ConvBNLayer(
+ in_channels=32,
+ out_channels=64,
+ kernel_size=3,
+ stride=2,
+ act=None,
+ name='FPN_d4')
+ self.conv_bn_layer_5 = ConvBNLayer(
+ in_channels=64,
+ out_channels=64,
+ kernel_size=3,
+ stride=1,
+ act='relu',
+ name='FPN_d5')
+ self.conv_bn_layer_6 = ConvBNLayer(
+ in_channels=64,
+ out_channels=128,
+ kernel_size=3,
+ stride=2,
+ act=None,
+ name='FPN_d6')
+ self.conv_bn_layer_7 = ConvBNLayer(
+ in_channels=128,
+ out_channels=128,
+ kernel_size=3,
+ stride=1,
+ act='relu',
+ name='FPN_d7')
+ self.conv_bn_layer_8 = ConvBNLayer(
+ in_channels=128,
+ out_channels=128,
+ kernel_size=1,
+ stride=1,
+ act=None,
+ name='FPN_d8')
+
+ self.conv_h0 = ConvBNLayer(
+ in_channels=num_inputs[0],
+ out_channels=num_outputs[0],
+ kernel_size=1,
+ stride=1,
+ act=None,
+ name="conv_h{}".format(0))
+ self.conv_h1 = ConvBNLayer(
+ in_channels=num_inputs[1],
+ out_channels=num_outputs[1],
+ kernel_size=1,
+ stride=1,
+ act=None,
+ name="conv_h{}".format(1))
+ self.conv_h2 = ConvBNLayer(
+ in_channels=num_inputs[2],
+ out_channels=num_outputs[2],
+ kernel_size=1,
+ stride=1,
+ act=None,
+ name="conv_h{}".format(2))
+ self.conv_h3 = ConvBNLayer(
+ in_channels=num_inputs[3],
+ out_channels=num_outputs[3],
+ kernel_size=1,
+ stride=1,
+ act=None,
+ name="conv_h{}".format(3))
+ self.conv_h4 = ConvBNLayer(
+ in_channels=num_inputs[4],
+ out_channels=num_outputs[4],
+ kernel_size=1,
+ stride=1,
+ act=None,
+ name="conv_h{}".format(4))
+
+ self.dconv0 = DeConvBNLayer(
+ in_channels=num_outputs[0],
+ out_channels=num_outputs[0 + 1],
+ name="dconv_{}".format(0))
+ self.dconv1 = DeConvBNLayer(
+ in_channels=num_outputs[1],
+ out_channels=num_outputs[1 + 1],
+ act=None,
+ name="dconv_{}".format(1))
+ self.dconv2 = DeConvBNLayer(
+ in_channels=num_outputs[2],
+ out_channels=num_outputs[2 + 1],
+ act=None,
+ name="dconv_{}".format(2))
+ self.dconv3 = DeConvBNLayer(
+ in_channels=num_outputs[3],
+ out_channels=num_outputs[3 + 1],
+ act=None,
+ name="dconv_{}".format(3))
+ self.conv_g1 = ConvBNLayer(
+ in_channels=num_outputs[1],
+ out_channels=num_outputs[1],
+ kernel_size=3,
+ stride=1,
+ act='relu',
+ name="conv_g{}".format(1))
+ self.conv_g2 = ConvBNLayer(
+ in_channels=num_outputs[2],
+ out_channels=num_outputs[2],
+ kernel_size=3,
+ stride=1,
+ act='relu',
+ name="conv_g{}".format(2))
+ self.conv_g3 = ConvBNLayer(
+ in_channels=num_outputs[3],
+ out_channels=num_outputs[3],
+ kernel_size=3,
+ stride=1,
+ act='relu',
+ name="conv_g{}".format(3))
+ self.conv_g4 = ConvBNLayer(
+ in_channels=num_outputs[4],
+ out_channels=num_outputs[4],
+ kernel_size=3,
+ stride=1,
+ act='relu',
+ name="conv_g{}".format(4))
+ self.convf = ConvBNLayer(
+ in_channels=num_outputs[4],
+ out_channels=num_outputs[4],
+ kernel_size=1,
+ stride=1,
+ act=None,
+ name="conv_f{}".format(4))
+
+ def forward(self, x):
+ c0, c1, c2, c3, c4, c5, c6 = x
+ # FPN_Down_Fusion
+ f = [c0, c1, c2]
+ g = [None, None, None]
+ h = [None, None, None]
+ h[0] = self.conv_bn_layer_1(f[0])
+ h[1] = self.conv_bn_layer_2(f[1])
+ h[2] = self.conv_bn_layer_3(f[2])
+
+ g[0] = self.conv_bn_layer_4(h[0])
+ g[1] = paddle.add(g[0], h[1])
+ g[1] = F.relu(g[1])
+ g[1] = self.conv_bn_layer_5(g[1])
+ g[1] = self.conv_bn_layer_6(g[1])
+
+ g[2] = paddle.add(g[1], h[2])
+ g[2] = F.relu(g[2])
+ g[2] = self.conv_bn_layer_7(g[2])
+ f_down = self.conv_bn_layer_8(g[2])
+
+ # FPN UP Fusion
+ f1 = [c6, c5, c4, c3, c2]
+ g = [None, None, None, None, None]
+ h = [None, None, None, None, None]
+ h[0] = self.conv_h0(f1[0])
+ h[1] = self.conv_h1(f1[1])
+ h[2] = self.conv_h2(f1[2])
+ h[3] = self.conv_h3(f1[3])
+ h[4] = self.conv_h4(f1[4])
+
+ g[0] = self.dconv0(h[0])
+ g[1] = paddle.add(g[0], h[1])
+ g[1] = F.relu(g[1])
+ g[1] = self.conv_g1(g[1])
+ g[1] = self.dconv1(g[1])
+
+ g[2] = paddle.add(g[1], h[2])
+ g[2] = F.relu(g[2])
+ g[2] = self.conv_g2(g[2])
+ g[2] = self.dconv2(g[2])
+
+ g[3] = paddle.add(g[2], h[3])
+ g[3] = F.relu(g[3])
+ g[3] = self.conv_g3(g[3])
+ g[3] = self.dconv3(g[3])
+
+ g[4] = paddle.add(x=g[3], y=h[4])
+ g[4] = F.relu(g[4])
+ g[4] = self.conv_g4(g[4])
+ f_up = self.convf(g[4])
+ f_common = paddle.add(f_down, f_up)
+ f_common = F.relu(f_common)
+ return f_common
diff --git a/backend/ppocr/modeling/necks/pren_fpn.py b/backend/ppocr/modeling/necks/pren_fpn.py
new file mode 100644
index 00000000..afbdcea8
--- /dev/null
+++ b/backend/ppocr/modeling/necks/pren_fpn.py
@@ -0,0 +1,163 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Code is refer from:
+https://github.com/RuijieJ/pren/blob/main/Nets/Aggregation.py
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+import paddle.nn.functional as F
+
+
+class PoolAggregate(nn.Layer):
+ def __init__(self, n_r, d_in, d_middle=None, d_out=None):
+ super(PoolAggregate, self).__init__()
+ if not d_middle:
+ d_middle = d_in
+ if not d_out:
+ d_out = d_in
+
+ self.d_in = d_in
+ self.d_middle = d_middle
+ self.d_out = d_out
+ self.act = nn.Swish()
+
+ self.n_r = n_r
+ self.aggs = self._build_aggs()
+
+ def _build_aggs(self):
+ aggs = []
+ for i in range(self.n_r):
+ aggs.append(
+ self.add_sublayer(
+ '{}'.format(i),
+ nn.Sequential(
+ ('conv1', nn.Conv2D(
+ self.d_in, self.d_middle, 3, 2, 1, bias_attr=False)
+ ), ('bn1', nn.BatchNorm(self.d_middle)),
+ ('act', self.act), ('conv2', nn.Conv2D(
+ self.d_middle, self.d_out, 3, 2, 1, bias_attr=False
+ )), ('bn2', nn.BatchNorm(self.d_out)))))
+ return aggs
+
+ def forward(self, x):
+ b = x.shape[0]
+ outs = []
+ for agg in self.aggs:
+ y = agg(x)
+ p = F.adaptive_avg_pool2d(y, 1)
+ outs.append(p.reshape((b, 1, self.d_out)))
+ out = paddle.concat(outs, 1)
+ return out
+
+
+class WeightAggregate(nn.Layer):
+ def __init__(self, n_r, d_in, d_middle=None, d_out=None):
+ super(WeightAggregate, self).__init__()
+ if not d_middle:
+ d_middle = d_in
+ if not d_out:
+ d_out = d_in
+
+ self.n_r = n_r
+ self.d_out = d_out
+ self.act = nn.Swish()
+
+ self.conv_n = nn.Sequential(
+ ('conv1', nn.Conv2D(
+ d_in, d_in, 3, 1, 1,
+ bias_attr=False)), ('bn1', nn.BatchNorm(d_in)),
+ ('act1', self.act), ('conv2', nn.Conv2D(
+ d_in, n_r, 1, bias_attr=False)), ('bn2', nn.BatchNorm(n_r)),
+ ('act2', nn.Sigmoid()))
+ self.conv_d = nn.Sequential(
+ ('conv1', nn.Conv2D(
+ d_in, d_middle, 3, 1, 1,
+ bias_attr=False)), ('bn1', nn.BatchNorm(d_middle)),
+ ('act1', self.act), ('conv2', nn.Conv2D(
+ d_middle, d_out, 1,
+ bias_attr=False)), ('bn2', nn.BatchNorm(d_out)))
+
+ def forward(self, x):
+ b, _, h, w = x.shape
+
+ hmaps = self.conv_n(x)
+ fmaps = self.conv_d(x)
+ r = paddle.bmm(
+ hmaps.reshape((b, self.n_r, h * w)),
+ fmaps.reshape((b, self.d_out, h * w)).transpose((0, 2, 1)))
+ return r
+
+
+class GCN(nn.Layer):
+ def __init__(self, d_in, n_in, d_out=None, n_out=None, dropout=0.1):
+ super(GCN, self).__init__()
+ if not d_out:
+ d_out = d_in
+ if not n_out:
+ n_out = d_in
+
+ self.conv_n = nn.Conv1D(n_in, n_out, 1)
+ self.linear = nn.Linear(d_in, d_out)
+ self.dropout = nn.Dropout(dropout)
+ self.act = nn.Swish()
+
+ def forward(self, x):
+ x = self.conv_n(x)
+ x = self.dropout(self.linear(x))
+ return self.act(x)
+
+
+class PRENFPN(nn.Layer):
+ def __init__(self, in_channels, n_r, d_model, max_len, dropout):
+ super(PRENFPN, self).__init__()
+ assert len(in_channels) == 3, "in_channels' length must be 3."
+ c1, c2, c3 = in_channels # the depths are from big to small
+ # build fpn
+ assert d_model % 3 == 0, "{} can't be divided by 3.".format(d_model)
+ self.agg_p1 = PoolAggregate(n_r, c1, d_out=d_model // 3)
+ self.agg_p2 = PoolAggregate(n_r, c2, d_out=d_model // 3)
+ self.agg_p3 = PoolAggregate(n_r, c3, d_out=d_model // 3)
+
+ self.agg_w1 = WeightAggregate(n_r, c1, 4 * c1, d_model // 3)
+ self.agg_w2 = WeightAggregate(n_r, c2, 4 * c2, d_model // 3)
+ self.agg_w3 = WeightAggregate(n_r, c3, 4 * c3, d_model // 3)
+
+ self.gcn_pool = GCN(d_model, n_r, d_model, max_len, dropout)
+ self.gcn_weight = GCN(d_model, n_r, d_model, max_len, dropout)
+
+ self.out_channels = d_model
+
+ def forward(self, inputs):
+ f3, f5, f7 = inputs
+
+ rp1 = self.agg_p1(f3)
+ rp2 = self.agg_p2(f5)
+ rp3 = self.agg_p3(f7)
+ rp = paddle.concat([rp1, rp2, rp3], 2) # [b,nr,d]
+
+ rw1 = self.agg_w1(f3)
+ rw2 = self.agg_w2(f5)
+ rw3 = self.agg_w3(f7)
+ rw = paddle.concat([rw1, rw2, rw3], 2) # [b,nr,d]
+
+ y1 = self.gcn_pool(rp)
+ y2 = self.gcn_weight(rw)
+ y = 0.5 * (y1 + y2)
+ return y # [b,max_len,d]
diff --git a/backend/ppocr/modeling/necks/rnn.py b/backend/ppocr/modeling/necks/rnn.py
index de87b3d9..c8a774b8 100644
--- a/backend/ppocr/modeling/necks/rnn.py
+++ b/backend/ppocr/modeling/necks/rnn.py
@@ -16,9 +16,11 @@
from __future__ import division
from __future__ import print_function
+import paddle
from paddle import nn
from ppocr.modeling.heads.rec_ctc_head import get_para_bias_attr
+from ppocr.modeling.backbones.rec_svtrnet import Block, ConvBNLayer, trunc_normal_, zeros_, ones_
class Im2Seq(nn.Layer):
@@ -51,7 +53,7 @@ def __init__(self, in_channels, hidden_size):
super(EncoderWithFC, self).__init__()
self.out_channels = hidden_size
weight_attr, bias_attr = get_para_bias_attr(
- l2_decay=0.00001, k=in_channels, name='reduce_encoder_fea')
+ l2_decay=0.00001, k=in_channels)
self.fc = nn.Linear(
in_channels,
hidden_size,
@@ -64,29 +66,126 @@ def forward(self, x):
return x
+class EncoderWithSVTR(nn.Layer):
+ def __init__(
+ self,
+ in_channels,
+ dims=64, # XS
+ depth=2,
+ hidden_dims=120,
+ use_guide=False,
+ num_heads=8,
+ qkv_bias=True,
+ mlp_ratio=2.0,
+ drop_rate=0.1,
+ attn_drop_rate=0.1,
+ drop_path=0.,
+ qk_scale=None):
+ super(EncoderWithSVTR, self).__init__()
+ self.depth = depth
+ self.use_guide = use_guide
+ self.conv1 = ConvBNLayer(
+ in_channels, in_channels // 8, padding=1, act=nn.Swish)
+ self.conv2 = ConvBNLayer(
+ in_channels // 8, hidden_dims, kernel_size=1, act=nn.Swish)
+
+ self.svtr_block = nn.LayerList([
+ Block(
+ dim=hidden_dims,
+ num_heads=num_heads,
+ mixer='Global',
+ HW=None,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ act_layer=nn.Swish,
+ attn_drop=attn_drop_rate,
+ drop_path=drop_path,
+ norm_layer='nn.LayerNorm',
+ epsilon=1e-05,
+ prenorm=False) for i in range(depth)
+ ])
+ self.norm = nn.LayerNorm(hidden_dims, epsilon=1e-6)
+ self.conv3 = ConvBNLayer(
+ hidden_dims, in_channels, kernel_size=1, act=nn.Swish)
+ # last conv-nxn, the input is concat of input tensor and conv3 output tensor
+ self.conv4 = ConvBNLayer(
+ 2 * in_channels, in_channels // 8, padding=1, act=nn.Swish)
+
+ self.conv1x1 = ConvBNLayer(
+ in_channels // 8, dims, kernel_size=1, act=nn.Swish)
+ self.out_channels = dims
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ zeros_(m.bias)
+ elif isinstance(m, nn.LayerNorm):
+ zeros_(m.bias)
+ ones_(m.weight)
+
+ def forward(self, x):
+ # for use guide
+ if self.use_guide:
+ z = x.clone()
+ z.stop_gradient = True
+ else:
+ z = x
+ # for short cut
+ h = z
+ # reduce dim
+ z = self.conv1(z)
+ z = self.conv2(z)
+ # SVTR global block
+ B, C, H, W = z.shape
+ z = z.flatten(2).transpose([0, 2, 1])
+ for blk in self.svtr_block:
+ z = blk(z)
+ z = self.norm(z)
+ # last stage
+ z = z.reshape([0, H, W, C]).transpose([0, 3, 1, 2])
+ z = self.conv3(z)
+ z = paddle.concat((h, z), axis=1)
+ z = self.conv1x1(self.conv4(z))
+ return z
+
+
class SequenceEncoder(nn.Layer):
def __init__(self, in_channels, encoder_type, hidden_size=48, **kwargs):
super(SequenceEncoder, self).__init__()
self.encoder_reshape = Im2Seq(in_channels)
self.out_channels = self.encoder_reshape.out_channels
+ self.encoder_type = encoder_type
if encoder_type == 'reshape':
self.only_reshape = True
else:
support_encoder_dict = {
'reshape': Im2Seq,
'fc': EncoderWithFC,
- 'rnn': EncoderWithRNN
+ 'rnn': EncoderWithRNN,
+ 'svtr': EncoderWithSVTR
}
assert encoder_type in support_encoder_dict, '{} must in {}'.format(
encoder_type, support_encoder_dict.keys())
-
- self.encoder = support_encoder_dict[encoder_type](
- self.encoder_reshape.out_channels, hidden_size)
+ if encoder_type == "svtr":
+ self.encoder = support_encoder_dict[encoder_type](
+ self.encoder_reshape.out_channels, **kwargs)
+ else:
+ self.encoder = support_encoder_dict[encoder_type](
+ self.encoder_reshape.out_channels, hidden_size)
self.out_channels = self.encoder.out_channels
self.only_reshape = False
def forward(self, x):
- x = self.encoder_reshape(x)
- if not self.only_reshape:
+ if self.encoder_type != 'svtr':
+ x = self.encoder_reshape(x)
+ if not self.only_reshape:
+ x = self.encoder(x)
+ return x
+ else:
x = self.encoder(x)
- return x
+ x = self.encoder_reshape(x)
+ return x
diff --git a/backend/ppocr/modeling/necks/table_fpn.py b/backend/ppocr/modeling/necks/table_fpn.py
new file mode 100644
index 00000000..734f15af
--- /dev/null
+++ b/backend/ppocr/modeling/necks/table_fpn.py
@@ -0,0 +1,110 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+import paddle.nn.functional as F
+from paddle import ParamAttr
+
+
+class TableFPN(nn.Layer):
+ def __init__(self, in_channels, out_channels, **kwargs):
+ super(TableFPN, self).__init__()
+ self.out_channels = 512
+ weight_attr = paddle.nn.initializer.KaimingUniform()
+ self.in2_conv = nn.Conv2D(
+ in_channels=in_channels[0],
+ out_channels=self.out_channels,
+ kernel_size=1,
+ weight_attr=ParamAttr(initializer=weight_attr),
+ bias_attr=False)
+ self.in3_conv = nn.Conv2D(
+ in_channels=in_channels[1],
+ out_channels=self.out_channels,
+ kernel_size=1,
+ stride = 1,
+ weight_attr=ParamAttr(initializer=weight_attr),
+ bias_attr=False)
+ self.in4_conv = nn.Conv2D(
+ in_channels=in_channels[2],
+ out_channels=self.out_channels,
+ kernel_size=1,
+ weight_attr=ParamAttr(initializer=weight_attr),
+ bias_attr=False)
+ self.in5_conv = nn.Conv2D(
+ in_channels=in_channels[3],
+ out_channels=self.out_channels,
+ kernel_size=1,
+ weight_attr=ParamAttr(initializer=weight_attr),
+ bias_attr=False)
+ self.p5_conv = nn.Conv2D(
+ in_channels=self.out_channels,
+ out_channels=self.out_channels // 4,
+ kernel_size=3,
+ padding=1,
+ weight_attr=ParamAttr(initializer=weight_attr),
+ bias_attr=False)
+ self.p4_conv = nn.Conv2D(
+ in_channels=self.out_channels,
+ out_channels=self.out_channels // 4,
+ kernel_size=3,
+ padding=1,
+ weight_attr=ParamAttr(initializer=weight_attr),
+ bias_attr=False)
+ self.p3_conv = nn.Conv2D(
+ in_channels=self.out_channels,
+ out_channels=self.out_channels // 4,
+ kernel_size=3,
+ padding=1,
+ weight_attr=ParamAttr(initializer=weight_attr),
+ bias_attr=False)
+ self.p2_conv = nn.Conv2D(
+ in_channels=self.out_channels,
+ out_channels=self.out_channels // 4,
+ kernel_size=3,
+ padding=1,
+ weight_attr=ParamAttr(initializer=weight_attr),
+ bias_attr=False)
+ self.fuse_conv = nn.Conv2D(
+ in_channels=self.out_channels * 4,
+ out_channels=512,
+ kernel_size=3,
+ padding=1,
+ weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False)
+
+ def forward(self, x):
+ c2, c3, c4, c5 = x
+
+ in5 = self.in5_conv(c5)
+ in4 = self.in4_conv(c4)
+ in3 = self.in3_conv(c3)
+ in2 = self.in2_conv(c2)
+
+ out4 = in4 + F.upsample(
+ in5, size=in4.shape[2:4], mode="nearest", align_mode=1) # 1/16
+ out3 = in3 + F.upsample(
+ out4, size=in3.shape[2:4], mode="nearest", align_mode=1) # 1/8
+ out2 = in2 + F.upsample(
+ out3, size=in2.shape[2:4], mode="nearest", align_mode=1) # 1/4
+
+ p4 = F.upsample(out4, size=in5.shape[2:4], mode="nearest", align_mode=1)
+ p3 = F.upsample(out3, size=in5.shape[2:4], mode="nearest", align_mode=1)
+ p2 = F.upsample(out2, size=in5.shape[2:4], mode="nearest", align_mode=1)
+ fuse = paddle.concat([in5, p4, p3, p2], axis=1)
+ fuse_conv = self.fuse_conv(fuse) * 0.005
+ return [c5 + fuse_conv]
diff --git a/backend/ppocr/modeling/transforms/__init__.py b/backend/ppocr/modeling/transforms/__init__.py
index 78eaeccc..405ab3cc 100755
--- a/backend/ppocr/modeling/transforms/__init__.py
+++ b/backend/ppocr/modeling/transforms/__init__.py
@@ -17,8 +17,9 @@
def build_transform(config):
from .tps import TPS
+ from .stn import STN_ON
- support_dict = ['TPS']
+ support_dict = ['TPS', 'STN_ON']
module_name = config.pop('name')
assert module_name in support_dict, Exception(
diff --git a/backend/ppocr/modeling/transforms/stn.py b/backend/ppocr/modeling/transforms/stn.py
new file mode 100644
index 00000000..6f2bdda0
--- /dev/null
+++ b/backend/ppocr/modeling/transforms/stn.py
@@ -0,0 +1,135 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/ayumiymk/aster.pytorch/blob/master/lib/models/stn_head.py
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import paddle
+from paddle import nn, ParamAttr
+from paddle.nn import functional as F
+import numpy as np
+
+from .tps_spatial_transformer import TPSSpatialTransformer
+
+
+def conv3x3_block(in_channels, out_channels, stride=1):
+ n = 3 * 3 * out_channels
+ w = math.sqrt(2. / n)
+ conv_layer = nn.Conv2D(
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ weight_attr=nn.initializer.Normal(
+ mean=0.0, std=w),
+ bias_attr=nn.initializer.Constant(0))
+ block = nn.Sequential(conv_layer, nn.BatchNorm2D(out_channels), nn.ReLU())
+ return block
+
+
+class STN(nn.Layer):
+ def __init__(self, in_channels, num_ctrlpoints, activation='none'):
+ super(STN, self).__init__()
+ self.in_channels = in_channels
+ self.num_ctrlpoints = num_ctrlpoints
+ self.activation = activation
+ self.stn_convnet = nn.Sequential(
+ conv3x3_block(in_channels, 32), #32x64
+ nn.MaxPool2D(
+ kernel_size=2, stride=2),
+ conv3x3_block(32, 64), #16x32
+ nn.MaxPool2D(
+ kernel_size=2, stride=2),
+ conv3x3_block(64, 128), # 8*16
+ nn.MaxPool2D(
+ kernel_size=2, stride=2),
+ conv3x3_block(128, 256), # 4*8
+ nn.MaxPool2D(
+ kernel_size=2, stride=2),
+ conv3x3_block(256, 256), # 2*4,
+ nn.MaxPool2D(
+ kernel_size=2, stride=2),
+ conv3x3_block(256, 256)) # 1*2
+ self.stn_fc1 = nn.Sequential(
+ nn.Linear(
+ 2 * 256,
+ 512,
+ weight_attr=nn.initializer.Normal(0, 0.001),
+ bias_attr=nn.initializer.Constant(0)),
+ nn.BatchNorm1D(512),
+ nn.ReLU())
+ fc2_bias = self.init_stn()
+ self.stn_fc2 = nn.Linear(
+ 512,
+ num_ctrlpoints * 2,
+ weight_attr=nn.initializer.Constant(0.0),
+ bias_attr=nn.initializer.Assign(fc2_bias))
+
+ def init_stn(self):
+ margin = 0.01
+ sampling_num_per_side = int(self.num_ctrlpoints / 2)
+ ctrl_pts_x = np.linspace(margin, 1. - margin, sampling_num_per_side)
+ ctrl_pts_y_top = np.ones(sampling_num_per_side) * margin
+ ctrl_pts_y_bottom = np.ones(sampling_num_per_side) * (1 - margin)
+ ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
+ ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
+ ctrl_points = np.concatenate(
+ [ctrl_pts_top, ctrl_pts_bottom], axis=0).astype(np.float32)
+ if self.activation == 'none':
+ pass
+ elif self.activation == 'sigmoid':
+ ctrl_points = -np.log(1. / ctrl_points - 1.)
+ ctrl_points = paddle.to_tensor(ctrl_points)
+ fc2_bias = paddle.reshape(
+ ctrl_points, shape=[ctrl_points.shape[0] * ctrl_points.shape[1]])
+ return fc2_bias
+
+ def forward(self, x):
+ x = self.stn_convnet(x)
+ batch_size, _, h, w = x.shape
+ x = paddle.reshape(x, shape=(batch_size, -1))
+ img_feat = self.stn_fc1(x)
+ x = self.stn_fc2(0.1 * img_feat)
+ if self.activation == 'sigmoid':
+ x = F.sigmoid(x)
+ x = paddle.reshape(x, shape=[-1, self.num_ctrlpoints, 2])
+ return img_feat, x
+
+
+class STN_ON(nn.Layer):
+ def __init__(self, in_channels, tps_inputsize, tps_outputsize,
+ num_control_points, tps_margins, stn_activation):
+ super(STN_ON, self).__init__()
+ self.tps = TPSSpatialTransformer(
+ output_image_size=tuple(tps_outputsize),
+ num_control_points=num_control_points,
+ margins=tuple(tps_margins))
+ self.stn_head = STN(in_channels=in_channels,
+ num_ctrlpoints=num_control_points,
+ activation=stn_activation)
+ self.tps_inputsize = tps_inputsize
+ self.out_channels = in_channels
+
+ def forward(self, image):
+ stn_input = paddle.nn.functional.interpolate(
+ image, self.tps_inputsize, mode="bilinear", align_corners=True)
+ stn_img_feat, ctrl_points = self.stn_head(stn_input)
+ x, _ = self.tps(image, ctrl_points)
+ return x
diff --git a/backend/ppocr/modeling/transforms/tps.py b/backend/ppocr/modeling/transforms/tps.py
index 78338edf..9bdab0f8 100644
--- a/backend/ppocr/modeling/transforms/tps.py
+++ b/backend/ppocr/modeling/transforms/tps.py
@@ -11,6 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+"""
+This code is refer from:
+https://github.com/clovaai/deep-text-recognition-benchmark/blob/master/modules/transformation.py
+"""
from __future__ import absolute_import
from __future__ import division
@@ -230,15 +234,9 @@ def build_P_paddle(self, I_r_size):
def build_inv_delta_C_paddle(self, C):
""" Return inv_delta_C which is needed to calculate T """
F = self.F
- hat_C = paddle.zeros((F, F), dtype='float64') # F x F
- for i in range(0, F):
- for j in range(i, F):
- if i == j:
- hat_C[i, j] = 1
- else:
- r = paddle.norm(C[i] - C[j])
- hat_C[i, j] = r
- hat_C[j, i] = r
+ hat_eye = paddle.eye(F, dtype='float64') # F x F
+ hat_C = paddle.norm(
+ C.reshape([1, F, 2]) - C.reshape([F, 1, 2]), axis=2) + hat_eye
hat_C = (hat_C**2) * paddle.log(hat_C)
delta_C = paddle.concat( # F+3 x F+3
[
diff --git a/backend/ppocr/modeling/transforms/tps_spatial_transformer.py b/backend/ppocr/modeling/transforms/tps_spatial_transformer.py
new file mode 100644
index 00000000..cb1cb10a
--- /dev/null
+++ b/backend/ppocr/modeling/transforms/tps_spatial_transformer.py
@@ -0,0 +1,156 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/ayumiymk/aster.pytorch/blob/master/lib/models/tps_spatial_transformer.py
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import paddle
+from paddle import nn, ParamAttr
+from paddle.nn import functional as F
+import numpy as np
+import itertools
+
+
+def grid_sample(input, grid, canvas=None):
+ input.stop_gradient = False
+ output = F.grid_sample(input, grid)
+ if canvas is None:
+ return output
+ else:
+ input_mask = paddle.ones(shape=input.shape)
+ output_mask = F.grid_sample(input_mask, grid)
+ padded_output = output * output_mask + canvas * (1 - output_mask)
+ return padded_output
+
+
+# phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2
+def compute_partial_repr(input_points, control_points):
+ N = input_points.shape[0]
+ M = control_points.shape[0]
+ pairwise_diff = paddle.reshape(
+ input_points, shape=[N, 1, 2]) - paddle.reshape(
+ control_points, shape=[1, M, 2])
+ # original implementation, very slow
+ # pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance
+ pairwise_diff_square = pairwise_diff * pairwise_diff
+ pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :,
+ 1]
+ repr_matrix = 0.5 * pairwise_dist * paddle.log(pairwise_dist)
+ # fix numerical error for 0 * log(0), substitute all nan with 0
+ mask = np.array(repr_matrix != repr_matrix)
+ repr_matrix[mask] = 0
+ return repr_matrix
+
+
+# output_ctrl_pts are specified, according to our task.
+def build_output_control_points(num_control_points, margins):
+ margin_x, margin_y = margins
+ num_ctrl_pts_per_side = num_control_points // 2
+ ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x, num_ctrl_pts_per_side)
+ ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y
+ ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y)
+ ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
+ ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
+ output_ctrl_pts_arr = np.concatenate(
+ [ctrl_pts_top, ctrl_pts_bottom], axis=0)
+ output_ctrl_pts = paddle.to_tensor(output_ctrl_pts_arr)
+ return output_ctrl_pts
+
+
+class TPSSpatialTransformer(nn.Layer):
+ def __init__(self,
+ output_image_size=None,
+ num_control_points=None,
+ margins=None):
+ super(TPSSpatialTransformer, self).__init__()
+ self.output_image_size = output_image_size
+ self.num_control_points = num_control_points
+ self.margins = margins
+
+ self.target_height, self.target_width = output_image_size
+ target_control_points = build_output_control_points(num_control_points,
+ margins)
+ N = num_control_points
+
+ # create padded kernel matrix
+ forward_kernel = paddle.zeros(shape=[N + 3, N + 3])
+ target_control_partial_repr = compute_partial_repr(
+ target_control_points, target_control_points)
+ target_control_partial_repr = paddle.cast(target_control_partial_repr,
+ forward_kernel.dtype)
+ forward_kernel[:N, :N] = target_control_partial_repr
+ forward_kernel[:N, -3] = 1
+ forward_kernel[-3, :N] = 1
+ target_control_points = paddle.cast(target_control_points,
+ forward_kernel.dtype)
+ forward_kernel[:N, -2:] = target_control_points
+ forward_kernel[-2:, :N] = paddle.transpose(
+ target_control_points, perm=[1, 0])
+ # compute inverse matrix
+ inverse_kernel = paddle.inverse(forward_kernel)
+
+ # create target cordinate matrix
+ HW = self.target_height * self.target_width
+ target_coordinate = list(
+ itertools.product(
+ range(self.target_height), range(self.target_width)))
+ target_coordinate = paddle.to_tensor(target_coordinate) # HW x 2
+ Y, X = paddle.split(
+ target_coordinate, target_coordinate.shape[1], axis=1)
+ Y = Y / (self.target_height - 1)
+ X = X / (self.target_width - 1)
+ target_coordinate = paddle.concat(
+ [X, Y], axis=1) # convert from (y, x) to (x, y)
+ target_coordinate_partial_repr = compute_partial_repr(
+ target_coordinate, target_control_points)
+ target_coordinate_repr = paddle.concat(
+ [
+ target_coordinate_partial_repr, paddle.ones(shape=[HW, 1]),
+ target_coordinate
+ ],
+ axis=1)
+
+ # register precomputed matrices
+ self.inverse_kernel = inverse_kernel
+ self.padding_matrix = paddle.zeros(shape=[3, 2])
+ self.target_coordinate_repr = target_coordinate_repr
+ self.target_control_points = target_control_points
+
+ def forward(self, input, source_control_points):
+ assert source_control_points.ndimension() == 3
+ assert source_control_points.shape[1] == self.num_control_points
+ assert source_control_points.shape[2] == 2
+ batch_size = paddle.shape(source_control_points)[0]
+
+ padding_matrix = paddle.expand(
+ self.padding_matrix, shape=[batch_size, 3, 2])
+ Y = paddle.concat([source_control_points, padding_matrix], 1)
+ mapping_matrix = paddle.matmul(self.inverse_kernel, Y)
+ source_coordinate = paddle.matmul(self.target_coordinate_repr,
+ mapping_matrix)
+
+ grid = paddle.reshape(
+ source_coordinate,
+ shape=[-1, self.target_height, self.target_width, 2])
+ grid = paddle.clip(grid, 0,
+ 1) # the source_control_points may be out of [0, 1].
+ # the input to grid_sample is normalized [-1, 1], but what we get is [0, 1]
+ grid = 2.0 * grid - 1.0
+ output_maps = grid_sample(input, grid, canvas=None)
+ return output_maps, source_coordinate
diff --git a/backend/ppocr/optimizer/__init__.py b/backend/ppocr/optimizer/__init__.py
index c729103a..a6bd2ebb 100644
--- a/backend/ppocr/optimizer/__init__.py
+++ b/backend/ppocr/optimizer/__init__.py
@@ -25,15 +25,12 @@
def build_lr_scheduler(lr_config, epochs, step_each_epoch):
from . import learning_rate
lr_config.update({'epochs': epochs, 'step_each_epoch': step_each_epoch})
- if 'name' in lr_config:
- lr_name = lr_config.pop('name')
- lr = getattr(learning_rate, lr_name)(**lr_config)()
- else:
- lr = lr_config['learning_rate']
+ lr_name = lr_config.pop('name', 'Const')
+ lr = getattr(learning_rate, lr_name)(**lr_config)()
return lr
-def build_optimizer(config, epochs, step_each_epoch, parameters):
+def build_optimizer(config, epochs, step_each_epoch, model):
from . import regularizer, optimizer
config = copy.deepcopy(config)
# step1 build lr
@@ -42,8 +39,12 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
# step2 build regularization
if 'regularizer' in config and config['regularizer'] is not None:
reg_config = config.pop('regularizer')
- reg_name = reg_config.pop('name') + 'Decay'
+ reg_name = reg_config.pop('name')
+ if not hasattr(regularizer, reg_name):
+ reg_name += 'Decay'
reg = getattr(regularizer, reg_name)(**reg_config)()
+ elif 'weight_decay' in config:
+ reg = config.pop('weight_decay')
else:
reg = None
@@ -58,4 +59,4 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
weight_decay=reg,
grad_clip=grad_clip,
**config)
- return optim(parameters), lr
+ return optim(model), lr
diff --git a/backend/ppocr/optimizer/learning_rate.py b/backend/ppocr/optimizer/learning_rate.py
index e1b10992..fe251f36 100644
--- a/backend/ppocr/optimizer/learning_rate.py
+++ b/backend/ppocr/optimizer/learning_rate.py
@@ -18,7 +18,7 @@
from __future__ import unicode_literals
from paddle.optimizer import lr
-from .lr_scheduler import CyclicalCosineDecay
+from .lr_scheduler import CyclicalCosineDecay, OneCycleDecay
class Linear(object):
@@ -226,3 +226,85 @@ def __call__(self):
end_lr=self.learning_rate,
last_epoch=self.last_epoch)
return learning_rate
+
+
+class OneCycle(object):
+ """
+ One Cycle learning rate decay
+ Args:
+ max_lr(float): Upper learning rate boundaries
+ epochs(int): total training epochs
+ step_each_epoch(int): steps each epoch
+ anneal_strategy(str): {‘cos’, ‘linear’} Specifies the annealing strategy: “cos” for cosine annealing, “linear” for linear annealing.
+ Default: ‘cos’
+ three_phase(bool): If True, use a third phase of the schedule to annihilate the learning rate according to ‘final_div_factor’
+ instead of modifying the second phase (the first two phases will be symmetrical about the step indicated by ‘pct_start’).
+ last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
+ """
+
+ def __init__(self,
+ max_lr,
+ epochs,
+ step_each_epoch,
+ anneal_strategy='cos',
+ three_phase=False,
+ warmup_epoch=0,
+ last_epoch=-1,
+ **kwargs):
+ super(OneCycle, self).__init__()
+ self.max_lr = max_lr
+ self.epochs = epochs
+ self.steps_per_epoch = step_each_epoch
+ self.anneal_strategy = anneal_strategy
+ self.three_phase = three_phase
+ self.last_epoch = last_epoch
+ self.warmup_epoch = round(warmup_epoch * step_each_epoch)
+
+ def __call__(self):
+ learning_rate = OneCycleDecay(
+ max_lr=self.max_lr,
+ epochs=self.epochs,
+ steps_per_epoch=self.steps_per_epoch,
+ anneal_strategy=self.anneal_strategy,
+ three_phase=self.three_phase,
+ last_epoch=self.last_epoch)
+ if self.warmup_epoch > 0:
+ learning_rate = lr.LinearWarmup(
+ learning_rate=learning_rate,
+ warmup_steps=self.warmup_epoch,
+ start_lr=0.0,
+ end_lr=self.max_lr,
+ last_epoch=self.last_epoch)
+ return learning_rate
+
+
+class Const(object):
+ """
+ Const learning rate decay
+ Args:
+ learning_rate(float): initial learning rate
+ step_each_epoch(int): steps each epoch
+ last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
+ """
+
+ def __init__(self,
+ learning_rate,
+ step_each_epoch,
+ warmup_epoch=0,
+ last_epoch=-1,
+ **kwargs):
+ super(Const, self).__init__()
+ self.learning_rate = learning_rate
+ self.last_epoch = last_epoch
+ self.warmup_epoch = round(warmup_epoch * step_each_epoch)
+
+ def __call__(self):
+ learning_rate = self.learning_rate
+ if self.warmup_epoch > 0:
+ learning_rate = lr.LinearWarmup(
+ learning_rate=learning_rate,
+ warmup_steps=self.warmup_epoch,
+ start_lr=0.0,
+ end_lr=self.learning_rate,
+ last_epoch=self.last_epoch)
+ return learning_rate
diff --git a/backend/ppocr/optimizer/lr_scheduler.py b/backend/ppocr/optimizer/lr_scheduler.py
index 21aec737..f62f1f3b 100644
--- a/backend/ppocr/optimizer/lr_scheduler.py
+++ b/backend/ppocr/optimizer/lr_scheduler.py
@@ -47,3 +47,116 @@ def get_lr(self):
lr = self.eta_min + 0.5 * (self.base_lr - self.eta_min) * \
(1 + math.cos(math.pi * reletive_epoch / self.cycle))
return lr
+
+
+class OneCycleDecay(LRScheduler):
+ """
+ One Cycle learning rate decay
+ A learning rate which can be referred in https://arxiv.org/abs/1708.07120
+ Code refered in https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
+ """
+
+ def __init__(self,
+ max_lr,
+ epochs=None,
+ steps_per_epoch=None,
+ pct_start=0.3,
+ anneal_strategy='cos',
+ div_factor=25.,
+ final_div_factor=1e4,
+ three_phase=False,
+ last_epoch=-1,
+ verbose=False):
+
+ # Validate total_steps
+ if epochs <= 0 or not isinstance(epochs, int):
+ raise ValueError(
+ "Expected positive integer epochs, but got {}".format(epochs))
+ if steps_per_epoch <= 0 or not isinstance(steps_per_epoch, int):
+ raise ValueError(
+ "Expected positive integer steps_per_epoch, but got {}".format(
+ steps_per_epoch))
+ self.total_steps = epochs * steps_per_epoch
+
+ self.max_lr = max_lr
+ self.initial_lr = self.max_lr / div_factor
+ self.min_lr = self.initial_lr / final_div_factor
+
+ if three_phase:
+ self._schedule_phases = [
+ {
+ 'end_step': float(pct_start * self.total_steps) - 1,
+ 'start_lr': self.initial_lr,
+ 'end_lr': self.max_lr,
+ },
+ {
+ 'end_step': float(2 * pct_start * self.total_steps) - 2,
+ 'start_lr': self.max_lr,
+ 'end_lr': self.initial_lr,
+ },
+ {
+ 'end_step': self.total_steps - 1,
+ 'start_lr': self.initial_lr,
+ 'end_lr': self.min_lr,
+ },
+ ]
+ else:
+ self._schedule_phases = [
+ {
+ 'end_step': float(pct_start * self.total_steps) - 1,
+ 'start_lr': self.initial_lr,
+ 'end_lr': self.max_lr,
+ },
+ {
+ 'end_step': self.total_steps - 1,
+ 'start_lr': self.max_lr,
+ 'end_lr': self.min_lr,
+ },
+ ]
+
+ # Validate pct_start
+ if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
+ raise ValueError(
+ "Expected float between 0 and 1 pct_start, but got {}".format(
+ pct_start))
+
+ # Validate anneal_strategy
+ if anneal_strategy not in ['cos', 'linear']:
+ raise ValueError(
+ "anneal_strategy must by one of 'cos' or 'linear', instead got {}".
+ format(anneal_strategy))
+ elif anneal_strategy == 'cos':
+ self.anneal_func = self._annealing_cos
+ elif anneal_strategy == 'linear':
+ self.anneal_func = self._annealing_linear
+
+ super(OneCycleDecay, self).__init__(max_lr, last_epoch, verbose)
+
+ def _annealing_cos(self, start, end, pct):
+ "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."
+ cos_out = math.cos(math.pi * pct) + 1
+ return end + (start - end) / 2.0 * cos_out
+
+ def _annealing_linear(self, start, end, pct):
+ "Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0."
+ return (end - start) * pct + start
+
+ def get_lr(self):
+ computed_lr = 0.0
+ step_num = self.last_epoch
+
+ if step_num > self.total_steps:
+ raise ValueError(
+ "Tried to step {} times. The specified number of total steps is {}"
+ .format(step_num + 1, self.total_steps))
+ start_step = 0
+ for i, phase in enumerate(self._schedule_phases):
+ end_step = phase['end_step']
+ if step_num <= end_step or i == len(self._schedule_phases) - 1:
+ pct = (step_num - start_step) / (end_step - start_step)
+ computed_lr = self.anneal_func(phase['start_lr'],
+ phase['end_lr'], pct)
+ break
+ start_step = phase['end_step']
+
+ return computed_lr
diff --git a/backend/ppocr/optimizer/optimizer.py b/backend/ppocr/optimizer/optimizer.py
index 8215b92d..dd8544e2 100644
--- a/backend/ppocr/optimizer/optimizer.py
+++ b/backend/ppocr/optimizer/optimizer.py
@@ -42,13 +42,16 @@ def __init__(self,
self.weight_decay = weight_decay
self.grad_clip = grad_clip
- def __call__(self, parameters):
+ def __call__(self, model):
+ train_params = [
+ param for param in model.parameters() if param.trainable is True
+ ]
opt = optim.Momentum(
learning_rate=self.learning_rate,
momentum=self.momentum,
weight_decay=self.weight_decay,
grad_clip=self.grad_clip,
- parameters=parameters)
+ parameters=train_params)
return opt
@@ -75,7 +78,10 @@ def __init__(self,
self.name = name
self.lazy_mode = lazy_mode
- def __call__(self, parameters):
+ def __call__(self, model):
+ train_params = [
+ param for param in model.parameters() if param.trainable is True
+ ]
opt = optim.Adam(
learning_rate=self.learning_rate,
beta1=self.beta1,
@@ -85,7 +91,7 @@ def __call__(self, parameters):
grad_clip=self.grad_clip,
name=self.name,
lazy_mode=self.lazy_mode,
- parameters=parameters)
+ parameters=train_params)
return opt
@@ -117,7 +123,10 @@ def __init__(self,
self.weight_decay = weight_decay
self.grad_clip = grad_clip
- def __call__(self, parameters):
+ def __call__(self, model):
+ train_params = [
+ param for param in model.parameters() if param.trainable is True
+ ]
opt = optim.RMSProp(
learning_rate=self.learning_rate,
momentum=self.momentum,
@@ -125,5 +134,101 @@ def __call__(self, parameters):
epsilon=self.epsilon,
weight_decay=self.weight_decay,
grad_clip=self.grad_clip,
- parameters=parameters)
+ parameters=train_params)
return opt
+
+
+class Adadelta(object):
+ def __init__(self,
+ learning_rate=0.001,
+ epsilon=1e-08,
+ rho=0.95,
+ parameter_list=None,
+ weight_decay=None,
+ grad_clip=None,
+ name=None,
+ **kwargs):
+ self.learning_rate = learning_rate
+ self.epsilon = epsilon
+ self.rho = rho
+ self.parameter_list = parameter_list
+ self.learning_rate = learning_rate
+ self.weight_decay = weight_decay
+ self.grad_clip = grad_clip
+ self.name = name
+
+ def __call__(self, model):
+ train_params = [
+ param for param in model.parameters() if param.trainable is True
+ ]
+ opt = optim.Adadelta(
+ learning_rate=self.learning_rate,
+ epsilon=self.epsilon,
+ rho=self.rho,
+ weight_decay=self.weight_decay,
+ grad_clip=self.grad_clip,
+ name=self.name,
+ parameters=train_params)
+ return opt
+
+
+class AdamW(object):
+ def __init__(self,
+ learning_rate=0.001,
+ beta1=0.9,
+ beta2=0.999,
+ epsilon=1e-8,
+ weight_decay=0.01,
+ multi_precision=False,
+ grad_clip=None,
+ no_weight_decay_name=None,
+ one_dim_param_no_weight_decay=False,
+ name=None,
+ lazy_mode=False,
+ **args):
+ super().__init__()
+ self.learning_rate = learning_rate
+ self.beta1 = beta1
+ self.beta2 = beta2
+ self.epsilon = epsilon
+ self.grad_clip = grad_clip
+ self.weight_decay = 0.01 if weight_decay is None else weight_decay
+ self.grad_clip = grad_clip
+ self.name = name
+ self.lazy_mode = lazy_mode
+ self.multi_precision = multi_precision
+ self.no_weight_decay_name_list = no_weight_decay_name.split(
+ ) if no_weight_decay_name else []
+ self.one_dim_param_no_weight_decay = one_dim_param_no_weight_decay
+
+ def __call__(self, model):
+ parameters = [
+ param for param in model.parameters() if param.trainable is True
+ ]
+
+ self.no_weight_decay_param_name_list = [
+ p.name for n, p in model.named_parameters()
+ if any(nd in n for nd in self.no_weight_decay_name_list)
+ ]
+
+ if self.one_dim_param_no_weight_decay:
+ self.no_weight_decay_param_name_list += [
+ p.name for n, p in model.named_parameters() if len(p.shape) == 1
+ ]
+
+ opt = optim.AdamW(
+ learning_rate=self.learning_rate,
+ beta1=self.beta1,
+ beta2=self.beta2,
+ epsilon=self.epsilon,
+ parameters=parameters,
+ weight_decay=self.weight_decay,
+ multi_precision=self.multi_precision,
+ grad_clip=self.grad_clip,
+ name=self.name,
+ lazy_mode=self.lazy_mode,
+ apply_decay_param_fun=self._apply_decay_param_fun)
+ return opt
+
+ def _apply_decay_param_fun(self, name):
+ return name not in self.no_weight_decay_param_name_list
diff --git a/backend/ppocr/optimizer/regularizer.py b/backend/ppocr/optimizer/regularizer.py
index c6396f33..2ce68f71 100644
--- a/backend/ppocr/optimizer/regularizer.py
+++ b/backend/ppocr/optimizer/regularizer.py
@@ -29,24 +29,23 @@ class L1Decay(object):
def __init__(self, factor=0.0):
super(L1Decay, self).__init__()
- self.regularization_coeff = factor
+ self.coeff = factor
def __call__(self):
- reg = paddle.regularizer.L1Decay(self.regularization_coeff)
+ reg = paddle.regularizer.L1Decay(self.coeff)
return reg
class L2Decay(object):
"""
- L2 Weight Decay Regularization, which encourages the weights to be sparse.
+ L2 Weight Decay Regularization, which helps to prevent the model over-fitting.
Args:
factor(float): regularization coeff. Default:0.0.
"""
def __init__(self, factor=0.0):
super(L2Decay, self).__init__()
- self.regularization_coeff = factor
+ self.coeff = float(factor)
def __call__(self):
- reg = paddle.regularizer.L2Decay(self.regularization_coeff)
- return reg
+ return self.coeff
\ No newline at end of file
diff --git a/backend/ppocr/postprocess/__init__.py b/backend/ppocr/postprocess/__init__.py
index 0156e438..f50b5f1c 100644
--- a/backend/ppocr/postprocess/__init__.py
+++ b/backend/ppocr/postprocess/__init__.py
@@ -21,21 +21,38 @@
__all__ = ['build_post_process']
+from .db_postprocess import DBPostProcess, DistillationDBPostProcess
+from .east_postprocess import EASTPostProcess
+from .sast_postprocess import SASTPostProcess
+from .fce_postprocess import FCEPostProcess
+from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
+ DistillationCTCLabelDecode, TableLabelDecode, NRTRLabelDecode, SARLabelDecode, \
+ SEEDLabelDecode, PRENLabelDecode
+from .cls_postprocess import ClsPostProcess
+from .pg_postprocess import PGPostProcess
+from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess
+from .vqa_token_re_layoutlm_postprocess import VQAReTokenLayoutLMPostProcess
-def build_post_process(config, global_config=None):
- from .db_postprocess import DBPostProcess
- from .east_postprocess import EASTPostProcess
- from .sast_postprocess import SASTPostProcess
- from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode
- from .cls_postprocess import ClsPostProcess
+def build_post_process(config, global_config=None):
support_dict = [
- 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
- 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode'
+ 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'FCEPostProcess',
+ 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode',
+ 'PGPostProcess', 'DistillationCTCLabelDecode', 'TableLabelDecode',
+ 'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
+ 'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess',
+ 'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode',
+ 'DistillationSARLabelDecode'
]
+ if config['name'] == 'PSEPostProcess':
+ from .pse_postprocess import PSEPostProcess
+ support_dict.append('PSEPostProcess')
+
config = copy.deepcopy(config)
module_name = config.pop('name')
+ if module_name == "None":
+ return
if global_config is not None:
config.update(global_config)
assert module_name in support_dict, Exception(
diff --git a/backend/ppocr/postprocess/cls_postprocess.py b/backend/ppocr/postprocess/cls_postprocess.py
index 77e7f46d..9a27ba08 100644
--- a/backend/ppocr/postprocess/cls_postprocess.py
+++ b/backend/ppocr/postprocess/cls_postprocess.py
@@ -17,17 +17,26 @@
class ClsPostProcess(object):
""" Convert between text-label and text-index """
- def __init__(self, label_list, **kwargs):
+ def __init__(self, label_list=None, key=None, **kwargs):
super(ClsPostProcess, self).__init__()
self.label_list = label_list
+ self.key = key
def __call__(self, preds, label=None, *args, **kwargs):
+ if self.key is not None:
+ preds = preds[self.key]
+
+ label_list = self.label_list
+ if label_list is None:
+ label_list = {idx: idx for idx in range(preds.shape[-1])}
+
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
+
pred_idxs = preds.argmax(axis=1)
- decode_out = [(self.label_list[idx], preds[i, idx])
+ decode_out = [(label_list[idx], preds[i, idx])
for i, idx in enumerate(pred_idxs)]
if label is None:
return decode_out
- label = [(self.label_list[idx], 1.0) for idx in label]
+ label = [(label_list[idx], 1.0) for idx in label]
return decode_out, label
diff --git a/backend/ppocr/postprocess/db_postprocess.py b/backend/ppocr/postprocess/db_postprocess.py
index 91729e0a..6542a1bf 100755
--- a/backend/ppocr/postprocess/db_postprocess.py
+++ b/backend/ppocr/postprocess/db_postprocess.py
@@ -11,7 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+"""
+This code is refered from:
+https://github.com/WenmuZhou/DBNet.pytorch/blob/master/post_processing/seg_detector_representer.py
+"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -34,12 +37,18 @@ def __init__(self,
max_candidates=1000,
unclip_ratio=2.0,
use_dilation=False,
+ score_mode="fast",
**kwargs):
self.thresh = thresh
self.box_thresh = box_thresh
self.max_candidates = max_candidates
self.unclip_ratio = unclip_ratio
self.min_size = 3
+ self.score_mode = score_mode
+ assert score_mode in [
+ "slow", "fast"
+ ], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
+
self.dilation_kernel = None if not use_dilation else np.array(
[[1, 1], [1, 1]])
@@ -69,7 +78,10 @@ def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
if sside < self.min_size:
continue
points = np.array(points)
- score = self.box_score_fast(pred, points.reshape(-1, 2))
+ if self.score_mode == "fast":
+ score = self.box_score_fast(pred, points.reshape(-1, 2))
+ else:
+ score = self.box_score_slow(pred, contour)
if self.box_thresh > score:
continue
@@ -120,12 +132,15 @@ def get_mini_boxes(self, contour):
return box, min(bounding_box[1])
def box_score_fast(self, bitmap, _box):
+ '''
+ box_score_fast: use bbox mean score as the mean score
+ '''
h, w = bitmap.shape[:2]
box = _box.copy()
- xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1)
- xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int), 0, w - 1)
- ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int), 0, h - 1)
- ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int), 0, h - 1)
+ xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int32), 0, w - 1)
+ xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int32), 0, w - 1)
+ ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int32), 0, h - 1)
+ ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int32), 0, h - 1)
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
box[:, 0] = box[:, 0] - xmin
@@ -133,6 +148,27 @@ def box_score_fast(self, bitmap, _box):
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
+ def box_score_slow(self, bitmap, contour):
+ '''
+ box_score_slow: use polyon mean score as the mean score
+ '''
+ h, w = bitmap.shape[:2]
+ contour = contour.copy()
+ contour = np.reshape(contour, (-1, 2))
+
+ xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
+ xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
+ ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
+ ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
+
+ mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
+
+ contour[:, 0] = contour[:, 0] - xmin
+ contour[:, 1] = contour[:, 1] - ymin
+
+ cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1)
+ return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
+
def __call__(self, outs_dict, shape_list):
pred = outs_dict['maps']
if isinstance(pred, paddle.Tensor):
@@ -154,3 +190,31 @@ def __call__(self, outs_dict, shape_list):
boxes_batch.append({'points': boxes})
return boxes_batch
+
+
+class DistillationDBPostProcess(object):
+ def __init__(self,
+ model_name=["student"],
+ key=None,
+ thresh=0.3,
+ box_thresh=0.6,
+ max_candidates=1000,
+ unclip_ratio=1.5,
+ use_dilation=False,
+ score_mode="fast",
+ **kwargs):
+ self.model_name = model_name
+ self.key = key
+ self.post_process = DBPostProcess(
+ thresh=thresh,
+ box_thresh=box_thresh,
+ max_candidates=max_candidates,
+ unclip_ratio=unclip_ratio,
+ use_dilation=use_dilation,
+ score_mode=score_mode)
+
+ def __call__(self, predicts, shape_list):
+ results = {}
+ for k in self.model_name:
+ results[k] = self.post_process(predicts[k], shape_list=shape_list)
+ return results
diff --git a/backend/ppocr/postprocess/east_postprocess.py b/backend/ppocr/postprocess/east_postprocess.py
index ceee727a..c194c81c 100755
--- a/backend/ppocr/postprocess/east_postprocess.py
+++ b/backend/ppocr/postprocess/east_postprocess.py
@@ -29,6 +29,7 @@ class EASTPostProcess(object):
"""
The post process for EAST.
"""
+
def __init__(self,
score_thresh=0.8,
cover_thresh=0.1,
@@ -38,11 +39,6 @@ def __init__(self,
self.score_thresh = score_thresh
self.cover_thresh = cover_thresh
self.nms_thresh = nms_thresh
-
- # c++ la-nms is faster, but only support python 3.5
- self.is_python35 = False
- if sys.version_info.major == 3 and sys.version_info.minor == 5:
- self.is_python35 = True
def restore_rectangle_quad(self, origin, geometry):
"""
@@ -64,6 +60,7 @@ def detect(self,
"""
restore text boxes from score map and geo map
"""
+
score_map = score_map[0]
geo_map = np.swapaxes(geo_map, 1, 0)
geo_map = np.swapaxes(geo_map, 1, 2)
@@ -79,10 +76,14 @@ def detect(self,
boxes = np.zeros((text_box_restored.shape[0], 9), dtype=np.float32)
boxes[:, :8] = text_box_restored.reshape((-1, 8))
boxes[:, 8] = score_map[xy_text[:, 0], xy_text[:, 1]]
- if self.is_python35:
+
+ try:
import lanms
boxes = lanms.merge_quadrangle_n9(boxes, nms_thresh)
- else:
+ except:
+ print(
+ 'you should install lanms by pip3 install lanms-nova to speed up nms_locality'
+ )
boxes = nms_locality(boxes.astype(np.float64), nms_thresh)
if boxes.shape[0] == 0:
return []
@@ -139,4 +140,4 @@ def __call__(self, outs_dict, shape_list):
continue
boxes_norm.append(box)
dt_boxes_list.append({'points': np.array(boxes_norm)})
- return dt_boxes_list
\ No newline at end of file
+ return dt_boxes_list
diff --git a/backend/ppocr/postprocess/fce_postprocess.py b/backend/ppocr/postprocess/fce_postprocess.py
new file mode 100755
index 00000000..8e0716f9
--- /dev/null
+++ b/backend/ppocr/postprocess/fce_postprocess.py
@@ -0,0 +1,241 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/open-mmlab/mmocr/blob/v0.3.0/mmocr/models/textdet/postprocess/wrapper.py
+"""
+
+import cv2
+import paddle
+import numpy as np
+from numpy.fft import ifft
+from ppocr.utils.poly_nms import poly_nms, valid_boundary
+
+
+def fill_hole(input_mask):
+ h, w = input_mask.shape
+ canvas = np.zeros((h + 2, w + 2), np.uint8)
+ canvas[1:h + 1, 1:w + 1] = input_mask.copy()
+
+ mask = np.zeros((h + 4, w + 4), np.uint8)
+
+ cv2.floodFill(canvas, mask, (0, 0), 1)
+ canvas = canvas[1:h + 1, 1:w + 1].astype(np.bool)
+
+ return ~canvas | input_mask
+
+
+def fourier2poly(fourier_coeff, num_reconstr_points=50):
+ """ Inverse Fourier transform
+ Args:
+ fourier_coeff (ndarray): Fourier coefficients shaped (n, 2k+1),
+ with n and k being candidates number and Fourier degree
+ respectively.
+ num_reconstr_points (int): Number of reconstructed polygon points.
+ Returns:
+ Polygons (ndarray): The reconstructed polygons shaped (n, n')
+ """
+
+ a = np.zeros((len(fourier_coeff), num_reconstr_points), dtype='complex')
+ k = (len(fourier_coeff[0]) - 1) // 2
+
+ a[:, 0:k + 1] = fourier_coeff[:, k:]
+ a[:, -k:] = fourier_coeff[:, :k]
+
+ poly_complex = ifft(a) * num_reconstr_points
+ polygon = np.zeros((len(fourier_coeff), num_reconstr_points, 2))
+ polygon[:, :, 0] = poly_complex.real
+ polygon[:, :, 1] = poly_complex.imag
+ return polygon.astype('int32').reshape((len(fourier_coeff), -1))
+
+
+class FCEPostProcess(object):
+ """
+ The post process for FCENet.
+ """
+
+ def __init__(self,
+ scales,
+ fourier_degree=5,
+ num_reconstr_points=50,
+ decoding_type='fcenet',
+ score_thr=0.3,
+ nms_thr=0.1,
+ alpha=1.0,
+ beta=1.0,
+ box_type='poly',
+ **kwargs):
+
+ self.scales = scales
+ self.fourier_degree = fourier_degree
+ self.num_reconstr_points = num_reconstr_points
+ self.decoding_type = decoding_type
+ self.score_thr = score_thr
+ self.nms_thr = nms_thr
+ self.alpha = alpha
+ self.beta = beta
+ self.box_type = box_type
+
+ def __call__(self, preds, shape_list):
+ score_maps = []
+ for key, value in preds.items():
+ if isinstance(value, paddle.Tensor):
+ value = value.numpy()
+ cls_res = value[:, :4, :, :]
+ reg_res = value[:, 4:, :, :]
+ score_maps.append([cls_res, reg_res])
+
+ return self.get_boundary(score_maps, shape_list)
+
+ def resize_boundary(self, boundaries, scale_factor):
+ """Rescale boundaries via scale_factor.
+
+ Args:
+ boundaries (list[list[float]]): The boundary list. Each boundary
+ with size 2k+1 with k>=4.
+ scale_factor(ndarray): The scale factor of size (4,).
+
+ Returns:
+ boundaries (list[list[float]]): The scaled boundaries.
+ """
+ boxes = []
+ scores = []
+ for b in boundaries:
+ sz = len(b)
+ valid_boundary(b, True)
+ scores.append(b[-1])
+ b = (np.array(b[:sz - 1]) *
+ (np.tile(scale_factor[:2], int(
+ (sz - 1) / 2)).reshape(1, sz - 1))).flatten().tolist()
+ boxes.append(np.array(b).reshape([-1, 2]))
+
+ return np.array(boxes, dtype=np.float32), scores
+
+ def get_boundary(self, score_maps, shape_list):
+ assert len(score_maps) == len(self.scales)
+ boundaries = []
+ for idx, score_map in enumerate(score_maps):
+ scale = self.scales[idx]
+ boundaries = boundaries + self._get_boundary_single(score_map,
+ scale)
+
+ # nms
+ boundaries = poly_nms(boundaries, self.nms_thr)
+ boundaries, scores = self.resize_boundary(
+ boundaries, (1 / shape_list[0, 2:]).tolist()[::-1])
+
+ boxes_batch = [dict(points=boundaries, scores=scores)]
+ return boxes_batch
+
+ def _get_boundary_single(self, score_map, scale):
+ assert len(score_map) == 2
+ assert score_map[1].shape[1] == 4 * self.fourier_degree + 2
+
+ return self.fcenet_decode(
+ preds=score_map,
+ fourier_degree=self.fourier_degree,
+ num_reconstr_points=self.num_reconstr_points,
+ scale=scale,
+ alpha=self.alpha,
+ beta=self.beta,
+ box_type=self.box_type,
+ score_thr=self.score_thr,
+ nms_thr=self.nms_thr)
+
+ def fcenet_decode(self,
+ preds,
+ fourier_degree,
+ num_reconstr_points,
+ scale,
+ alpha=1.0,
+ beta=2.0,
+ box_type='poly',
+ score_thr=0.3,
+ nms_thr=0.1):
+ """Decoding predictions of FCENet to instances.
+
+ Args:
+ preds (list(Tensor)): The head output tensors.
+ fourier_degree (int): The maximum Fourier transform degree k.
+ num_reconstr_points (int): The points number of the polygon
+ reconstructed from predicted Fourier coefficients.
+ scale (int): The down-sample scale of the prediction.
+ alpha (float) : The parameter to calculate final scores. Score_{final}
+ = (Score_{text region} ^ alpha)
+ * (Score_{text center region}^ beta)
+ beta (float) : The parameter to calculate final score.
+ box_type (str): Boundary encoding type 'poly' or 'quad'.
+ score_thr (float) : The threshold used to filter out the final
+ candidates.
+ nms_thr (float) : The threshold of nms.
+
+ Returns:
+ boundaries (list[list[float]]): The instance boundary and confidence
+ list.
+ """
+ assert isinstance(preds, list)
+ assert len(preds) == 2
+ assert box_type in ['poly', 'quad']
+
+ cls_pred = preds[0][0]
+ tr_pred = cls_pred[0:2]
+ tcl_pred = cls_pred[2:]
+
+ reg_pred = preds[1][0].transpose([1, 2, 0])
+ x_pred = reg_pred[:, :, :2 * fourier_degree + 1]
+ y_pred = reg_pred[:, :, 2 * fourier_degree + 1:]
+
+ score_pred = (tr_pred[1]**alpha) * (tcl_pred[1]**beta)
+ tr_pred_mask = (score_pred) > score_thr
+ tr_mask = fill_hole(tr_pred_mask)
+
+ tr_contours, _ = cv2.findContours(
+ tr_mask.astype(np.uint8), cv2.RETR_TREE,
+ cv2.CHAIN_APPROX_SIMPLE) # opencv4
+
+ mask = np.zeros_like(tr_mask)
+ boundaries = []
+ for cont in tr_contours:
+ deal_map = mask.copy().astype(np.int8)
+ cv2.drawContours(deal_map, [cont], -1, 1, -1)
+
+ score_map = score_pred * deal_map
+ score_mask = score_map > 0
+ xy_text = np.argwhere(score_mask)
+ dxy = xy_text[:, 1] + xy_text[:, 0] * 1j
+
+ x, y = x_pred[score_mask], y_pred[score_mask]
+ c = x + y * 1j
+ c[:, fourier_degree] = c[:, fourier_degree] + dxy
+ c *= scale
+
+ polygons = fourier2poly(c, num_reconstr_points)
+ score = score_map[score_mask].reshape(-1, 1)
+ polygons = poly_nms(np.hstack((polygons, score)).tolist(), nms_thr)
+
+ boundaries = boundaries + polygons
+
+ boundaries = poly_nms(boundaries, nms_thr)
+
+ if box_type == 'quad':
+ new_boundaries = []
+ for boundary in boundaries:
+ poly = np.array(boundary[:-1]).reshape(-1, 2).astype(np.float32)
+ score = boundary[-1]
+ points = cv2.boxPoints(cv2.minAreaRect(poly))
+ points = np.int0(points)
+ new_boundaries.append(points.reshape(-1).tolist() + [score])
+ boundaries = new_boundaries
+
+ return boundaries
diff --git a/backend/ppocr/postprocess/locality_aware_nms.py b/backend/ppocr/postprocess/locality_aware_nms.py
index 53280cc1..d305ef68 100644
--- a/backend/ppocr/postprocess/locality_aware_nms.py
+++ b/backend/ppocr/postprocess/locality_aware_nms.py
@@ -1,5 +1,6 @@
"""
Locality aware nms.
+This code is refered from: https://github.com/songdejia/EAST/blob/master/locality_aware_nms.py
"""
import numpy as np
diff --git a/backend/ppocr/postprocess/pg_postprocess.py b/backend/ppocr/postprocess/pg_postprocess.py
new file mode 100644
index 00000000..0b145518
--- /dev/null
+++ b/backend/ppocr/postprocess/pg_postprocess.py
@@ -0,0 +1,52 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import sys
+
+__dir__ = os.path.dirname(__file__)
+sys.path.append(__dir__)
+sys.path.append(os.path.join(__dir__, '..'))
+from ppocr.utils.e2e_utils.pgnet_pp_utils import PGNet_PostProcess
+
+
+class PGPostProcess(object):
+ """
+ The post process for PGNet.
+ """
+
+ def __init__(self, character_dict_path, valid_set, score_thresh, mode,
+ **kwargs):
+ self.character_dict_path = character_dict_path
+ self.valid_set = valid_set
+ self.score_thresh = score_thresh
+ self.mode = mode
+
+ # c++ la-nms is faster, but only support python 3.5
+ self.is_python35 = False
+ if sys.version_info.major == 3 and sys.version_info.minor == 5:
+ self.is_python35 = True
+
+ def __call__(self, outs_dict, shape_list):
+ post = PGNet_PostProcess(self.character_dict_path, self.valid_set,
+ self.score_thresh, outs_dict, shape_list)
+ if self.mode == 'fast':
+ data = post.pg_postprocess_fast()
+ else:
+ data = post.pg_postprocess_slow()
+ return data
diff --git a/backend/ppocr/postprocess/pse_postprocess/__init__.py b/backend/ppocr/postprocess/pse_postprocess/__init__.py
new file mode 100644
index 00000000..680473bf
--- /dev/null
+++ b/backend/ppocr/postprocess/pse_postprocess/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .pse_postprocess import PSEPostProcess
\ No newline at end of file
diff --git a/backend/ppocr/postprocess/pse_postprocess/pse/README.md b/backend/ppocr/postprocess/pse_postprocess/pse/README.md
new file mode 100644
index 00000000..6a19d5d1
--- /dev/null
+++ b/backend/ppocr/postprocess/pse_postprocess/pse/README.md
@@ -0,0 +1,6 @@
+## 编译
+This code is refer from:
+https://github.com/whai362/PSENet/blob/python3/models/post_processing/pse
+```python
+python3 setup.py build_ext --inplace
+```
diff --git a/backend/ppocr/postprocess/pse_postprocess/pse/__init__.py b/backend/ppocr/postprocess/pse_postprocess/pse/__init__.py
new file mode 100644
index 00000000..1903a914
--- /dev/null
+++ b/backend/ppocr/postprocess/pse_postprocess/pse/__init__.py
@@ -0,0 +1,29 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import sys
+import os
+import subprocess
+
+python_path = sys.executable
+
+ori_path = os.getcwd()
+os.chdir('ppocr/postprocess/pse_postprocess/pse')
+if subprocess.call(
+ '{} setup.py build_ext --inplace'.format(python_path), shell=True) != 0:
+ raise RuntimeError(
+ 'Cannot compile pse: {}, if your system is windows, you need to install all the default components of `desktop development using C++` in visual studio 2019+'.
+ format(os.path.dirname(os.path.realpath(__file__))))
+os.chdir(ori_path)
+
+from .pse import pse
diff --git a/backend/ppocr/postprocess/pse_postprocess/pse/pse.pyx b/backend/ppocr/postprocess/pse_postprocess/pse/pse.pyx
new file mode 100644
index 00000000..b2be49e9
--- /dev/null
+++ b/backend/ppocr/postprocess/pse_postprocess/pse/pse.pyx
@@ -0,0 +1,70 @@
+
+import numpy as np
+import cv2
+cimport numpy as np
+cimport cython
+cimport libcpp
+cimport libcpp.pair
+cimport libcpp.queue
+from libcpp.pair cimport *
+from libcpp.queue cimport *
+
+@cython.boundscheck(False)
+@cython.wraparound(False)
+cdef np.ndarray[np.int32_t, ndim=2] _pse(np.ndarray[np.uint8_t, ndim=3] kernels,
+ np.ndarray[np.int32_t, ndim=2] label,
+ int kernel_num,
+ int label_num,
+ float min_area=0):
+ cdef np.ndarray[np.int32_t, ndim=2] pred
+ pred = np.zeros((label.shape[0], label.shape[1]), dtype=np.int32)
+
+ for label_idx in range(1, label_num):
+ if np.sum(label == label_idx) < min_area:
+ label[label == label_idx] = 0
+
+ cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] que = \
+ queue[libcpp.pair.pair[np.int16_t,np.int16_t]]()
+ cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] nxt_que = \
+ queue[libcpp.pair.pair[np.int16_t,np.int16_t]]()
+ cdef np.int16_t* dx = [-1, 1, 0, 0]
+ cdef np.int16_t* dy = [0, 0, -1, 1]
+ cdef np.int16_t tmpx, tmpy
+
+ points = np.array(np.where(label > 0)).transpose((1, 0))
+ for point_idx in range(points.shape[0]):
+ tmpx, tmpy = points[point_idx, 0], points[point_idx, 1]
+ que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy))
+ pred[tmpx, tmpy] = label[tmpx, tmpy]
+
+ cdef libcpp.pair.pair[np.int16_t,np.int16_t] cur
+ cdef int cur_label
+ for kernel_idx in range(kernel_num - 1, -1, -1):
+ while not que.empty():
+ cur = que.front()
+ que.pop()
+ cur_label = pred[cur.first, cur.second]
+
+ is_edge = True
+ for j in range(4):
+ tmpx = cur.first + dx[j]
+ tmpy = cur.second + dy[j]
+ if tmpx < 0 or tmpx >= label.shape[0] or tmpy < 0 or tmpy >= label.shape[1]:
+ continue
+ if kernels[kernel_idx, tmpx, tmpy] == 0 or pred[tmpx, tmpy] > 0:
+ continue
+
+ que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy))
+ pred[tmpx, tmpy] = cur_label
+ is_edge = False
+ if is_edge:
+ nxt_que.push(cur)
+
+ que, nxt_que = nxt_que, que
+
+ return pred
+
+def pse(kernels, min_area):
+ kernel_num = kernels.shape[0]
+ label_num, label = cv2.connectedComponents(kernels[-1], connectivity=4)
+ return _pse(kernels[:-1], label, kernel_num, label_num, min_area)
\ No newline at end of file
diff --git a/backend/ppocr/postprocess/pse_postprocess/pse/setup.py b/backend/ppocr/postprocess/pse_postprocess/pse/setup.py
new file mode 100644
index 00000000..03746782
--- /dev/null
+++ b/backend/ppocr/postprocess/pse_postprocess/pse/setup.py
@@ -0,0 +1,14 @@
+from distutils.core import setup, Extension
+from Cython.Build import cythonize
+import numpy
+
+setup(ext_modules=cythonize(Extension(
+ 'pse',
+ sources=['pse.pyx'],
+ language='c++',
+ include_dirs=[numpy.get_include()],
+ library_dirs=[],
+ libraries=[],
+ extra_compile_args=['-O3'],
+ extra_link_args=[]
+)))
diff --git a/backend/ppocr/postprocess/pse_postprocess/pse_postprocess.py b/backend/ppocr/postprocess/pse_postprocess/pse_postprocess.py
new file mode 100755
index 00000000..34f1b8c9
--- /dev/null
+++ b/backend/ppocr/postprocess/pse_postprocess/pse_postprocess.py
@@ -0,0 +1,118 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/whai362/PSENet/blob/python3/models/head/psenet_head.py
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import cv2
+import paddle
+from paddle.nn import functional as F
+
+from ppocr.postprocess.pse_postprocess.pse import pse
+
+
+class PSEPostProcess(object):
+ """
+ The post process for PSE.
+ """
+
+ def __init__(self,
+ thresh=0.5,
+ box_thresh=0.85,
+ min_area=16,
+ box_type='quad',
+ scale=4,
+ **kwargs):
+ assert box_type in ['quad', 'poly'], 'Only quad and poly is supported'
+ self.thresh = thresh
+ self.box_thresh = box_thresh
+ self.min_area = min_area
+ self.box_type = box_type
+ self.scale = scale
+
+ def __call__(self, outs_dict, shape_list):
+ pred = outs_dict['maps']
+ if not isinstance(pred, paddle.Tensor):
+ pred = paddle.to_tensor(pred)
+ pred = F.interpolate(
+ pred, scale_factor=4 // self.scale, mode='bilinear')
+
+ score = F.sigmoid(pred[:, 0, :, :])
+
+ kernels = (pred > self.thresh).astype('float32')
+ text_mask = kernels[:, 0, :, :]
+ kernels[:, 0:, :, :] = kernels[:, 0:, :, :] * text_mask
+
+ score = score.numpy()
+ kernels = kernels.numpy().astype(np.uint8)
+
+ boxes_batch = []
+ for batch_index in range(pred.shape[0]):
+ boxes, scores = self.boxes_from_bitmap(score[batch_index],
+ kernels[batch_index],
+ shape_list[batch_index])
+
+ boxes_batch.append({'points': boxes, 'scores': scores})
+ return boxes_batch
+
+ def boxes_from_bitmap(self, score, kernels, shape):
+ label = pse(kernels, self.min_area)
+ return self.generate_box(score, label, shape)
+
+ def generate_box(self, score, label, shape):
+ src_h, src_w, ratio_h, ratio_w = shape
+ label_num = np.max(label) + 1
+
+ boxes = []
+ scores = []
+ for i in range(1, label_num):
+ ind = label == i
+ points = np.array(np.where(ind)).transpose((1, 0))[:, ::-1]
+
+ if points.shape[0] < self.min_area:
+ label[ind] = 0
+ continue
+
+ score_i = np.mean(score[ind])
+ if score_i < self.box_thresh:
+ label[ind] = 0
+ continue
+
+ if self.box_type == 'quad':
+ rect = cv2.minAreaRect(points)
+ bbox = cv2.boxPoints(rect)
+ elif self.box_type == 'poly':
+ box_height = np.max(points[:, 1]) + 10
+ box_width = np.max(points[:, 0]) + 10
+
+ mask = np.zeros((box_height, box_width), np.uint8)
+ mask[points[:, 1], points[:, 0]] = 255
+
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL,
+ cv2.CHAIN_APPROX_SIMPLE)
+ bbox = np.squeeze(contours[0], 1)
+ else:
+ raise NotImplementedError
+
+ bbox[:, 0] = np.clip(np.round(bbox[:, 0] / ratio_w), 0, src_w)
+ bbox[:, 1] = np.clip(np.round(bbox[:, 1] / ratio_h), 0, src_h)
+ boxes.append(bbox)
+ scores.append(score_i)
+ return boxes, scores
diff --git a/backend/ppocr/postprocess/rec_postprocess.py b/backend/ppocr/postprocess/rec_postprocess.py
index b0517982..bf0fd890 100644
--- a/backend/ppocr/postprocess/rec_postprocess.py
+++ b/backend/ppocr/postprocess/rec_postprocess.py
@@ -11,54 +11,34 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import numpy as np
-import string
import paddle
from paddle.nn import functional as F
+import re
class BaseRecLabelDecode(object):
""" Convert between text-label and text-index """
- def __init__(self,
- character_dict_path=None,
- character_type='ch',
- use_space_char=False):
- support_character_type = [
- 'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
- 'it', 'es', 'pt', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs_latin',
- 'oc', 'rs_cyrillic', 'bg', 'uk', 'be', 'te', 'kn', 'ch_tra', 'hi',
- 'mr', 'ne', 'EN'
- ]
- assert character_type in support_character_type, "Only {} are supported now but get {}".format(
- support_character_type, character_type)
-
+ def __init__(self, character_dict_path=None, use_space_char=False):
self.beg_str = "sos"
self.end_str = "eos"
- if character_type == "en":
+ self.character_str = []
+ if character_dict_path is None:
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
- elif character_type == "EN_symbol":
- # same with ASTER setting (use 94 char).
- self.character_str = string.printable[:-6]
- dict_character = list(self.character_str)
- elif character_type in support_character_type:
- self.character_str = ""
- assert character_dict_path is not None, "character_dict_path should not be None when character_type is {}".format(
- character_type)
+ else:
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
line = line.decode('utf-8').strip("\n").strip("\r\n")
- self.character_str += line
+ self.character_str.append(line)
if use_space_char:
- self.character_str += " "
+ self.character_str.append(" ")
dict_character = list(self.character_str)
- else:
- raise NotImplementedError
- self.character_type = character_type
dict_character = self.add_special_char(dict_character)
self.dict = {}
for i, char in enumerate(dict_character):
@@ -74,24 +54,26 @@ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
ignored_tokens = self.get_ignored_tokens()
batch_size = len(text_index)
for batch_idx in range(batch_size):
- char_list = []
- conf_list = []
- for idx in range(len(text_index[batch_idx])):
- if text_index[batch_idx][idx] in ignored_tokens:
- continue
- if is_remove_duplicate:
- # only for predict
- if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
- batch_idx][idx]:
- continue
- char_list.append(self.character[int(text_index[batch_idx][
- idx])])
- if text_prob is not None:
- conf_list.append(text_prob[batch_idx][idx])
- else:
- conf_list.append(1)
+ selection = np.ones(len(text_index[batch_idx]), dtype=bool)
+ if is_remove_duplicate:
+ selection[1:] = text_index[batch_idx][1:] != text_index[
+ batch_idx][:-1]
+ for ignored_token in ignored_tokens:
+ selection &= text_index[batch_idx] != ignored_token
+
+ char_list = [
+ self.character[text_id]
+ for text_id in text_index[batch_idx][selection]
+ ]
+ if text_prob is not None:
+ conf_list = text_prob[batch_idx][selection]
+ else:
+ conf_list = [1] * len(selection)
+ if len(conf_list) == 0:
+ conf_list = [0]
+
text = ''.join(char_list)
- result_list.append((text, np.mean(conf_list)))
+ result_list.append((text, np.mean(conf_list).tolist()))
return result_list
def get_ignored_tokens(self):
@@ -101,15 +83,14 @@ def get_ignored_tokens(self):
class CTCLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
- def __init__(self,
- character_dict_path=None,
- character_type='ch',
- use_space_char=False,
+ def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(CTCLabelDecode, self).__init__(character_dict_path,
- character_type, use_space_char)
+ use_space_char)
def __call__(self, preds, label=None, *args, **kwargs):
+ if isinstance(preds, tuple) or isinstance(preds, list):
+ preds = preds[-1]
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
preds_idx = preds.argmax(axis=2)
@@ -125,16 +106,111 @@ def add_special_char(self, dict_character):
return dict_character
-class AttnLabelDecode(BaseRecLabelDecode):
- """ Convert between text-label and text-index """
+class DistillationCTCLabelDecode(CTCLabelDecode):
+ """
+ Convert
+ Convert between text-label and text-index
+ """
def __init__(self,
character_dict_path=None,
- character_type='ch',
use_space_char=False,
+ model_name=["student"],
+ key=None,
+ multi_head=False,
+ **kwargs):
+ super(DistillationCTCLabelDecode, self).__init__(character_dict_path,
+ use_space_char)
+ if not isinstance(model_name, list):
+ model_name = [model_name]
+ self.model_name = model_name
+
+ self.key = key
+ self.multi_head = multi_head
+
+ def __call__(self, preds, label=None, *args, **kwargs):
+ output = dict()
+ for name in self.model_name:
+ pred = preds[name]
+ if self.key is not None:
+ pred = pred[self.key]
+ if self.multi_head and isinstance(pred, dict):
+ pred = pred['ctc']
+ output[name] = super().__call__(pred, label=label, *args, **kwargs)
+ return output
+
+
+class NRTRLabelDecode(BaseRecLabelDecode):
+ """ Convert between text-label and text-index """
+
+ def __init__(self, character_dict_path=None, use_space_char=True, **kwargs):
+ super(NRTRLabelDecode, self).__init__(character_dict_path,
+ use_space_char)
+
+ def __call__(self, preds, label=None, *args, **kwargs):
+
+ if len(preds) == 2:
+ preds_id = preds[0]
+ preds_prob = preds[1]
+ if isinstance(preds_id, paddle.Tensor):
+ preds_id = preds_id.numpy()
+ if isinstance(preds_prob, paddle.Tensor):
+ preds_prob = preds_prob.numpy()
+ if preds_id[0][0] == 2:
+ preds_idx = preds_id[:, 1:]
+ preds_prob = preds_prob[:, 1:]
+ else:
+ preds_idx = preds_id
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
+ if label is None:
+ return text
+ label = self.decode(label[:, 1:])
+ else:
+ if isinstance(preds, paddle.Tensor):
+ preds = preds.numpy()
+ preds_idx = preds.argmax(axis=2)
+ preds_prob = preds.max(axis=2)
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
+ if label is None:
+ return text
+ label = self.decode(label[:, 1:])
+ return text, label
+
+ def add_special_char(self, dict_character):
+ dict_character = ['blank', '', '', ''] + dict_character
+ return dict_character
+
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
+ """ convert text-index into text-label. """
+ result_list = []
+ batch_size = len(text_index)
+ for batch_idx in range(batch_size):
+ char_list = []
+ conf_list = []
+ for idx in range(len(text_index[batch_idx])):
+ if text_index[batch_idx][idx] == 3: # end
+ break
+ try:
+ char_list.append(self.character[int(text_index[batch_idx][
+ idx])])
+ except:
+ continue
+ if text_prob is not None:
+ conf_list.append(text_prob[batch_idx][idx])
+ else:
+ conf_list.append(1)
+ text = ''.join(char_list)
+ result_list.append((text.lower(), np.mean(conf_list).tolist()))
+ return result_list
+
+
+class AttnLabelDecode(BaseRecLabelDecode):
+ """ Convert between text-label and text-index """
+
+ def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(AttnLabelDecode, self).__init__(character_dict_path,
- character_type, use_space_char)
+ use_space_char)
def add_special_char(self, dict_character):
self.beg_str = "sos"
@@ -169,7 +245,7 @@ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
else:
conf_list.append(1)
text = ''.join(char_list)
- result_list.append((text, np.mean(conf_list)))
+ result_list.append((text, np.mean(conf_list).tolist()))
return result_list
def __call__(self, preds, label=None, *args, **kwargs):
@@ -208,16 +284,95 @@ def get_beg_end_flag_idx(self, beg_or_end):
return idx
+class SEEDLabelDecode(BaseRecLabelDecode):
+ """ Convert between text-label and text-index """
+
+ def __init__(self, character_dict_path=None, use_space_char=False,
+ **kwargs):
+ super(SEEDLabelDecode, self).__init__(character_dict_path,
+ use_space_char)
+
+ def add_special_char(self, dict_character):
+ self.padding_str = "padding"
+ self.end_str = "eos"
+ self.unknown = "unknown"
+ dict_character = dict_character + [
+ self.end_str, self.padding_str, self.unknown
+ ]
+ return dict_character
+
+ def get_ignored_tokens(self):
+ end_idx = self.get_beg_end_flag_idx("eos")
+ return [end_idx]
+
+ def get_beg_end_flag_idx(self, beg_or_end):
+ if beg_or_end == "sos":
+ idx = np.array(self.dict[self.beg_str])
+ elif beg_or_end == "eos":
+ idx = np.array(self.dict[self.end_str])
+ else:
+ assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end
+ return idx
+
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
+ """ convert text-index into text-label. """
+ result_list = []
+ [end_idx] = self.get_ignored_tokens()
+ batch_size = len(text_index)
+ for batch_idx in range(batch_size):
+ char_list = []
+ conf_list = []
+ for idx in range(len(text_index[batch_idx])):
+ if int(text_index[batch_idx][idx]) == int(end_idx):
+ break
+ if is_remove_duplicate:
+ # only for predict
+ if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
+ batch_idx][idx]:
+ continue
+ char_list.append(self.character[int(text_index[batch_idx][
+ idx])])
+ if text_prob is not None:
+ conf_list.append(text_prob[batch_idx][idx])
+ else:
+ conf_list.append(1)
+ text = ''.join(char_list)
+ result_list.append((text, np.mean(conf_list).tolist()))
+ return result_list
+
+ def __call__(self, preds, label=None, *args, **kwargs):
+ """
+ text = self.decode(text)
+ if label is None:
+ return text
+ else:
+ label = self.decode(label, is_remove_duplicate=False)
+ return text, label
+ """
+ preds_idx = preds["rec_pred"]
+ if isinstance(preds_idx, paddle.Tensor):
+ preds_idx = preds_idx.numpy()
+ if "rec_pred_scores" in preds:
+ preds_idx = preds["rec_pred"]
+ preds_prob = preds["rec_pred_scores"]
+ else:
+ preds_idx = preds["rec_pred"].argmax(axis=2)
+ preds_prob = preds["rec_pred"].max(axis=2)
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
+ if label is None:
+ return text
+ label = self.decode(label, is_remove_duplicate=False)
+ return text, label
+
+
class SRNLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
- def __init__(self,
- character_dict_path=None,
- character_type='en',
- use_space_char=False,
+ def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(SRNLabelDecode, self).__init__(character_dict_path,
- character_type, use_space_char)
+ use_space_char)
+ self.max_text_length = kwargs.get('max_text_length', 25)
def __call__(self, preds, label=None, *args, **kwargs):
pred = preds['predict']
@@ -229,9 +384,9 @@ def __call__(self, preds, label=None, *args, **kwargs):
preds_idx = np.argmax(pred, axis=1)
preds_prob = np.max(pred, axis=1)
- preds_idx = np.reshape(preds_idx, [-1, 25])
+ preds_idx = np.reshape(preds_idx, [-1, self.max_text_length])
- preds_prob = np.reshape(preds_prob, [-1, 25])
+ preds_prob = np.reshape(preds_prob, [-1, self.max_text_length])
text = self.decode(preds_idx, preds_prob)
@@ -266,7 +421,7 @@ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
conf_list.append(1)
text = ''.join(char_list)
- result_list.append((text, np.mean(conf_list)))
+ result_list.append((text, np.mean(conf_list).tolist()))
return result_list
def add_special_char(self, dict_character):
@@ -287,3 +442,313 @@ def get_beg_end_flag_idx(self, beg_or_end):
assert False, "unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end
return idx
+
+
+class TableLabelDecode(object):
+ """ """
+
+ def __init__(self, character_dict_path, **kwargs):
+ list_character, list_elem = self.load_char_elem_dict(
+ character_dict_path)
+ list_character = self.add_special_char(list_character)
+ list_elem = self.add_special_char(list_elem)
+ self.dict_character = {}
+ self.dict_idx_character = {}
+ for i, char in enumerate(list_character):
+ self.dict_idx_character[i] = char
+ self.dict_character[char] = i
+ self.dict_elem = {}
+ self.dict_idx_elem = {}
+ for i, elem in enumerate(list_elem):
+ self.dict_idx_elem[i] = elem
+ self.dict_elem[elem] = i
+
+ def load_char_elem_dict(self, character_dict_path):
+ list_character = []
+ list_elem = []
+ with open(character_dict_path, "rb") as fin:
+ lines = fin.readlines()
+ substr = lines[0].decode('utf-8').strip("\n").strip("\r\n").split(
+ "\t")
+ character_num = int(substr[0])
+ elem_num = int(substr[1])
+ for cno in range(1, 1 + character_num):
+ character = lines[cno].decode('utf-8').strip("\n").strip("\r\n")
+ list_character.append(character)
+ for eno in range(1 + character_num, 1 + character_num + elem_num):
+ elem = lines[eno].decode('utf-8').strip("\n").strip("\r\n")
+ list_elem.append(elem)
+ return list_character, list_elem
+
+ def add_special_char(self, list_character):
+ self.beg_str = "sos"
+ self.end_str = "eos"
+ list_character = [self.beg_str] + list_character + [self.end_str]
+ return list_character
+
+ def __call__(self, preds):
+ structure_probs = preds['structure_probs']
+ loc_preds = preds['loc_preds']
+ if isinstance(structure_probs, paddle.Tensor):
+ structure_probs = structure_probs.numpy()
+ if isinstance(loc_preds, paddle.Tensor):
+ loc_preds = loc_preds.numpy()
+ structure_idx = structure_probs.argmax(axis=2)
+ structure_probs = structure_probs.max(axis=2)
+ structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(
+ structure_idx, structure_probs, 'elem')
+ res_html_code_list = []
+ res_loc_list = []
+ batch_num = len(structure_str)
+ for bno in range(batch_num):
+ res_loc = []
+ for sno in range(len(structure_str[bno])):
+ text = structure_str[bno][sno]
+ if text in ['', ' | 0 and tmp_elem_idx == end_idx:
+ break
+ if tmp_elem_idx in ignored_tokens:
+ continue
+
+ char_list.append(current_dict[tmp_elem_idx])
+ elem_pos_list.append(idx)
+ score_list.append(structure_probs[batch_idx, idx])
+ elem_idx_list.append(tmp_elem_idx)
+ result_list.append(char_list)
+ result_pos_list.append(elem_pos_list)
+ result_score_list.append(score_list)
+ result_elem_idx_list.append(elem_idx_list)
+ return result_list, result_pos_list, result_score_list, result_elem_idx_list
+
+ def get_ignored_tokens(self, char_or_elem):
+ beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
+ end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
+ return [beg_idx, end_idx]
+
+ def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
+ if char_or_elem == "char":
+ if beg_or_end == "beg":
+ idx = self.dict_character[self.beg_str]
+ elif beg_or_end == "end":
+ idx = self.dict_character[self.end_str]
+ else:
+ assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \
+ % beg_or_end
+ elif char_or_elem == "elem":
+ if beg_or_end == "beg":
+ idx = self.dict_elem[self.beg_str]
+ elif beg_or_end == "end":
+ idx = self.dict_elem[self.end_str]
+ else:
+ assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
+ % beg_or_end
+ else:
+ assert False, "Unsupport type %s in char_or_elem" \
+ % char_or_elem
+ return idx
+
+
+class SARLabelDecode(BaseRecLabelDecode):
+ """ Convert between text-label and text-index """
+
+ def __init__(self, character_dict_path=None, use_space_char=False,
+ **kwargs):
+ super(SARLabelDecode, self).__init__(character_dict_path,
+ use_space_char)
+
+ self.rm_symbol = kwargs.get('rm_symbol', False)
+
+ def add_special_char(self, dict_character):
+ beg_end_str = ""
+ unknown_str = ""
+ padding_str = ""
+ dict_character = dict_character + [unknown_str]
+ self.unknown_idx = len(dict_character) - 1
+ dict_character = dict_character + [beg_end_str]
+ self.start_idx = len(dict_character) - 1
+ self.end_idx = len(dict_character) - 1
+ dict_character = dict_character + [padding_str]
+ self.padding_idx = len(dict_character) - 1
+ return dict_character
+
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
+ """ convert text-index into text-label. """
+ result_list = []
+ ignored_tokens = self.get_ignored_tokens()
+
+ batch_size = len(text_index)
+ for batch_idx in range(batch_size):
+ char_list = []
+ conf_list = []
+ for idx in range(len(text_index[batch_idx])):
+ if text_index[batch_idx][idx] in ignored_tokens:
+ continue
+ if int(text_index[batch_idx][idx]) == int(self.end_idx):
+ if text_prob is None and idx == 0:
+ continue
+ else:
+ break
+ if is_remove_duplicate:
+ # only for predict
+ if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
+ batch_idx][idx]:
+ continue
+ char_list.append(self.character[int(text_index[batch_idx][
+ idx])])
+ if text_prob is not None:
+ conf_list.append(text_prob[batch_idx][idx])
+ else:
+ conf_list.append(1)
+ text = ''.join(char_list)
+ if self.rm_symbol:
+ comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]')
+ text = text.lower()
+ text = comp.sub('', text)
+ result_list.append((text, np.mean(conf_list).tolist()))
+ return result_list
+
+ def __call__(self, preds, label=None, *args, **kwargs):
+ if isinstance(preds, paddle.Tensor):
+ preds = preds.numpy()
+ preds_idx = preds.argmax(axis=2)
+ preds_prob = preds.max(axis=2)
+
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
+
+ if label is None:
+ return text
+ label = self.decode(label, is_remove_duplicate=False)
+ return text, label
+
+ def get_ignored_tokens(self):
+ return [self.padding_idx]
+
+
+class DistillationSARLabelDecode(SARLabelDecode):
+ """
+ Convert
+ Convert between text-label and text-index
+ """
+
+ def __init__(self,
+ character_dict_path=None,
+ use_space_char=False,
+ model_name=["student"],
+ key=None,
+ multi_head=False,
+ **kwargs):
+ super(DistillationSARLabelDecode, self).__init__(character_dict_path,
+ use_space_char)
+ if not isinstance(model_name, list):
+ model_name = [model_name]
+ self.model_name = model_name
+
+ self.key = key
+ self.multi_head = multi_head
+
+ def __call__(self, preds, label=None, *args, **kwargs):
+ output = dict()
+ for name in self.model_name:
+ pred = preds[name]
+ if self.key is not None:
+ pred = pred[self.key]
+ if self.multi_head and isinstance(pred, dict):
+ pred = pred['sar']
+ output[name] = super().__call__(pred, label=label, *args, **kwargs)
+ return output
+
+
+class PRENLabelDecode(BaseRecLabelDecode):
+ """ Convert between text-label and text-index """
+
+ def __init__(self, character_dict_path=None, use_space_char=False,
+ **kwargs):
+ super(PRENLabelDecode, self).__init__(character_dict_path,
+ use_space_char)
+
+ def add_special_char(self, dict_character):
+ padding_str = '' # 0
+ end_str = '' # 1
+ unknown_str = '' # 2
+
+ dict_character = [padding_str, end_str, unknown_str] + dict_character
+ self.padding_idx = 0
+ self.end_idx = 1
+ self.unknown_idx = 2
+
+ return dict_character
+
+ def decode(self, text_index, text_prob=None):
+ """ convert text-index into text-label. """
+ result_list = []
+ batch_size = len(text_index)
+
+ for batch_idx in range(batch_size):
+ char_list = []
+ conf_list = []
+ for idx in range(len(text_index[batch_idx])):
+ if text_index[batch_idx][idx] == self.end_idx:
+ break
+ if text_index[batch_idx][idx] in \
+ [self.padding_idx, self.unknown_idx]:
+ continue
+ char_list.append(self.character[int(text_index[batch_idx][
+ idx])])
+ if text_prob is not None:
+ conf_list.append(text_prob[batch_idx][idx])
+ else:
+ conf_list.append(1)
+
+ text = ''.join(char_list)
+ if len(text) > 0:
+ result_list.append((text, np.mean(conf_list).tolist()))
+ else:
+ # here confidence of empty recog result is 1
+ result_list.append(('', 1))
+ return result_list
+
+ def __call__(self, preds, label=None, *args, **kwargs):
+ preds = preds.numpy()
+ preds_idx = preds.argmax(axis=2)
+ preds_prob = preds.max(axis=2)
+ text = self.decode(preds_idx, preds_prob)
+ if label is None:
+ return text
+ label = self.decode(label)
+ return text, label
diff --git a/backend/ppocr/postprocess/sast_postprocess.py b/backend/ppocr/postprocess/sast_postprocess.py
index f011e7e5..bee75c05 100755
--- a/backend/ppocr/postprocess/sast_postprocess.py
+++ b/backend/ppocr/postprocess/sast_postprocess.py
@@ -18,6 +18,7 @@
import os
import sys
+
__dir__ = os.path.dirname(__file__)
sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..'))
@@ -49,12 +50,12 @@ def __init__(self,
self.shrink_ratio_of_width = shrink_ratio_of_width
self.expand_scale = expand_scale
self.tcl_map_thresh = tcl_map_thresh
-
+
# c++ la-nms is faster, but only support python 3.5
self.is_python35 = False
if sys.version_info.major == 3 and sys.version_info.minor == 5:
self.is_python35 = True
-
+
def point_pair2poly(self, point_pair_list):
"""
Transfer vertical point_pairs into poly point in clockwise.
@@ -66,31 +67,42 @@ def point_pair2poly(self, point_pair_list):
point_list[idx] = point_pair[0]
point_list[point_num - 1 - idx] = point_pair[1]
return np.array(point_list).reshape(-1, 2)
-
- def shrink_quad_along_width(self, quad, begin_width_ratio=0., end_width_ratio=1.):
+
+ def shrink_quad_along_width(self,
+ quad,
+ begin_width_ratio=0.,
+ end_width_ratio=1.):
"""
Generate shrink_quad_along_width.
"""
- ratio_pair = np.array([[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
+ ratio_pair = np.array(
+ [[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
-
+
def expand_poly_along_width(self, poly, shrink_ratio_of_width=0.3):
"""
expand poly along width.
"""
point_num = poly.shape[0]
- left_quad = np.array([poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
+ left_quad = np.array(
+ [poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
- (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
- left_quad_expand = self.shrink_quad_along_width(left_quad, left_ratio, 1.0)
- right_quad = np.array([poly[point_num // 2 - 2], poly[point_num // 2 - 1],
- poly[point_num // 2], poly[point_num // 2 + 1]], dtype=np.float32)
+ (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
+ left_quad_expand = self.shrink_quad_along_width(left_quad, left_ratio,
+ 1.0)
+ right_quad = np.array(
+ [
+ poly[point_num // 2 - 2], poly[point_num // 2 - 1],
+ poly[point_num // 2], poly[point_num // 2 + 1]
+ ],
+ dtype=np.float32)
right_ratio = 1.0 + \
- shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
- (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
- right_quad_expand = self.shrink_quad_along_width(right_quad, 0.0, right_ratio)
+ shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
+ (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
+ right_quad_expand = self.shrink_quad_along_width(right_quad, 0.0,
+ right_ratio)
poly[0] = left_quad_expand[0]
poly[-1] = left_quad_expand[-1]
poly[point_num // 2 - 1] = right_quad_expand[1]
@@ -100,7 +112,7 @@ def expand_poly_along_width(self, poly, shrink_ratio_of_width=0.3):
def restore_quad(self, tcl_map, tcl_map_thresh, tvo_map):
"""Restore quad."""
xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh)
- xy_text = xy_text[:, ::-1] # (n, 2)
+ xy_text = xy_text[:, ::-1] # (n, 2)
# Sort the text boxes via the y axis
xy_text = xy_text[np.argsort(xy_text[:, 1])]
@@ -112,7 +124,7 @@ def restore_quad(self, tcl_map, tcl_map_thresh, tvo_map):
point_num = int(tvo_map.shape[-1] / 2)
assert point_num == 4
tvo_map = tvo_map[xy_text[:, 1], xy_text[:, 0], :]
- xy_text_tile = np.tile(xy_text, (1, point_num)) # (n, point_num * 2)
+ xy_text_tile = np.tile(xy_text, (1, point_num)) # (n, point_num * 2)
quads = xy_text_tile - tvo_map
return scores, quads, xy_text
@@ -121,14 +133,12 @@ def quad_area(self, quad):
"""
compute area of a quad.
"""
- edge = [
- (quad[1][0] - quad[0][0]) * (quad[1][1] + quad[0][1]),
- (quad[2][0] - quad[1][0]) * (quad[2][1] + quad[1][1]),
- (quad[3][0] - quad[2][0]) * (quad[3][1] + quad[2][1]),
- (quad[0][0] - quad[3][0]) * (quad[0][1] + quad[3][1])
- ]
+ edge = [(quad[1][0] - quad[0][0]) * (quad[1][1] + quad[0][1]),
+ (quad[2][0] - quad[1][0]) * (quad[2][1] + quad[1][1]),
+ (quad[3][0] - quad[2][0]) * (quad[3][1] + quad[2][1]),
+ (quad[0][0] - quad[3][0]) * (quad[0][1] + quad[3][1])]
return np.sum(edge) / 2.
-
+
def nms(self, dets):
if self.is_python35:
import lanms
@@ -141,7 +151,7 @@ def cluster_by_quads_tco(self, tcl_map, tcl_map_thresh, quads, tco_map):
"""
Cluster pixels in tcl_map based on quads.
"""
- instance_count = quads.shape[0] + 1 # contain background
+ instance_count = quads.shape[0] + 1 # contain background
instance_label_map = np.zeros(tcl_map.shape[:2], dtype=np.int32)
if instance_count == 1:
return instance_count, instance_label_map
@@ -149,18 +159,19 @@ def cluster_by_quads_tco(self, tcl_map, tcl_map_thresh, quads, tco_map):
# predict text center
xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh)
n = xy_text.shape[0]
- xy_text = xy_text[:, ::-1] # (n, 2)
- tco = tco_map[xy_text[:, 1], xy_text[:, 0], :] # (n, 2)
+ xy_text = xy_text[:, ::-1] # (n, 2)
+ tco = tco_map[xy_text[:, 1], xy_text[:, 0], :] # (n, 2)
pred_tc = xy_text - tco
-
+
# get gt text center
m = quads.shape[0]
- gt_tc = np.mean(quads, axis=1) # (m, 2)
+ gt_tc = np.mean(quads, axis=1) # (m, 2)
- pred_tc_tile = np.tile(pred_tc[:, np.newaxis, :], (1, m, 1)) # (n, m, 2)
- gt_tc_tile = np.tile(gt_tc[np.newaxis, :, :], (n, 1, 1)) # (n, m, 2)
- dist_mat = np.linalg.norm(pred_tc_tile - gt_tc_tile, axis=2) # (n, m)
- xy_text_assign = np.argmin(dist_mat, axis=1) + 1 # (n,)
+ pred_tc_tile = np.tile(pred_tc[:, np.newaxis, :],
+ (1, m, 1)) # (n, m, 2)
+ gt_tc_tile = np.tile(gt_tc[np.newaxis, :, :], (n, 1, 1)) # (n, m, 2)
+ dist_mat = np.linalg.norm(pred_tc_tile - gt_tc_tile, axis=2) # (n, m)
+ xy_text_assign = np.argmin(dist_mat, axis=1) + 1 # (n,)
instance_label_map[xy_text[:, 1], xy_text[:, 0]] = xy_text_assign
return instance_count, instance_label_map
@@ -169,26 +180,47 @@ def estimate_sample_pts_num(self, quad, xy_text):
"""
Estimate sample points number.
"""
- eh = (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2])) / 2.0
- ew = (np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])) / 2.0
+ eh = (np.linalg.norm(quad[0] - quad[3]) +
+ np.linalg.norm(quad[1] - quad[2])) / 2.0
+ ew = (np.linalg.norm(quad[0] - quad[1]) +
+ np.linalg.norm(quad[2] - quad[3])) / 2.0
dense_sample_pts_num = max(2, int(ew))
- dense_xy_center_line = xy_text[np.linspace(0, xy_text.shape[0] - 1, dense_sample_pts_num,
- endpoint=True, dtype=np.float32).astype(np.int32)]
-
- dense_xy_center_line_diff = dense_xy_center_line[1:] - dense_xy_center_line[:-1]
- estimate_arc_len = np.sum(np.linalg.norm(dense_xy_center_line_diff, axis=1))
+ dense_xy_center_line = xy_text[np.linspace(
+ 0,
+ xy_text.shape[0] - 1,
+ dense_sample_pts_num,
+ endpoint=True,
+ dtype=np.float32).astype(np.int32)]
+
+ dense_xy_center_line_diff = dense_xy_center_line[
+ 1:] - dense_xy_center_line[:-1]
+ estimate_arc_len = np.sum(
+ np.linalg.norm(
+ dense_xy_center_line_diff, axis=1))
sample_pts_num = max(2, int(estimate_arc_len / eh))
return sample_pts_num
- def detect_sast(self, tcl_map, tvo_map, tbo_map, tco_map, ratio_w, ratio_h, src_w, src_h,
- shrink_ratio_of_width=0.3, tcl_map_thresh=0.5, offset_expand=1.0, out_strid=4.0):
+ def detect_sast(self,
+ tcl_map,
+ tvo_map,
+ tbo_map,
+ tco_map,
+ ratio_w,
+ ratio_h,
+ src_w,
+ src_h,
+ shrink_ratio_of_width=0.3,
+ tcl_map_thresh=0.5,
+ offset_expand=1.0,
+ out_strid=4.0):
"""
first resize the tcl_map, tvo_map and tbo_map to the input_size, then restore the polys
"""
# restore quad
- scores, quads, xy_text = self.restore_quad(tcl_map, tcl_map_thresh, tvo_map)
+ scores, quads, xy_text = self.restore_quad(tcl_map, tcl_map_thresh,
+ tvo_map)
dets = np.hstack((quads, scores)).astype(np.float32, copy=False)
dets = self.nms(dets)
if dets.shape[0] == 0:
@@ -202,7 +234,8 @@ def detect_sast(self, tcl_map, tvo_map, tbo_map, tco_map, ratio_w, ratio_h, src_
# instance segmentation
# instance_count, instance_label_map = cv2.connectedComponents(tcl_map.astype(np.uint8), connectivity=8)
- instance_count, instance_label_map = self.cluster_by_quads_tco(tcl_map, tcl_map_thresh, quads, tco_map)
+ instance_count, instance_label_map = self.cluster_by_quads_tco(
+ tcl_map, tcl_map_thresh, quads, tco_map)
# restore single poly with tcl instance.
poly_list = []
@@ -212,10 +245,10 @@ def detect_sast(self, tcl_map, tvo_map, tbo_map, tco_map, ratio_w, ratio_h, src_
q_area = quad_areas[instance_idx - 1]
if q_area < 5:
continue
-
+
#
- len1 = float(np.linalg.norm(quad[0] -quad[1]))
- len2 = float(np.linalg.norm(quad[1] -quad[2]))
+ len1 = float(np.linalg.norm(quad[0] - quad[1]))
+ len2 = float(np.linalg.norm(quad[1] - quad[2]))
min_len = min(len1, len2)
if min_len < 3:
continue
@@ -225,16 +258,18 @@ def detect_sast(self, tcl_map, tvo_map, tbo_map, tco_map, ratio_w, ratio_h, src_
continue
# filter low confidence instance
- xy_text_scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0]
+ xy_text_scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0]
if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.1:
- # if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.05:
+ # if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.05:
continue
# sort xy_text
- left_center_pt = np.array([[(quad[0, 0] + quad[-1, 0]) / 2.0,
- (quad[0, 1] + quad[-1, 1]) / 2.0]]) # (1, 2)
- right_center_pt = np.array([[(quad[1, 0] + quad[2, 0]) / 2.0,
- (quad[1, 1] + quad[2, 1]) / 2.0]]) # (1, 2)
+ left_center_pt = np.array(
+ [[(quad[0, 0] + quad[-1, 0]) / 2.0,
+ (quad[0, 1] + quad[-1, 1]) / 2.0]]) # (1, 2)
+ right_center_pt = np.array(
+ [[(quad[1, 0] + quad[2, 0]) / 2.0,
+ (quad[1, 1] + quad[2, 1]) / 2.0]]) # (1, 2)
proj_unit_vec = (right_center_pt - left_center_pt) / \
(np.linalg.norm(right_center_pt - left_center_pt) + 1e-6)
proj_value = np.sum(xy_text * proj_unit_vec, axis=1)
@@ -245,33 +280,45 @@ def detect_sast(self, tcl_map, tvo_map, tbo_map, tco_map, ratio_w, ratio_h, src_
sample_pts_num = self.estimate_sample_pts_num(quad, xy_text)
else:
sample_pts_num = self.sample_pts_num
- xy_center_line = xy_text[np.linspace(0, xy_text.shape[0] - 1, sample_pts_num,
- endpoint=True, dtype=np.float32).astype(np.int32)]
+ xy_center_line = xy_text[np.linspace(
+ 0,
+ xy_text.shape[0] - 1,
+ sample_pts_num,
+ endpoint=True,
+ dtype=np.float32).astype(np.int32)]
point_pair_list = []
for x, y in xy_center_line:
# get corresponding offset
offset = tbo_map[y, x, :].reshape(2, 2)
if offset_expand != 1.0:
- offset_length = np.linalg.norm(offset, axis=1, keepdims=True)
- expand_length = np.clip(offset_length * (offset_expand - 1), a_min=0.5, a_max=3.0)
+ offset_length = np.linalg.norm(
+ offset, axis=1, keepdims=True)
+ expand_length = np.clip(
+ offset_length * (offset_expand - 1),
+ a_min=0.5,
+ a_max=3.0)
offset_detal = offset / offset_length * expand_length
- offset = offset + offset_detal
- # original point
+ offset = offset + offset_detal
+ # original point
ori_yx = np.array([y, x], dtype=np.float32)
- point_pair = (ori_yx + offset)[:, ::-1]* out_strid / np.array([ratio_w, ratio_h]).reshape(-1, 2)
+ point_pair = (ori_yx + offset)[:, ::-1] * out_strid / np.array(
+ [ratio_w, ratio_h]).reshape(-1, 2)
point_pair_list.append(point_pair)
# ndarry: (x, 2), expand poly along width
detected_poly = self.point_pair2poly(point_pair_list)
- detected_poly = self.expand_poly_along_width(detected_poly, shrink_ratio_of_width)
- detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w)
- detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h)
+ detected_poly = self.expand_poly_along_width(detected_poly,
+ shrink_ratio_of_width)
+ detected_poly[:, 0] = np.clip(
+ detected_poly[:, 0], a_min=0, a_max=src_w)
+ detected_poly[:, 1] = np.clip(
+ detected_poly[:, 1], a_min=0, a_max=src_h)
poly_list.append(detected_poly)
return poly_list
- def __call__(self, outs_dict, shape_list):
+ def __call__(self, outs_dict, shape_list):
score_list = outs_dict['f_score']
border_list = outs_dict['f_border']
tvo_list = outs_dict['f_tvo']
@@ -281,20 +328,28 @@ def __call__(self, outs_dict, shape_list):
border_list = border_list.numpy()
tvo_list = tvo_list.numpy()
tco_list = tco_list.numpy()
-
+
img_num = len(shape_list)
poly_lists = []
for ino in range(img_num):
- p_score = score_list[ino].transpose((1,2,0))
- p_border = border_list[ino].transpose((1,2,0))
- p_tvo = tvo_list[ino].transpose((1,2,0))
- p_tco = tco_list[ino].transpose((1,2,0))
+ p_score = score_list[ino].transpose((1, 2, 0))
+ p_border = border_list[ino].transpose((1, 2, 0))
+ p_tvo = tvo_list[ino].transpose((1, 2, 0))
+ p_tco = tco_list[ino].transpose((1, 2, 0))
src_h, src_w, ratio_h, ratio_w = shape_list[ino]
- poly_list = self.detect_sast(p_score, p_tvo, p_border, p_tco, ratio_w, ratio_h, src_w, src_h,
- shrink_ratio_of_width=self.shrink_ratio_of_width,
- tcl_map_thresh=self.tcl_map_thresh, offset_expand=self.expand_scale)
+ poly_list = self.detect_sast(
+ p_score,
+ p_tvo,
+ p_border,
+ p_tco,
+ ratio_w,
+ ratio_h,
+ src_w,
+ src_h,
+ shrink_ratio_of_width=self.shrink_ratio_of_width,
+ tcl_map_thresh=self.tcl_map_thresh,
+ offset_expand=self.expand_scale)
poly_lists.append({'points': np.array(poly_list)})
return poly_lists
-
diff --git a/backend/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py b/backend/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py
new file mode 100644
index 00000000..1d55d13d
--- /dev/null
+++ b/backend/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py
@@ -0,0 +1,51 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import paddle
+
+
+class VQAReTokenLayoutLMPostProcess(object):
+ """ Convert between text-label and text-index """
+
+ def __init__(self, **kwargs):
+ super(VQAReTokenLayoutLMPostProcess, self).__init__()
+
+ def __call__(self, preds, label=None, *args, **kwargs):
+ if label is not None:
+ return self._metric(preds, label)
+ else:
+ return self._infer(preds, *args, **kwargs)
+
+ def _metric(self, preds, label):
+ return preds['pred_relations'], label[6], label[5]
+
+ def _infer(self, preds, *args, **kwargs):
+ ser_results = kwargs['ser_results']
+ entity_idx_dict_batch = kwargs['entity_idx_dict_batch']
+ pred_relations = preds['pred_relations']
+
+ # merge relations and ocr info
+ results = []
+ for pred_relation, ser_result, entity_idx_dict in zip(
+ pred_relations, ser_results, entity_idx_dict_batch):
+ result = []
+ used_tail_id = []
+ for relation in pred_relation:
+ if relation['tail_id'] in used_tail_id:
+ continue
+ used_tail_id.append(relation['tail_id'])
+ ocr_info_head = ser_result[entity_idx_dict[relation['head_id']]]
+ ocr_info_tail = ser_result[entity_idx_dict[relation['tail_id']]]
+ result.append((ocr_info_head, ocr_info_tail))
+ results.append(result)
+ return results
diff --git a/backend/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py b/backend/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py
new file mode 100644
index 00000000..782cdea6
--- /dev/null
+++ b/backend/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py
@@ -0,0 +1,93 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import numpy as np
+import paddle
+from ppocr.utils.utility import load_vqa_bio_label_maps
+
+
+class VQASerTokenLayoutLMPostProcess(object):
+ """ Convert between text-label and text-index """
+
+ def __init__(self, class_path, **kwargs):
+ super(VQASerTokenLayoutLMPostProcess, self).__init__()
+ label2id_map, self.id2label_map = load_vqa_bio_label_maps(class_path)
+
+ self.label2id_map_for_draw = dict()
+ for key in label2id_map:
+ if key.startswith("I-"):
+ self.label2id_map_for_draw[key] = label2id_map["B" + key[1:]]
+ else:
+ self.label2id_map_for_draw[key] = label2id_map[key]
+
+ self.id2label_map_for_show = dict()
+ for key in self.label2id_map_for_draw:
+ val = self.label2id_map_for_draw[key]
+ if key == "O":
+ self.id2label_map_for_show[val] = key
+ if key.startswith("B-") or key.startswith("I-"):
+ self.id2label_map_for_show[val] = key[2:]
+ else:
+ self.id2label_map_for_show[val] = key
+
+ def __call__(self, preds, batch=None, *args, **kwargs):
+ if isinstance(preds, paddle.Tensor):
+ preds = preds.numpy()
+
+ if batch is not None:
+ return self._metric(preds, batch[1])
+ else:
+ return self._infer(preds, **kwargs)
+
+ def _metric(self, preds, label):
+ pred_idxs = preds.argmax(axis=2)
+ decode_out_list = [[] for _ in range(pred_idxs.shape[0])]
+ label_decode_out_list = [[] for _ in range(pred_idxs.shape[0])]
+
+ for i in range(pred_idxs.shape[0]):
+ for j in range(pred_idxs.shape[1]):
+ if label[i, j] != -100:
+ label_decode_out_list[i].append(self.id2label_map[label[i,
+ j]])
+ decode_out_list[i].append(self.id2label_map[pred_idxs[i,
+ j]])
+ return decode_out_list, label_decode_out_list
+
+ def _infer(self, preds, attention_masks, segment_offset_ids, ocr_infos):
+ results = []
+
+ for pred, attention_mask, segment_offset_id, ocr_info in zip(
+ preds, attention_masks, segment_offset_ids, ocr_infos):
+ pred = np.argmax(pred, axis=1)
+ pred = [self.id2label_map[idx] for idx in pred]
+
+ for idx in range(len(segment_offset_id)):
+ if idx == 0:
+ start_id = 0
+ else:
+ start_id = segment_offset_id[idx - 1]
+
+ end_id = segment_offset_id[idx]
+
+ curr_pred = pred[start_id:end_id]
+ curr_pred = [self.label2id_map_for_draw[p] for p in curr_pred]
+
+ if len(curr_pred) <= 0:
+ pred_id = 0
+ else:
+ counts = np.bincount(curr_pred)
+ pred_id = np.argmax(counts)
+ ocr_info[idx]["pred_id"] = int(pred_id)
+ ocr_info[idx]["pred"] = self.id2label_map_for_show[int(pred_id)]
+ results.append(ocr_info)
+ return results
diff --git a/backend/ppocr/utils/dict/arabic_dict.txt b/backend/ppocr/utils/dict/arabic_dict.txt
new file mode 100644
index 00000000..916d421c
--- /dev/null
+++ b/backend/ppocr/utils/dict/arabic_dict.txt
@@ -0,0 +1,161 @@
+!
+#
+$
+%
+&
+'
+(
++
+,
+-
+.
+/
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+:
+?
+@
+A
+B
+C
+D
+E
+F
+G
+H
+I
+J
+K
+L
+M
+N
+O
+P
+Q
+R
+S
+T
+U
+V
+W
+X
+Y
+Z
+_
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+É
+é
+ء
+آ
+أ
+ؤ
+إ
+ئ
+ا
+ب
+ة
+ت
+ث
+ج
+ح
+خ
+د
+ذ
+ر
+ز
+س
+ش
+ص
+ض
+ط
+ظ
+ع
+غ
+ف
+ق
+ك
+ل
+م
+ن
+ه
+و
+ى
+ي
+ً
+ٌ
+ٍ
+َ
+ُ
+ِ
+ّ
+ْ
+ٓ
+ٔ
+ٰ
+ٱ
+ٹ
+پ
+چ
+ڈ
+ڑ
+ژ
+ک
+ڭ
+گ
+ں
+ھ
+ۀ
+ہ
+ۂ
+ۃ
+ۆ
+ۇ
+ۈ
+ۋ
+ی
+ې
+ے
+ۓ
+ە
+١
+٢
+٣
+٤
+٥
+٦
+٧
+٨
+٩
diff --git a/backend/ppocr/utils/ppocr_keys_v1.txt b/backend/ppocr/utils/dict/ch_dict.txt
similarity index 100%
rename from backend/ppocr/utils/ppocr_keys_v1.txt
rename to backend/ppocr/utils/dict/ch_dict.txt
diff --git a/backend/ppocr/utils/dict/ch_tra_dict.txt b/backend/ppocr/utils/dict/chinese_cht_dict.txt
similarity index 100%
rename from backend/ppocr/utils/dict/ch_tra_dict.txt
rename to backend/ppocr/utils/dict/chinese_cht_dict.txt
diff --git a/backend/ppocr/utils/dict/cyrillic_dict.txt b/backend/ppocr/utils/dict/cyrillic_dict.txt
new file mode 100644
index 00000000..2b6f6649
--- /dev/null
+++ b/backend/ppocr/utils/dict/cyrillic_dict.txt
@@ -0,0 +1,163 @@
+
+!
+#
+$
+%
+&
+'
+(
++
+,
+-
+.
+/
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+:
+?
+@
+A
+B
+C
+D
+E
+F
+G
+H
+I
+J
+K
+L
+M
+N
+O
+P
+Q
+R
+S
+T
+U
+V
+W
+X
+Y
+Z
+_
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+É
+é
+Ё
+Є
+І
+Ј
+Љ
+Ў
+А
+Б
+В
+Г
+Д
+Е
+Ж
+З
+И
+Й
+К
+Л
+М
+Н
+О
+П
+Р
+С
+Т
+У
+Ф
+Х
+Ц
+Ч
+Ш
+Щ
+Ъ
+Ы
+Ь
+Э
+Ю
+Я
+а
+б
+в
+г
+д
+е
+ж
+з
+и
+й
+к
+л
+м
+н
+о
+п
+р
+с
+т
+у
+ф
+х
+ц
+ч
+ш
+щ
+ъ
+ы
+ь
+э
+ю
+я
+ё
+ђ
+є
+і
+ј
+љ
+њ
+ћ
+ў
+џ
+Ґ
+ґ
diff --git a/backend/ppocr/utils/dict/devanagari_dict.txt b/backend/ppocr/utils/dict/devanagari_dict.txt
new file mode 100644
index 00000000..f5592306
--- /dev/null
+++ b/backend/ppocr/utils/dict/devanagari_dict.txt
@@ -0,0 +1,167 @@
+
+!
+#
+$
+%
+&
+'
+(
++
+,
+-
+.
+/
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+:
+?
+@
+A
+B
+C
+D
+E
+F
+G
+H
+I
+J
+K
+L
+M
+N
+O
+P
+Q
+R
+S
+T
+U
+V
+W
+X
+Y
+Z
+_
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+É
+é
+ँ
+ं
+ः
+अ
+आ
+इ
+ई
+उ
+ऊ
+ऋ
+ए
+ऐ
+ऑ
+ओ
+औ
+क
+ख
+ग
+घ
+ङ
+च
+छ
+ज
+झ
+ञ
+ट
+ठ
+ड
+ढ
+ण
+त
+थ
+द
+ध
+न
+ऩ
+प
+फ
+ब
+भ
+म
+य
+र
+ऱ
+ल
+ळ
+व
+श
+ष
+स
+ह
+़
+ा
+ि
+ी
+ु
+ू
+ृ
+ॅ
+े
+ै
+ॉ
+ो
+ौ
+्
+॒
+क़
+ख़
+ग़
+ज़
+ड़
+ढ़
+फ़
+ॠ
+।
+०
+१
+२
+३
+४
+५
+६
+७
+८
+९
+॰
diff --git a/backend/ppocr/utils/dict/en_dict.txt b/backend/ppocr/utils/dict/en_dict.txt
index 6fbd99f4..7677d31b 100644
--- a/backend/ppocr/utils/dict/en_dict.txt
+++ b/backend/ppocr/utils/dict/en_dict.txt
@@ -8,32 +8,13 @@
7
8
9
-a
-b
-c
-d
-e
-f
-g
-h
-i
-j
-k
-l
-m
-n
-o
-p
-q
-r
-s
-t
-u
-v
-w
-x
-y
-z
+:
+;
+<
+=
+>
+?
+@
A
B
C
@@ -60,4 +41,55 @@ W
X
Y
Z
+[
+\
+]
+^
+_
+`
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+{
+|
+}
+~
+!
+"
+#
+$
+%
+&
+'
+(
+)
+*
++
+,
+-
+.
+/
diff --git a/backend/ppocr/utils/dict/ka_dict.txt b/backend/ppocr/utils/dict/ka_dict.txt
new file mode 100644
index 00000000..d506b691
--- /dev/null
+++ b/backend/ppocr/utils/dict/ka_dict.txt
@@ -0,0 +1,153 @@
+k
+a
+_
+i
+m
+g
+/
+1
+2
+I
+L
+S
+V
+R
+C
+0
+v
+l
+6
+4
+8
+.
+j
+p
+ಗ
+ು
+ಣ
+ಪ
+ಡ
+ಿ
+ಸ
+ಲ
+ಾ
+ದ
+್
+7
+5
+3
+ವ
+ಷ
+ಬ
+ಹ
+ೆ
+9
+ಅ
+ಳ
+ನ
+ರ
+ಉ
+ಕ
+ಎ
+ೇ
+ಂ
+ೈ
+ೊ
+ೀ
+ಯ
+ೋ
+ತ
+ಶ
+ಭ
+ಧ
+ಚ
+ಜ
+ೂ
+ಮ
+ಒ
+ೃ
+ಥ
+ಇ
+ಟ
+ಖ
+ಆ
+ಞ
+ಫ
+-
+ಢ
+ಊ
+ಓ
+ಐ
+ಃ
+ಘ
+ಝ
+ೌ
+ಠ
+ಛ
+ಔ
+ಏ
+ಈ
+ಋ
+೨
+೦
+೧
+೮
+೯
+೪
+,
+೫
+೭
+೩
+೬
+ಙ
+s
+c
+e
+n
+w
+o
+u
+t
+d
+E
+A
+T
+B
+Z
+N
+G
+O
+q
+z
+r
+x
+P
+K
+M
+J
+U
+D
+f
+F
+h
+b
+W
+Y
+y
+H
+X
+Q
+'
+#
+&
+!
+@
+$
+:
+%
+é
+É
+(
+?
++
+
diff --git a/backend/ppocr/utils/dict/kie_dict/xfund_class_list.txt b/backend/ppocr/utils/dict/kie_dict/xfund_class_list.txt
new file mode 100644
index 00000000..faded9f9
--- /dev/null
+++ b/backend/ppocr/utils/dict/kie_dict/xfund_class_list.txt
@@ -0,0 +1,4 @@
+OTHER
+QUESTION
+ANSWER
+HEADER
diff --git a/backend/ppocr/utils/dict/latin_dict.txt b/backend/ppocr/utils/dict/latin_dict.txt
new file mode 100644
index 00000000..e166bf33
--- /dev/null
+++ b/backend/ppocr/utils/dict/latin_dict.txt
@@ -0,0 +1,185 @@
+
+!
+"
+#
+$
+%
+&
+'
+(
+)
+*
++
+,
+-
+.
+/
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+:
+;
+<
+=
+>
+?
+@
+A
+B
+C
+D
+E
+F
+G
+H
+I
+J
+K
+L
+M
+N
+O
+P
+Q
+R
+S
+T
+U
+V
+W
+X
+Y
+Z
+[
+]
+_
+`
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+{
+}
+¡
+£
+§
+ª
+«
+
+°
+²
+³
+´
+µ
+·
+º
+»
+¿
+À
+Á
+Â
+Ä
+Å
+Ç
+È
+É
+Ê
+Ë
+Ì
+Í
+Î
+Ï
+Ò
+Ó
+Ô
+Õ
+Ö
+Ú
+Ü
+Ý
+ß
+à
+á
+â
+ã
+ä
+å
+æ
+ç
+è
+é
+ê
+ë
+ì
+í
+î
+ï
+ñ
+ò
+ó
+ô
+õ
+ö
+ø
+ù
+ú
+û
+ü
+ý
+ą
+Ć
+ć
+Č
+č
+Đ
+đ
+ę
+ı
+Ł
+ł
+ō
+Œ
+œ
+Š
+š
+Ÿ
+Ž
+ž
+ʒ
+β
+δ
+ε
+з
+Ṡ
+‘
+€
+™
diff --git a/backend/ppocr/utils/dict/layout_dict/layout_cdla_dict.txt b/backend/ppocr/utils/dict/layout_dict/layout_cdla_dict.txt
new file mode 100644
index 00000000..8be0f486
--- /dev/null
+++ b/backend/ppocr/utils/dict/layout_dict/layout_cdla_dict.txt
@@ -0,0 +1,10 @@
+text
+title
+figure
+figure_caption
+table
+table_caption
+header
+footer
+reference
+equation
\ No newline at end of file
diff --git a/backend/ppocr/utils/dict/layout_dict/layout_publaynet_dict.txt b/backend/ppocr/utils/dict/layout_dict/layout_publaynet_dict.txt
new file mode 100644
index 00000000..ca6acf4e
--- /dev/null
+++ b/backend/ppocr/utils/dict/layout_dict/layout_publaynet_dict.txt
@@ -0,0 +1,5 @@
+text
+title
+list
+table
+figure
\ No newline at end of file
diff --git a/backend/ppocr/utils/dict/layout_dict/layout_table_dict.txt b/backend/ppocr/utils/dict/layout_dict/layout_table_dict.txt
new file mode 100644
index 00000000..faea15ea
--- /dev/null
+++ b/backend/ppocr/utils/dict/layout_dict/layout_table_dict.txt
@@ -0,0 +1 @@
+table
\ No newline at end of file
diff --git a/backend/ppocr/utils/dict/pu_dict.txt b/backend/ppocr/utils/dict/pu_dict.txt
new file mode 100644
index 00000000..9500fae6
--- /dev/null
+++ b/backend/ppocr/utils/dict/pu_dict.txt
@@ -0,0 +1,130 @@
+p
+u
+_
+i
+m
+g
+/
+8
+I
+L
+S
+V
+R
+C
+2
+0
+1
+v
+a
+l
+6
+7
+4
+5
+.
+j
+
+q
+e
+s
+t
+ã
+o
+x
+9
+c
+n
+r
+z
+ç
+õ
+3
+A
+U
+d
+º
+ô
+
+,
+E
+;
+ó
+á
+b
+D
+?
+ú
+ê
+-
+h
+P
+f
+à
+N
+í
+O
+M
+G
+É
+é
+â
+F
+:
+T
+Á
+"
+Q
+)
+W
+J
+B
+H
+(
+ö
+%
+Ö
+«
+w
+K
+y
+!
+k
+]
+'
+Z
++
+Ç
+Õ
+Y
+À
+X
+µ
+»
+ª
+Í
+ü
+ä
+´
+è
+ñ
+ß
+ï
+Ú
+ë
+Ô
+Ï
+Ó
+[
+Ì
+<
+Â
+ò
+§
+³
+ø
+å
+#
+$
+&
+@
diff --git a/backend/ppocr/utils/dict/rs_dict.txt b/backend/ppocr/utils/dict/rs_dict.txt
new file mode 100644
index 00000000..d1ce46d2
--- /dev/null
+++ b/backend/ppocr/utils/dict/rs_dict.txt
@@ -0,0 +1,91 @@
+r
+s
+_
+i
+m
+g
+/
+1
+I
+L
+S
+V
+R
+C
+2
+0
+v
+a
+l
+7
+5
+8
+6
+.
+j
+p
+
+t
+d
+9
+3
+e
+š
+4
+k
+u
+ć
+c
+n
+đ
+o
+z
+č
+b
+ž
+f
+Z
+T
+h
+M
+F
+O
+Š
+B
+H
+A
+E
+Đ
+Ž
+D
+P
+G
+Č
+K
+U
+N
+J
+Ć
+w
+y
+W
+x
+Y
+X
+q
+Q
+#
+&
+$
+,
+-
+%
+'
+@
+!
+:
+?
+(
+É
+é
++
diff --git a/backend/ppocr/utils/dict/rsc_dict.txt b/backend/ppocr/utils/dict/rsc_dict.txt
new file mode 100644
index 00000000..95dd4636
--- /dev/null
+++ b/backend/ppocr/utils/dict/rsc_dict.txt
@@ -0,0 +1,134 @@
+r
+s
+c
+_
+i
+m
+g
+/
+5
+I
+L
+S
+V
+R
+C
+2
+0
+1
+v
+a
+l
+9
+7
+8
+.
+j
+p
+м
+а
+с
+и
+р
+ћ
+е
+ш
+3
+4
+о
+г
+н
+з
+в
+л
+6
+т
+ж
+у
+к
+п
+њ
+д
+ч
+С
+ј
+ф
+ц
+љ
+х
+О
+И
+А
+б
+Ш
+К
+ђ
+џ
+М
+В
+З
+Д
+Р
+У
+Н
+Т
+Б
+?
+П
+Х
+Ј
+Ц
+Г
+Љ
+Л
+Ф
+e
+n
+w
+E
+F
+A
+N
+f
+o
+b
+M
+G
+t
+y
+W
+k
+P
+u
+H
+B
+T
+z
+h
+O
+Y
+d
+U
+K
+D
+x
+X
+J
+Z
+Q
+q
+'
+-
+@
+é
+#
+!
+,
+%
+$
+:
+&
++
+(
+É
+
diff --git a/backend/ppocr/utils/dict/ru_dict.txt b/backend/ppocr/utils/dict/ru_dict.txt
index 3b0cf3a8..aff9c16e 100644
--- a/backend/ppocr/utils/dict/ru_dict.txt
+++ b/backend/ppocr/utils/dict/ru_dict.txt
@@ -1,65 +1,16 @@
-к
-в
-а
-з
-и
-у
-р
-о
-н
-я
-х
-п
-л
-ы
-г
-е
-т
-м
-д
-ж
-ш
-ь
-с
-ё
-б
-й
-ч
-ю
-ц
-щ
-М
-э
-ф
-А
-ъ
-С
-Ф
-Ю
-В
-К
-Т
-Н
-О
-Э
-У
-И
-Г
-Л
-Р
-Д
-Б
-Ш
-П
-З
-Х
-Е
-Ж
-Я
-Ц
-Ч
-Й
-Щ
+
+!
+#
+$
+%
+&
+'
+(
++
+,
+-
+.
+/
0
1
2
@@ -70,32 +21,9 @@
7
8
9
-a
-b
-c
-d
-e
-f
-g
-h
-i
-j
-k
-l
-m
-n
-o
-p
-q
-r
-s
-t
-u
-v
-w
-x
-y
-z
+:
+?
+@
A
B
C
@@ -122,4 +50,114 @@ W
X
Y
Z
-
+_
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+É
+é
+Ё
+Є
+І
+Ј
+Љ
+Ў
+А
+Б
+В
+Г
+Д
+Е
+Ж
+З
+И
+Й
+К
+Л
+М
+Н
+О
+П
+Р
+С
+Т
+У
+Ф
+Х
+Ц
+Ч
+Ш
+Щ
+Ъ
+Ы
+Ь
+Э
+Ю
+Я
+а
+б
+в
+г
+д
+е
+ж
+з
+и
+й
+к
+л
+м
+н
+о
+п
+р
+с
+т
+у
+ф
+х
+ц
+ч
+ш
+щ
+ъ
+ы
+ь
+э
+ю
+я
+ё
+ђ
+є
+і
+ј
+љ
+њ
+ћ
+ў
+џ
+Ґ
+ґ
diff --git a/backend/ppocr/utils/ic15_dict.txt b/backend/ppocr/utils/dict/spin_dict.txt
similarity index 51%
rename from backend/ppocr/utils/ic15_dict.txt
rename to backend/ppocr/utils/dict/spin_dict.txt
index 47406036..8ee8347f 100644
--- a/backend/ppocr/utils/ic15_dict.txt
+++ b/backend/ppocr/utils/dict/spin_dict.txt
@@ -33,4 +33,36 @@ v
w
x
y
-z
\ No newline at end of file
+z
+:
+(
+'
+-
+,
+%
+>
+.
+[
+?
+)
+"
+=
+_
+*
+]
+;
+&
++
+$
+@
+/
+|
+!
+<
+#
+`
+{
+~
+\
+}
+^
\ No newline at end of file
diff --git a/backend/ppocr/utils/dict/ta_dict.txt b/backend/ppocr/utils/dict/ta_dict.txt
index d1bae501..19d81892 100644
--- a/backend/ppocr/utils/dict/ta_dict.txt
+++ b/backend/ppocr/utils/dict/ta_dict.txt
@@ -22,7 +22,7 @@ l
8
.
j
-p
+p
ப
ூ
த
diff --git a/backend/ppocr/utils/dict/table_dict.txt b/backend/ppocr/utils/dict/table_dict.txt
new file mode 100644
index 00000000..2ef028c7
--- /dev/null
+++ b/backend/ppocr/utils/dict/table_dict.txt
@@ -0,0 +1,277 @@
+←
+
+☆
+─
+α
+
+
+⋅
+$
+ω
+ψ
+χ
+(
+υ
+≥
+σ
+,
+ρ
+ε
+0
+■
+4
+8
+✗
+b
+<
+✓
+Ψ
+Ω
+€
+D
+3
+Π
+H
+║
+
+L
+Φ
+Χ
+θ
+P
+κ
+λ
+μ
+T
+ξ
+X
+β
+γ
+δ
+\
+ζ
+η
+`
+d
+
+h
+f
+l
+Θ
+p
+√
+t
+
+x
+Β
+Γ
+Δ
+|
+ǂ
+ɛ
+j
+̧
+➢
+
+̌
+′
+«
+△
+▲
+#
+
+'
+Ι
++
+¶
+/
+▼
+⇑
+□
+·
+7
+▪
+;
+?
+➔
+∩
+C
+÷
+G
+⇒
+K
+
+O
+S
+С
+W
+Α
+[
+○
+_
+●
+‡
+c
+z
+g
+
+o
+
+〈
+〉
+s
+⩽
+w
+φ
+ʹ
+{
+»
+∣
+̆
+e
+ˆ
+∈
+τ
+◆
+ι
+∅
+∆
+∙
+∘
+Ø
+ß
+✔
+∞
+∑
+−
+×
+◊
+∗
+∖
+˃
+˂
+∫
+"
+i
+&
+π
+↔
+*
+∥
+æ
+∧
+.
+⁄
+ø
+Q
+∼
+6
+⁎
+:
+★
+>
+a
+B
+≈
+F
+J
+̄
+N
+♯
+R
+V
+
+―
+Z
+♣
+^
+¤
+¥
+§
+
+¢
+£
+≦
+
+≤
+‖
+Λ
+©
+n
+↓
+→
+↑
+r
+°
+±
+v
+
+♂
+k
+♀
+~
+ᅟ
+̇
+@
+”
+♦
+ł
+®
+⊕
+„
+!
+
+%
+⇓
+)
+-
+1
+5
+9
+=
+А
+A
+‰
+⋆
+Σ
+E
+◦
+I
+※
+M
+m
+̨
+⩾
+†
+
+•
+U
+Y
+
+]
+̸
+2
+‐
+–
+‒
+̂
+—
+̀
+́
+’
+‘
+⋮
+⋯
+̊
+“
+̈
+≧
+q
+u
+ı
+y
+
+
+̃
+}
+ν
diff --git a/backend/ppocr/utils/dict/table_master_structure_dict.txt b/backend/ppocr/utils/dict/table_master_structure_dict.txt
new file mode 100644
index 00000000..95ab2539
--- /dev/null
+++ b/backend/ppocr/utils/dict/table_master_structure_dict.txt
@@ -0,0 +1,39 @@
+
+
+ |
+
+
+
+
+
+
+ |
+ colspan="2"
+ colspan="3"
+
+
+ rowspan="2"
+ colspan="4"
+ colspan="6"
+ rowspan="3"
+ colspan="9"
+ colspan="10"
+ colspan="7"
+ rowspan="4"
+ rowspan="5"
+ rowspan="9"
+ colspan="8"
+ rowspan="8"
+ rowspan="6"
+ rowspan="7"
+ rowspan="10"
+
+
+
+
+
+
+
+
diff --git a/backend/ppocr/utils/dict/table_structure_dict.txt b/backend/ppocr/utils/dict/table_structure_dict.txt
new file mode 100644
index 00000000..8edb10b8
--- /dev/null
+++ b/backend/ppocr/utils/dict/table_structure_dict.txt
@@ -0,0 +1,28 @@
+
+
+
+ |
+
+
+
+
+
+ colspan="2"
+ colspan="3"
+ rowspan="2"
+ colspan="4"
+ colspan="6"
+ rowspan="3"
+ colspan="9"
+ colspan="10"
+ colspan="7"
+ rowspan="4"
+ rowspan="5"
+ rowspan="9"
+ colspan="8"
+ rowspan="8"
+ rowspan="6"
+ rowspan="7"
+ rowspan="10"
\ No newline at end of file
diff --git a/backend/ppocr/utils/dict/table_structure_dict_ch.txt b/backend/ppocr/utils/dict/table_structure_dict_ch.txt
new file mode 100644
index 00000000..0c59c0e9
--- /dev/null
+++ b/backend/ppocr/utils/dict/table_structure_dict_ch.txt
@@ -0,0 +1,48 @@
+
+
+ |
+
+
+
+
+ |
+ |
+ colspan="2"
+ colspan="3"
+ colspan="4"
+ colspan="5"
+ colspan="6"
+ colspan="7"
+ colspan="8"
+ colspan="9"
+ colspan="10"
+ colspan="11"
+ colspan="12"
+ colspan="13"
+ colspan="14"
+ colspan="15"
+ colspan="16"
+ colspan="17"
+ colspan="18"
+ colspan="19"
+ colspan="20"
+ rowspan="2"
+ rowspan="3"
+ rowspan="4"
+ rowspan="5"
+ rowspan="6"
+ rowspan="7"
+ rowspan="8"
+ rowspan="9"
+ rowspan="10"
+ rowspan="11"
+ rowspan="12"
+ rowspan="13"
+ rowspan="14"
+ rowspan="15"
+ rowspan="16"
+ rowspan="17"
+ rowspan="18"
+ rowspan="19"
+ rowspan="20"
diff --git a/backend/ppocr/utils/dict/xi_dict.txt b/backend/ppocr/utils/dict/xi_dict.txt
new file mode 100644
index 00000000..f195f1ea
--- /dev/null
+++ b/backend/ppocr/utils/dict/xi_dict.txt
@@ -0,0 +1,110 @@
+x
+i
+_
+m
+g
+/
+1
+0
+I
+L
+S
+V
+R
+C
+2
+v
+a
+l
+3
+6
+4
+5
+.
+j
+p
+
+Q
+u
+e
+r
+o
+8
+7
+n
+c
+9
+t
+b
+é
+q
+d
+ó
+y
+F
+s
+,
+O
+í
+T
+f
+"
+U
+M
+h
+:
+P
+H
+A
+E
+D
+z
+N
+á
+ñ
+ú
+%
+;
+è
++
+Y
+-
+B
+G
+(
+)
+¿
+?
+w
+¡
+!
+X
+É
+K
+k
+Á
+ü
+Ú
+«
+»
+J
+'
+ö
+W
+Z
+º
+Ö
+
+[
+]
+Ç
+ç
+à
+ä
+û
+ò
+Í
+ê
+ô
+ø
+ª
diff --git a/backend/ppocr/utils/e2e_metric/Deteval.py b/backend/ppocr/utils/e2e_metric/Deteval.py
new file mode 100755
index 00000000..45567a7d
--- /dev/null
+++ b/backend/ppocr/utils/e2e_metric/Deteval.py
@@ -0,0 +1,574 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import scipy.io as io
+from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area
+
+
+def get_socre_A(gt_dir, pred_dict):
+ allInputs = 1
+
+ def input_reading_mod(pred_dict):
+ """This helper reads input from txt files"""
+ det = []
+ n = len(pred_dict)
+ for i in range(n):
+ points = pred_dict[i]['points']
+ text = pred_dict[i]['texts']
+ point = ",".join(map(str, points.reshape(-1, )))
+ det.append([point, text])
+ return det
+
+ def gt_reading_mod(gt_dict):
+ """This helper reads groundtruths from mat files"""
+ gt = []
+ n = len(gt_dict)
+ for i in range(n):
+ points = gt_dict[i]['points'].tolist()
+ h = len(points)
+ text = gt_dict[i]['text']
+ xx = [
+ np.array(
+ ['x:'], dtype=' 1):
+ gt_x = list(map(int, np.squeeze(gt[1])))
+ gt_y = list(map(int, np.squeeze(gt[3])))
+ for det_id, detection in enumerate(detections):
+ detection_orig = detection
+ detection = [float(x) for x in detection[0].split(',')]
+ detection = list(map(int, detection))
+ det_x = detection[0::2]
+ det_y = detection[1::2]
+ det_gt_iou = iod(det_x, det_y, gt_x, gt_y)
+ if det_gt_iou > threshold:
+ detections[det_id] = []
+
+ detections[:] = [item for item in detections if item != []]
+ return detections
+
+ def sigma_calculation(det_x, det_y, gt_x, gt_y):
+ """
+ sigma = inter_area / gt_area
+ """
+ return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
+ area(gt_x, gt_y)), 2)
+
+ def tau_calculation(det_x, det_y, gt_x, gt_y):
+ if area(det_x, det_y) == 0.0:
+ return 0
+ return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
+ area(det_x, det_y)), 2)
+
+ ##############################Initialization###################################
+ # global_sigma = []
+ # global_tau = []
+ # global_pred_str = []
+ # global_gt_str = []
+ ###############################################################################
+
+ for input_id in range(allInputs):
+ if (input_id != '.DS_Store') and (input_id != 'Pascal_result.txt') and (
+ input_id != 'Pascal_result_curved.txt') and (input_id != 'Pascal_result_non_curved.txt') and (
+ input_id != 'Deteval_result.txt') and (input_id != 'Deteval_result_curved.txt') \
+ and (input_id != 'Deteval_result_non_curved.txt'):
+ detections = input_reading_mod(pred_dict)
+ groundtruths = gt_reading_mod(gt_dir)
+ detections = detection_filtering(
+ detections,
+ groundtruths) # filters detections overlapping with DC area
+ dc_id = []
+ for i in range(len(groundtruths)):
+ if groundtruths[i][5] == '#':
+ dc_id.append(i)
+ cnt = 0
+ for a in dc_id:
+ num = a - cnt
+ del groundtruths[num]
+ cnt += 1
+
+ local_sigma_table = np.zeros((len(groundtruths), len(detections)))
+ local_tau_table = np.zeros((len(groundtruths), len(detections)))
+ local_pred_str = {}
+ local_gt_str = {}
+
+ for gt_id, gt in enumerate(groundtruths):
+ if len(detections) > 0:
+ for det_id, detection in enumerate(detections):
+ detection_orig = detection
+ detection = [float(x) for x in detection[0].split(',')]
+ detection = list(map(int, detection))
+ pred_seq_str = detection_orig[1].strip()
+ det_x = detection[0::2]
+ det_y = detection[1::2]
+ gt_x = list(map(int, np.squeeze(gt[1])))
+ gt_y = list(map(int, np.squeeze(gt[3])))
+ gt_seq_str = str(gt[4].tolist()[0])
+
+ local_sigma_table[gt_id, det_id] = sigma_calculation(
+ det_x, det_y, gt_x, gt_y)
+ local_tau_table[gt_id, det_id] = tau_calculation(
+ det_x, det_y, gt_x, gt_y)
+ local_pred_str[det_id] = pred_seq_str
+ local_gt_str[gt_id] = gt_seq_str
+
+ global_sigma = local_sigma_table
+ global_tau = local_tau_table
+ global_pred_str = local_pred_str
+ global_gt_str = local_gt_str
+
+ single_data = {}
+ single_data['sigma'] = global_sigma
+ single_data['global_tau'] = global_tau
+ single_data['global_pred_str'] = global_pred_str
+ single_data['global_gt_str'] = global_gt_str
+ return single_data
+
+
+def get_socre_B(gt_dir, img_id, pred_dict):
+ allInputs = 1
+
+ def input_reading_mod(pred_dict):
+ """This helper reads input from txt files"""
+ det = []
+ n = len(pred_dict)
+ for i in range(n):
+ points = pred_dict[i]['points']
+ text = pred_dict[i]['texts']
+ point = ",".join(map(str, points.reshape(-1, )))
+ det.append([point, text])
+ return det
+
+ def gt_reading_mod(gt_dir, gt_id):
+ gt = io.loadmat('%s/poly_gt_img%s.mat' % (gt_dir, gt_id))
+ gt = gt['polygt']
+ return gt
+
+ def detection_filtering(detections, groundtruths, threshold=0.5):
+ for gt_id, gt in enumerate(groundtruths):
+ if (gt[5] == '#') and (gt[1].shape[1] > 1):
+ gt_x = list(map(int, np.squeeze(gt[1])))
+ gt_y = list(map(int, np.squeeze(gt[3])))
+ for det_id, detection in enumerate(detections):
+ detection_orig = detection
+ detection = [float(x) for x in detection[0].split(',')]
+ detection = list(map(int, detection))
+ det_x = detection[0::2]
+ det_y = detection[1::2]
+ det_gt_iou = iod(det_x, det_y, gt_x, gt_y)
+ if det_gt_iou > threshold:
+ detections[det_id] = []
+
+ detections[:] = [item for item in detections if item != []]
+ return detections
+
+ def sigma_calculation(det_x, det_y, gt_x, gt_y):
+ """
+ sigma = inter_area / gt_area
+ """
+ return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
+ area(gt_x, gt_y)), 2)
+
+ def tau_calculation(det_x, det_y, gt_x, gt_y):
+ if area(det_x, det_y) == 0.0:
+ return 0
+ return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
+ area(det_x, det_y)), 2)
+
+ ##############################Initialization###################################
+ # global_sigma = []
+ # global_tau = []
+ # global_pred_str = []
+ # global_gt_str = []
+ ###############################################################################
+
+ for input_id in range(allInputs):
+ if (input_id != '.DS_Store') and (input_id != 'Pascal_result.txt') and (
+ input_id != 'Pascal_result_curved.txt') and (input_id != 'Pascal_result_non_curved.txt') and (
+ input_id != 'Deteval_result.txt') and (input_id != 'Deteval_result_curved.txt') \
+ and (input_id != 'Deteval_result_non_curved.txt'):
+ detections = input_reading_mod(pred_dict)
+ groundtruths = gt_reading_mod(gt_dir, img_id).tolist()
+ detections = detection_filtering(
+ detections,
+ groundtruths) # filters detections overlapping with DC area
+ dc_id = []
+ for i in range(len(groundtruths)):
+ if groundtruths[i][5] == '#':
+ dc_id.append(i)
+ cnt = 0
+ for a in dc_id:
+ num = a - cnt
+ del groundtruths[num]
+ cnt += 1
+
+ local_sigma_table = np.zeros((len(groundtruths), len(detections)))
+ local_tau_table = np.zeros((len(groundtruths), len(detections)))
+ local_pred_str = {}
+ local_gt_str = {}
+
+ for gt_id, gt in enumerate(groundtruths):
+ if len(detections) > 0:
+ for det_id, detection in enumerate(detections):
+ detection_orig = detection
+ detection = [float(x) for x in detection[0].split(',')]
+ detection = list(map(int, detection))
+ pred_seq_str = detection_orig[1].strip()
+ det_x = detection[0::2]
+ det_y = detection[1::2]
+ gt_x = list(map(int, np.squeeze(gt[1])))
+ gt_y = list(map(int, np.squeeze(gt[3])))
+ gt_seq_str = str(gt[4].tolist()[0])
+
+ local_sigma_table[gt_id, det_id] = sigma_calculation(
+ det_x, det_y, gt_x, gt_y)
+ local_tau_table[gt_id, det_id] = tau_calculation(
+ det_x, det_y, gt_x, gt_y)
+ local_pred_str[det_id] = pred_seq_str
+ local_gt_str[gt_id] = gt_seq_str
+
+ global_sigma = local_sigma_table
+ global_tau = local_tau_table
+ global_pred_str = local_pred_str
+ global_gt_str = local_gt_str
+
+ single_data = {}
+ single_data['sigma'] = global_sigma
+ single_data['global_tau'] = global_tau
+ single_data['global_pred_str'] = global_pred_str
+ single_data['global_gt_str'] = global_gt_str
+ return single_data
+
+
+def combine_results(all_data):
+ tr = 0.7
+ tp = 0.6
+ fsc_k = 0.8
+ k = 2
+ global_sigma = []
+ global_tau = []
+ global_pred_str = []
+ global_gt_str = []
+ for data in all_data:
+ global_sigma.append(data['sigma'])
+ global_tau.append(data['global_tau'])
+ global_pred_str.append(data['global_pred_str'])
+ global_gt_str.append(data['global_gt_str'])
+
+ global_accumulative_recall = 0
+ global_accumulative_precision = 0
+ total_num_gt = 0
+ total_num_det = 0
+ hit_str_count = 0
+ hit_count = 0
+
+ def one_to_one(local_sigma_table, local_tau_table,
+ local_accumulative_recall, local_accumulative_precision,
+ global_accumulative_recall, global_accumulative_precision,
+ gt_flag, det_flag, idy):
+ hit_str_num = 0
+ for gt_id in range(num_gt):
+ gt_matching_qualified_sigma_candidates = np.where(
+ local_sigma_table[gt_id, :] > tr)
+ gt_matching_num_qualified_sigma_candidates = gt_matching_qualified_sigma_candidates[
+ 0].shape[0]
+ gt_matching_qualified_tau_candidates = np.where(
+ local_tau_table[gt_id, :] > tp)
+ gt_matching_num_qualified_tau_candidates = gt_matching_qualified_tau_candidates[
+ 0].shape[0]
+
+ det_matching_qualified_sigma_candidates = np.where(
+ local_sigma_table[:, gt_matching_qualified_sigma_candidates[0]]
+ > tr)
+ det_matching_num_qualified_sigma_candidates = det_matching_qualified_sigma_candidates[
+ 0].shape[0]
+ det_matching_qualified_tau_candidates = np.where(
+ local_tau_table[:, gt_matching_qualified_tau_candidates[0]] >
+ tp)
+ det_matching_num_qualified_tau_candidates = det_matching_qualified_tau_candidates[
+ 0].shape[0]
+
+ if (gt_matching_num_qualified_sigma_candidates == 1) and (gt_matching_num_qualified_tau_candidates == 1) and \
+ (det_matching_num_qualified_sigma_candidates == 1) and (
+ det_matching_num_qualified_tau_candidates == 1):
+ global_accumulative_recall = global_accumulative_recall + 1.0
+ global_accumulative_precision = global_accumulative_precision + 1.0
+ local_accumulative_recall = local_accumulative_recall + 1.0
+ local_accumulative_precision = local_accumulative_precision + 1.0
+
+ gt_flag[0, gt_id] = 1
+ matched_det_id = np.where(local_sigma_table[gt_id, :] > tr)
+ # recg start
+ gt_str_cur = global_gt_str[idy][gt_id]
+ pred_str_cur = global_pred_str[idy][matched_det_id[0].tolist()[
+ 0]]
+ if pred_str_cur == gt_str_cur:
+ hit_str_num += 1
+ else:
+ if pred_str_cur.lower() == gt_str_cur.lower():
+ hit_str_num += 1
+ # recg end
+ det_flag[0, matched_det_id] = 1
+ return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
+
+ def one_to_many(local_sigma_table, local_tau_table,
+ local_accumulative_recall, local_accumulative_precision,
+ global_accumulative_recall, global_accumulative_precision,
+ gt_flag, det_flag, idy):
+ hit_str_num = 0
+ for gt_id in range(num_gt):
+ # skip the following if the groundtruth was matched
+ if gt_flag[0, gt_id] > 0:
+ continue
+
+ non_zero_in_sigma = np.where(local_sigma_table[gt_id, :] > 0)
+ num_non_zero_in_sigma = non_zero_in_sigma[0].shape[0]
+
+ if num_non_zero_in_sigma >= k:
+ ####search for all detections that overlaps with this groundtruth
+ qualified_tau_candidates = np.where((local_tau_table[
+ gt_id, :] >= tp) & (det_flag[0, :] == 0))
+ num_qualified_tau_candidates = qualified_tau_candidates[
+ 0].shape[0]
+
+ if num_qualified_tau_candidates == 1:
+ if ((local_tau_table[gt_id, qualified_tau_candidates] >= tp)
+ and
+ (local_sigma_table[gt_id, qualified_tau_candidates] >=
+ tr)):
+ # became an one-to-one case
+ global_accumulative_recall = global_accumulative_recall + 1.0
+ global_accumulative_precision = global_accumulative_precision + 1.0
+ local_accumulative_recall = local_accumulative_recall + 1.0
+ local_accumulative_precision = local_accumulative_precision + 1.0
+
+ gt_flag[0, gt_id] = 1
+ det_flag[0, qualified_tau_candidates] = 1
+ # recg start
+ gt_str_cur = global_gt_str[idy][gt_id]
+ pred_str_cur = global_pred_str[idy][
+ qualified_tau_candidates[0].tolist()[0]]
+ if pred_str_cur == gt_str_cur:
+ hit_str_num += 1
+ else:
+ if pred_str_cur.lower() == gt_str_cur.lower():
+ hit_str_num += 1
+ # recg end
+ elif (np.sum(local_sigma_table[gt_id, qualified_tau_candidates])
+ >= tr):
+ gt_flag[0, gt_id] = 1
+ det_flag[0, qualified_tau_candidates] = 1
+ # recg start
+ gt_str_cur = global_gt_str[idy][gt_id]
+ pred_str_cur = global_pred_str[idy][
+ qualified_tau_candidates[0].tolist()[0]]
+ if pred_str_cur == gt_str_cur:
+ hit_str_num += 1
+ else:
+ if pred_str_cur.lower() == gt_str_cur.lower():
+ hit_str_num += 1
+ # recg end
+
+ global_accumulative_recall = global_accumulative_recall + fsc_k
+ global_accumulative_precision = global_accumulative_precision + num_qualified_tau_candidates * fsc_k
+
+ local_accumulative_recall = local_accumulative_recall + fsc_k
+ local_accumulative_precision = local_accumulative_precision + num_qualified_tau_candidates * fsc_k
+
+ return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
+
+ def many_to_one(local_sigma_table, local_tau_table,
+ local_accumulative_recall, local_accumulative_precision,
+ global_accumulative_recall, global_accumulative_precision,
+ gt_flag, det_flag, idy):
+ hit_str_num = 0
+ for det_id in range(num_det):
+ # skip the following if the detection was matched
+ if det_flag[0, det_id] > 0:
+ continue
+
+ non_zero_in_tau = np.where(local_tau_table[:, det_id] > 0)
+ num_non_zero_in_tau = non_zero_in_tau[0].shape[0]
+
+ if num_non_zero_in_tau >= k:
+ ####search for all detections that overlaps with this groundtruth
+ qualified_sigma_candidates = np.where((
+ local_sigma_table[:, det_id] >= tp) & (gt_flag[0, :] == 0))
+ num_qualified_sigma_candidates = qualified_sigma_candidates[
+ 0].shape[0]
+
+ if num_qualified_sigma_candidates == 1:
+ if ((local_tau_table[qualified_sigma_candidates, det_id] >=
+ tp) and
+ (local_sigma_table[qualified_sigma_candidates, det_id]
+ >= tr)):
+ # became an one-to-one case
+ global_accumulative_recall = global_accumulative_recall + 1.0
+ global_accumulative_precision = global_accumulative_precision + 1.0
+ local_accumulative_recall = local_accumulative_recall + 1.0
+ local_accumulative_precision = local_accumulative_precision + 1.0
+
+ gt_flag[0, qualified_sigma_candidates] = 1
+ det_flag[0, det_id] = 1
+ # recg start
+ pred_str_cur = global_pred_str[idy][det_id]
+ gt_len = len(qualified_sigma_candidates[0])
+ for idx in range(gt_len):
+ ele_gt_id = qualified_sigma_candidates[0].tolist()[
+ idx]
+ if ele_gt_id not in global_gt_str[idy]:
+ continue
+ gt_str_cur = global_gt_str[idy][ele_gt_id]
+ if pred_str_cur == gt_str_cur:
+ hit_str_num += 1
+ break
+ else:
+ if pred_str_cur.lower() == gt_str_cur.lower():
+ hit_str_num += 1
+ break
+ # recg end
+ elif (np.sum(local_tau_table[qualified_sigma_candidates,
+ det_id]) >= tp):
+ det_flag[0, det_id] = 1
+ gt_flag[0, qualified_sigma_candidates] = 1
+ # recg start
+ pred_str_cur = global_pred_str[idy][det_id]
+ gt_len = len(qualified_sigma_candidates[0])
+ for idx in range(gt_len):
+ ele_gt_id = qualified_sigma_candidates[0].tolist()[idx]
+ if ele_gt_id not in global_gt_str[idy]:
+ continue
+ gt_str_cur = global_gt_str[idy][ele_gt_id]
+ if pred_str_cur == gt_str_cur:
+ hit_str_num += 1
+ break
+ else:
+ if pred_str_cur.lower() == gt_str_cur.lower():
+ hit_str_num += 1
+ break
+ # recg end
+
+ global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k
+ global_accumulative_precision = global_accumulative_precision + fsc_k
+
+ local_accumulative_recall = local_accumulative_recall + num_qualified_sigma_candidates * fsc_k
+ local_accumulative_precision = local_accumulative_precision + fsc_k
+ return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
+
+ for idx in range(len(global_sigma)):
+ local_sigma_table = np.array(global_sigma[idx])
+ local_tau_table = global_tau[idx]
+
+ num_gt = local_sigma_table.shape[0]
+ num_det = local_sigma_table.shape[1]
+
+ total_num_gt = total_num_gt + num_gt
+ total_num_det = total_num_det + num_det
+
+ local_accumulative_recall = 0
+ local_accumulative_precision = 0
+ gt_flag = np.zeros((1, num_gt))
+ det_flag = np.zeros((1, num_det))
+
+ #######first check for one-to-one case##########
+ local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
+ gt_flag, det_flag, hit_str_num = one_to_one(local_sigma_table, local_tau_table,
+ local_accumulative_recall, local_accumulative_precision,
+ global_accumulative_recall, global_accumulative_precision,
+ gt_flag, det_flag, idx)
+
+ hit_str_count += hit_str_num
+ #######then check for one-to-many case##########
+ local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
+ gt_flag, det_flag, hit_str_num = one_to_many(local_sigma_table, local_tau_table,
+ local_accumulative_recall, local_accumulative_precision,
+ global_accumulative_recall, global_accumulative_precision,
+ gt_flag, det_flag, idx)
+ hit_str_count += hit_str_num
+ #######then check for many-to-one case##########
+ local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
+ gt_flag, det_flag, hit_str_num = many_to_one(local_sigma_table, local_tau_table,
+ local_accumulative_recall, local_accumulative_precision,
+ global_accumulative_recall, global_accumulative_precision,
+ gt_flag, det_flag, idx)
+ hit_str_count += hit_str_num
+
+ try:
+ recall = global_accumulative_recall / total_num_gt
+ except ZeroDivisionError:
+ recall = 0
+
+ try:
+ precision = global_accumulative_precision / total_num_det
+ except ZeroDivisionError:
+ precision = 0
+
+ try:
+ f_score = 2 * precision * recall / (precision + recall)
+ except ZeroDivisionError:
+ f_score = 0
+
+ try:
+ seqerr = 1 - float(hit_str_count) / global_accumulative_recall
+ except ZeroDivisionError:
+ seqerr = 1
+
+ try:
+ recall_e2e = float(hit_str_count) / total_num_gt
+ except ZeroDivisionError:
+ recall_e2e = 0
+
+ try:
+ precision_e2e = float(hit_str_count) / total_num_det
+ except ZeroDivisionError:
+ precision_e2e = 0
+
+ try:
+ f_score_e2e = 2 * precision_e2e * recall_e2e / (
+ precision_e2e + recall_e2e)
+ except ZeroDivisionError:
+ f_score_e2e = 0
+
+ final = {
+ 'total_num_gt': total_num_gt,
+ 'total_num_det': total_num_det,
+ 'global_accumulative_recall': global_accumulative_recall,
+ 'hit_str_count': hit_str_count,
+ 'recall': recall,
+ 'precision': precision,
+ 'f_score': f_score,
+ 'seqerr': seqerr,
+ 'recall_e2e': recall_e2e,
+ 'precision_e2e': precision_e2e,
+ 'f_score_e2e': f_score_e2e
+ }
+ return final
diff --git a/backend/ppocr/utils/e2e_metric/polygon_fast.py b/backend/ppocr/utils/e2e_metric/polygon_fast.py
new file mode 100755
index 00000000..81c9ad70
--- /dev/null
+++ b/backend/ppocr/utils/e2e_metric/polygon_fast.py
@@ -0,0 +1,83 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import numpy as np
+from shapely.geometry import Polygon
+"""
+:param det_x: [1, N] Xs of detection's vertices
+:param det_y: [1, N] Ys of detection's vertices
+:param gt_x: [1, N] Xs of groundtruth's vertices
+:param gt_y: [1, N] Ys of groundtruth's vertices
+
+##############
+All the calculation of 'AREA' in this script is handled by:
+1) First generating a binary mask with the polygon area filled up with 1's
+2) Summing up all the 1's
+"""
+
+
+def area(x, y):
+ polygon = Polygon(np.stack([x, y], axis=1))
+ return float(polygon.area)
+
+
+def approx_area_of_intersection(det_x, det_y, gt_x, gt_y):
+ """
+ This helper determine if both polygons are intersecting with each others with an approximation method.
+ Area of intersection represented by the minimum bounding rectangular [xmin, ymin, xmax, ymax]
+ """
+ det_ymax = np.max(det_y)
+ det_xmax = np.max(det_x)
+ det_ymin = np.min(det_y)
+ det_xmin = np.min(det_x)
+
+ gt_ymax = np.max(gt_y)
+ gt_xmax = np.max(gt_x)
+ gt_ymin = np.min(gt_y)
+ gt_xmin = np.min(gt_x)
+
+ all_min_ymax = np.minimum(det_ymax, gt_ymax)
+ all_max_ymin = np.maximum(det_ymin, gt_ymin)
+
+ intersect_heights = np.maximum(0.0, (all_min_ymax - all_max_ymin))
+
+ all_min_xmax = np.minimum(det_xmax, gt_xmax)
+ all_max_xmin = np.maximum(det_xmin, gt_xmin)
+ intersect_widths = np.maximum(0.0, (all_min_xmax - all_max_xmin))
+
+ return intersect_heights * intersect_widths
+
+
+def area_of_intersection(det_x, det_y, gt_x, gt_y):
+ p1 = Polygon(np.stack([det_x, det_y], axis=1)).buffer(0)
+ p2 = Polygon(np.stack([gt_x, gt_y], axis=1)).buffer(0)
+ return float(p1.intersection(p2).area)
+
+
+def area_of_union(det_x, det_y, gt_x, gt_y):
+ p1 = Polygon(np.stack([det_x, det_y], axis=1)).buffer(0)
+ p2 = Polygon(np.stack([gt_x, gt_y], axis=1)).buffer(0)
+ return float(p1.union(p2).area)
+
+
+def iou(det_x, det_y, gt_x, gt_y):
+ return area_of_intersection(det_x, det_y, gt_x, gt_y) / (
+ area_of_union(det_x, det_y, gt_x, gt_y) + 1.0)
+
+
+def iod(det_x, det_y, gt_x, gt_y):
+ """
+ This helper determine the fraction of intersection area over detection area
+ """
+ return area_of_intersection(det_x, det_y, gt_x, gt_y) / (
+ area(det_x, det_y) + 1.0)
diff --git a/backend/ppocr/utils/e2e_utils/extract_batchsize.py b/backend/ppocr/utils/e2e_utils/extract_batchsize.py
new file mode 100644
index 00000000..e99a833e
--- /dev/null
+++ b/backend/ppocr/utils/e2e_utils/extract_batchsize.py
@@ -0,0 +1,87 @@
+import paddle
+import numpy as np
+import copy
+
+
+def org_tcl_rois(batch_size, pos_lists, pos_masks, label_lists, tcl_bs):
+ """
+ """
+ pos_lists_, pos_masks_, label_lists_ = [], [], []
+ img_bs = batch_size
+ ngpu = int(batch_size / img_bs)
+ img_ids = np.array(pos_lists, dtype=np.int32)[:, 0, 0].copy()
+ pos_lists_split, pos_masks_split, label_lists_split = [], [], []
+ for i in range(ngpu):
+ pos_lists_split.append([])
+ pos_masks_split.append([])
+ label_lists_split.append([])
+
+ for i in range(img_ids.shape[0]):
+ img_id = img_ids[i]
+ gpu_id = int(img_id / img_bs)
+ img_id = img_id % img_bs
+ pos_list = pos_lists[i].copy()
+ pos_list[:, 0] = img_id
+ pos_lists_split[gpu_id].append(pos_list)
+ pos_masks_split[gpu_id].append(pos_masks[i].copy())
+ label_lists_split[gpu_id].append(copy.deepcopy(label_lists[i]))
+ # repeat or delete
+ for i in range(ngpu):
+ vp_len = len(pos_lists_split[i])
+ if vp_len <= tcl_bs:
+ for j in range(0, tcl_bs - vp_len):
+ pos_list = pos_lists_split[i][j].copy()
+ pos_lists_split[i].append(pos_list)
+ pos_mask = pos_masks_split[i][j].copy()
+ pos_masks_split[i].append(pos_mask)
+ label_list = copy.deepcopy(label_lists_split[i][j])
+ label_lists_split[i].append(label_list)
+ else:
+ for j in range(0, vp_len - tcl_bs):
+ c_len = len(pos_lists_split[i])
+ pop_id = np.random.permutation(c_len)[0]
+ pos_lists_split[i].pop(pop_id)
+ pos_masks_split[i].pop(pop_id)
+ label_lists_split[i].pop(pop_id)
+ # merge
+ for i in range(ngpu):
+ pos_lists_.extend(pos_lists_split[i])
+ pos_masks_.extend(pos_masks_split[i])
+ label_lists_.extend(label_lists_split[i])
+ return pos_lists_, pos_masks_, label_lists_
+
+
+def pre_process(label_list, pos_list, pos_mask, max_text_length, max_text_nums,
+ pad_num, tcl_bs):
+ label_list = label_list.numpy()
+ batch, _, _, _ = label_list.shape
+ pos_list = pos_list.numpy()
+ pos_mask = pos_mask.numpy()
+ pos_list_t = []
+ pos_mask_t = []
+ label_list_t = []
+ for i in range(batch):
+ for j in range(max_text_nums):
+ if pos_mask[i, j].any():
+ pos_list_t.append(pos_list[i][j])
+ pos_mask_t.append(pos_mask[i][j])
+ label_list_t.append(label_list[i][j])
+ pos_list, pos_mask, label_list = org_tcl_rois(batch, pos_list_t, pos_mask_t,
+ label_list_t, tcl_bs)
+ label = []
+ tt = [l.tolist() for l in label_list]
+ for i in range(tcl_bs):
+ k = 0
+ for j in range(max_text_length):
+ if tt[i][j][0] != pad_num:
+ k += 1
+ else:
+ break
+ label.append(k)
+ label = paddle.to_tensor(label)
+ label = paddle.cast(label, dtype='int64')
+ pos_list = paddle.to_tensor(pos_list)
+ pos_mask = paddle.to_tensor(pos_mask)
+ label_list = paddle.squeeze(paddle.to_tensor(label_list), axis=2)
+ label_list = paddle.cast(label_list, dtype='int32')
+ return pos_list, pos_mask, label_list, label
diff --git a/backend/ppocr/utils/e2e_utils/extract_textpoint_fast.py b/backend/ppocr/utils/e2e_utils/extract_textpoint_fast.py
new file mode 100644
index 00000000..787cd301
--- /dev/null
+++ b/backend/ppocr/utils/e2e_utils/extract_textpoint_fast.py
@@ -0,0 +1,457 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Contains various CTC decoders."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import cv2
+import math
+
+import numpy as np
+from itertools import groupby
+from skimage.morphology._skeletonize import thin
+
+
+def get_dict(character_dict_path):
+ character_str = ""
+ with open(character_dict_path, "rb") as fin:
+ lines = fin.readlines()
+ for line in lines:
+ line = line.decode('utf-8').strip("\n").strip("\r\n")
+ character_str += line
+ dict_character = list(character_str)
+ return dict_character
+
+
+def softmax(logits):
+ """
+ logits: N x d
+ """
+ max_value = np.max(logits, axis=1, keepdims=True)
+ exp = np.exp(logits - max_value)
+ exp_sum = np.sum(exp, axis=1, keepdims=True)
+ dist = exp / exp_sum
+ return dist
+
+
+def get_keep_pos_idxs(labels, remove_blank=None):
+ """
+ Remove duplicate and get pos idxs of keep items.
+ The value of keep_blank should be [None, 95].
+ """
+ duplicate_len_list = []
+ keep_pos_idx_list = []
+ keep_char_idx_list = []
+ for k, v_ in groupby(labels):
+ current_len = len(list(v_))
+ if k != remove_blank:
+ current_idx = int(sum(duplicate_len_list) + current_len // 2)
+ keep_pos_idx_list.append(current_idx)
+ keep_char_idx_list.append(k)
+ duplicate_len_list.append(current_len)
+ return keep_char_idx_list, keep_pos_idx_list
+
+
+def remove_blank(labels, blank=0):
+ new_labels = [x for x in labels if x != blank]
+ return new_labels
+
+
+def insert_blank(labels, blank=0):
+ new_labels = [blank]
+ for l in labels:
+ new_labels += [l, blank]
+ return new_labels
+
+
+def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True):
+ """
+ CTC greedy (best path) decoder.
+ """
+ raw_str = np.argmax(np.array(probs_seq), axis=1)
+ remove_blank_in_pos = None if keep_blank_in_idxs else blank
+ dedup_str, keep_idx_list = get_keep_pos_idxs(
+ raw_str, remove_blank=remove_blank_in_pos)
+ dst_str = remove_blank(dedup_str, blank=blank)
+ return dst_str, keep_idx_list
+
+
+def instance_ctc_greedy_decoder(gather_info, logits_map, pts_num=4):
+ _, _, C = logits_map.shape
+ ys, xs = zip(*gather_info)
+ logits_seq = logits_map[list(ys), list(xs)]
+ probs_seq = logits_seq
+ labels = np.argmax(probs_seq, axis=1)
+ dst_str = [k for k, v_ in groupby(labels) if k != C - 1]
+ detal = len(gather_info) // (pts_num - 1)
+ keep_idx_list = [0] + [detal * (i + 1) for i in range(pts_num - 2)] + [-1]
+ keep_gather_list = [gather_info[idx] for idx in keep_idx_list]
+ return dst_str, keep_gather_list
+
+
+def ctc_decoder_for_image(gather_info_list,
+ logits_map,
+ Lexicon_Table,
+ pts_num=6):
+ """
+ CTC decoder using multiple processes.
+ """
+ decoder_str = []
+ decoder_xys = []
+ for gather_info in gather_info_list:
+ if len(gather_info) < pts_num:
+ continue
+ dst_str, xys_list = instance_ctc_greedy_decoder(
+ gather_info, logits_map, pts_num=pts_num)
+ dst_str_readable = ''.join([Lexicon_Table[idx] for idx in dst_str])
+ if len(dst_str_readable) < 2:
+ continue
+ decoder_str.append(dst_str_readable)
+ decoder_xys.append(xys_list)
+ return decoder_str, decoder_xys
+
+
+def sort_with_direction(pos_list, f_direction):
+ """
+ f_direction: h x w x 2
+ pos_list: [[y, x], [y, x], [y, x] ...]
+ """
+
+ def sort_part_with_direction(pos_list, point_direction):
+ pos_list = np.array(pos_list).reshape(-1, 2)
+ point_direction = np.array(point_direction).reshape(-1, 2)
+ average_direction = np.mean(point_direction, axis=0, keepdims=True)
+ pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
+ sorted_list = pos_list[np.argsort(pos_proj_leng)].tolist()
+ sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
+ return sorted_list, sorted_direction
+
+ pos_list = np.array(pos_list).reshape(-1, 2)
+ point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
+ point_direction = point_direction[:, ::-1] # x, y -> y, x
+ sorted_point, sorted_direction = sort_part_with_direction(pos_list,
+ point_direction)
+
+ point_num = len(sorted_point)
+ if point_num >= 16:
+ middle_num = point_num // 2
+ first_part_point = sorted_point[:middle_num]
+ first_point_direction = sorted_direction[:middle_num]
+ sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
+ first_part_point, first_point_direction)
+
+ last_part_point = sorted_point[middle_num:]
+ last_point_direction = sorted_direction[middle_num:]
+ sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
+ last_part_point, last_point_direction)
+ sorted_point = sorted_fist_part_point + sorted_last_part_point
+ sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
+
+ return sorted_point, np.array(sorted_direction)
+
+
+def add_id(pos_list, image_id=0):
+ """
+ Add id for gather feature, for inference.
+ """
+ new_list = []
+ for item in pos_list:
+ new_list.append((image_id, item[0], item[1]))
+ return new_list
+
+
+def sort_and_expand_with_direction(pos_list, f_direction):
+ """
+ f_direction: h x w x 2
+ pos_list: [[y, x], [y, x], [y, x] ...]
+ """
+ h, w, _ = f_direction.shape
+ sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
+
+ point_num = len(sorted_list)
+ sub_direction_len = max(point_num // 3, 2)
+ left_direction = point_direction[:sub_direction_len, :]
+ right_dirction = point_direction[point_num - sub_direction_len:, :]
+
+ left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
+ left_average_len = np.linalg.norm(left_average_direction)
+ left_start = np.array(sorted_list[0])
+ left_step = left_average_direction / (left_average_len + 1e-6)
+
+ right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
+ right_average_len = np.linalg.norm(right_average_direction)
+ right_step = right_average_direction / (right_average_len + 1e-6)
+ right_start = np.array(sorted_list[-1])
+
+ append_num = max(
+ int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
+ left_list = []
+ right_list = []
+ for i in range(append_num):
+ ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype(
+ 'int32').tolist()
+ if ly < h and lx < w and (ly, lx) not in left_list:
+ left_list.append((ly, lx))
+ ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype(
+ 'int32').tolist()
+ if ry < h and rx < w and (ry, rx) not in right_list:
+ right_list.append((ry, rx))
+
+ all_list = left_list[::-1] + sorted_list + right_list
+ return all_list
+
+
+def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
+ """
+ f_direction: h x w x 2
+ pos_list: [[y, x], [y, x], [y, x] ...]
+ binary_tcl_map: h x w
+ """
+ h, w, _ = f_direction.shape
+ sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
+
+ point_num = len(sorted_list)
+ sub_direction_len = max(point_num // 3, 2)
+ left_direction = point_direction[:sub_direction_len, :]
+ right_dirction = point_direction[point_num - sub_direction_len:, :]
+
+ left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
+ left_average_len = np.linalg.norm(left_average_direction)
+ left_start = np.array(sorted_list[0])
+ left_step = left_average_direction / (left_average_len + 1e-6)
+
+ right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
+ right_average_len = np.linalg.norm(right_average_direction)
+ right_step = right_average_direction / (right_average_len + 1e-6)
+ right_start = np.array(sorted_list[-1])
+
+ append_num = max(
+ int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
+ max_append_num = 2 * append_num
+
+ left_list = []
+ right_list = []
+ for i in range(max_append_num):
+ ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype(
+ 'int32').tolist()
+ if ly < h and lx < w and (ly, lx) not in left_list:
+ if binary_tcl_map[ly, lx] > 0.5:
+ left_list.append((ly, lx))
+ else:
+ break
+
+ for i in range(max_append_num):
+ ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype(
+ 'int32').tolist()
+ if ry < h and rx < w and (ry, rx) not in right_list:
+ if binary_tcl_map[ry, rx] > 0.5:
+ right_list.append((ry, rx))
+ else:
+ break
+
+ all_list = left_list[::-1] + sorted_list + right_list
+ return all_list
+
+
+def point_pair2poly(point_pair_list):
+ """
+ Transfer vertical point_pairs into poly point in clockwise.
+ """
+ point_num = len(point_pair_list) * 2
+ point_list = [0] * point_num
+ for idx, point_pair in enumerate(point_pair_list):
+ point_list[idx] = point_pair[0]
+ point_list[point_num - 1 - idx] = point_pair[1]
+ return np.array(point_list).reshape(-1, 2)
+
+
+def shrink_quad_along_width(quad, begin_width_ratio=0., end_width_ratio=1.):
+ ratio_pair = np.array(
+ [[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
+ p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
+ p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
+ return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
+
+
+def expand_poly_along_width(poly, shrink_ratio_of_width=0.3):
+ """
+ expand poly along width.
+ """
+ point_num = poly.shape[0]
+ left_quad = np.array(
+ [poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
+ left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
+ (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
+ left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0)
+ right_quad = np.array(
+ [
+ poly[point_num // 2 - 2], poly[point_num // 2 - 1],
+ poly[point_num // 2], poly[point_num // 2 + 1]
+ ],
+ dtype=np.float32)
+ right_ratio = 1.0 + shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
+ (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
+ right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio)
+ poly[0] = left_quad_expand[0]
+ poly[-1] = left_quad_expand[-1]
+ poly[point_num // 2 - 1] = right_quad_expand[1]
+ poly[point_num // 2] = right_quad_expand[2]
+ return poly
+
+
+def restore_poly(instance_yxs_list, seq_strs, p_border, ratio_w, ratio_h, src_w,
+ src_h, valid_set):
+ poly_list = []
+ keep_str_list = []
+ for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
+ if len(keep_str) < 2:
+ print('--> too short, {}'.format(keep_str))
+ continue
+
+ offset_expand = 1.0
+ if valid_set == 'totaltext':
+ offset_expand = 1.2
+
+ point_pair_list = []
+ for y, x in yx_center_line:
+ offset = p_border[:, y, x].reshape(2, 2) * offset_expand
+ ori_yx = np.array([y, x], dtype=np.float32)
+ point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array(
+ [ratio_w, ratio_h]).reshape(-1, 2)
+ point_pair_list.append(point_pair)
+
+ detected_poly = point_pair2poly(point_pair_list)
+ detected_poly = expand_poly_along_width(
+ detected_poly, shrink_ratio_of_width=0.2)
+ detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w)
+ detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h)
+
+ keep_str_list.append(keep_str)
+ if valid_set == 'partvgg':
+ middle_point = len(detected_poly) // 2
+ detected_poly = detected_poly[
+ [0, middle_point - 1, middle_point, -1], :]
+ poly_list.append(detected_poly)
+ elif valid_set == 'totaltext':
+ poly_list.append(detected_poly)
+ else:
+ print('--> Not supported format.')
+ exit(-1)
+ return poly_list, keep_str_list
+
+
+def generate_pivot_list_fast(p_score,
+ p_char_maps,
+ f_direction,
+ Lexicon_Table,
+ score_thresh=0.5):
+ """
+ return center point and end point of TCL instance; filter with the char maps;
+ """
+ p_score = p_score[0]
+ f_direction = f_direction.transpose(1, 2, 0)
+ p_tcl_map = (p_score > score_thresh) * 1.0
+ skeleton_map = thin(p_tcl_map.astype(np.uint8))
+ instance_count, instance_label_map = cv2.connectedComponents(
+ skeleton_map.astype(np.uint8), connectivity=8)
+
+ # get TCL Instance
+ all_pos_yxs = []
+ if instance_count > 0:
+ for instance_id in range(1, instance_count):
+ pos_list = []
+ ys, xs = np.where(instance_label_map == instance_id)
+ pos_list = list(zip(ys, xs))
+
+ if len(pos_list) < 3:
+ continue
+
+ pos_list_sorted = sort_and_expand_with_direction_v2(
+ pos_list, f_direction, p_tcl_map)
+ all_pos_yxs.append(pos_list_sorted)
+
+ p_char_maps = p_char_maps.transpose([1, 2, 0])
+ decoded_str, keep_yxs_list = ctc_decoder_for_image(
+ all_pos_yxs, logits_map=p_char_maps, Lexicon_Table=Lexicon_Table)
+ return keep_yxs_list, decoded_str
+
+
+def extract_main_direction(pos_list, f_direction):
+ """
+ f_direction: h x w x 2
+ pos_list: [[y, x], [y, x], [y, x] ...]
+ """
+ pos_list = np.array(pos_list)
+ point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]]
+ point_direction = point_direction[:, ::-1] # x, y -> y, x
+ average_direction = np.mean(point_direction, axis=0, keepdims=True)
+ average_direction = average_direction / (
+ np.linalg.norm(average_direction) + 1e-6)
+ return average_direction
+
+
+def sort_by_direction_with_image_id_deprecated(pos_list, f_direction):
+ """
+ f_direction: h x w x 2
+ pos_list: [[id, y, x], [id, y, x], [id, y, x] ...]
+ """
+ pos_list_full = np.array(pos_list).reshape(-1, 3)
+ pos_list = pos_list_full[:, 1:]
+ point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
+ point_direction = point_direction[:, ::-1] # x, y -> y, x
+ average_direction = np.mean(point_direction, axis=0, keepdims=True)
+ pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
+ sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
+ return sorted_list
+
+
+def sort_by_direction_with_image_id(pos_list, f_direction):
+ """
+ f_direction: h x w x 2
+ pos_list: [[y, x], [y, x], [y, x] ...]
+ """
+
+ def sort_part_with_direction(pos_list_full, point_direction):
+ pos_list_full = np.array(pos_list_full).reshape(-1, 3)
+ pos_list = pos_list_full[:, 1:]
+ point_direction = np.array(point_direction).reshape(-1, 2)
+ average_direction = np.mean(point_direction, axis=0, keepdims=True)
+ pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
+ sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
+ sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
+ return sorted_list, sorted_direction
+
+ pos_list = np.array(pos_list).reshape(-1, 3)
+ point_direction = f_direction[pos_list[:, 1], pos_list[:, 2]] # x, y
+ point_direction = point_direction[:, ::-1] # x, y -> y, x
+ sorted_point, sorted_direction = sort_part_with_direction(pos_list,
+ point_direction)
+
+ point_num = len(sorted_point)
+ if point_num >= 16:
+ middle_num = point_num // 2
+ first_part_point = sorted_point[:middle_num]
+ first_point_direction = sorted_direction[:middle_num]
+ sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
+ first_part_point, first_point_direction)
+
+ last_part_point = sorted_point[middle_num:]
+ last_point_direction = sorted_direction[middle_num:]
+ sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
+ last_part_point, last_point_direction)
+ sorted_point = sorted_fist_part_point + sorted_last_part_point
+ sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
+
+ return sorted_point
diff --git a/backend/ppocr/utils/e2e_utils/extract_textpoint_slow.py b/backend/ppocr/utils/e2e_utils/extract_textpoint_slow.py
new file mode 100644
index 00000000..ace46fba
--- /dev/null
+++ b/backend/ppocr/utils/e2e_utils/extract_textpoint_slow.py
@@ -0,0 +1,592 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Contains various CTC decoders."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import cv2
+import math
+
+import numpy as np
+from itertools import groupby
+from skimage.morphology._skeletonize import thin
+
+
+def get_dict(character_dict_path):
+ character_str = ""
+ with open(character_dict_path, "rb") as fin:
+ lines = fin.readlines()
+ for line in lines:
+ line = line.decode('utf-8').strip("\n").strip("\r\n")
+ character_str += line
+ dict_character = list(character_str)
+ return dict_character
+
+
+def point_pair2poly(point_pair_list):
+ """
+ Transfer vertical point_pairs into poly point in clockwise.
+ """
+ pair_length_list = []
+ for point_pair in point_pair_list:
+ pair_length = np.linalg.norm(point_pair[0] - point_pair[1])
+ pair_length_list.append(pair_length)
+ pair_length_list = np.array(pair_length_list)
+ pair_info = (pair_length_list.max(), pair_length_list.min(),
+ pair_length_list.mean())
+
+ point_num = len(point_pair_list) * 2
+ point_list = [0] * point_num
+ for idx, point_pair in enumerate(point_pair_list):
+ point_list[idx] = point_pair[0]
+ point_list[point_num - 1 - idx] = point_pair[1]
+ return np.array(point_list).reshape(-1, 2), pair_info
+
+
+def shrink_quad_along_width(quad, begin_width_ratio=0., end_width_ratio=1.):
+ """
+ Generate shrink_quad_along_width.
+ """
+ ratio_pair = np.array(
+ [[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
+ p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
+ p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
+ return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
+
+
+def expand_poly_along_width(poly, shrink_ratio_of_width=0.3):
+ """
+ expand poly along width.
+ """
+ point_num = poly.shape[0]
+ left_quad = np.array(
+ [poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
+ left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
+ (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
+ left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0)
+ right_quad = np.array(
+ [
+ poly[point_num // 2 - 2], poly[point_num // 2 - 1],
+ poly[point_num // 2], poly[point_num // 2 + 1]
+ ],
+ dtype=np.float32)
+ right_ratio = 1.0 + \
+ shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
+ (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
+ right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio)
+ poly[0] = left_quad_expand[0]
+ poly[-1] = left_quad_expand[-1]
+ poly[point_num // 2 - 1] = right_quad_expand[1]
+ poly[point_num // 2] = right_quad_expand[2]
+ return poly
+
+
+def softmax(logits):
+ """
+ logits: N x d
+ """
+ max_value = np.max(logits, axis=1, keepdims=True)
+ exp = np.exp(logits - max_value)
+ exp_sum = np.sum(exp, axis=1, keepdims=True)
+ dist = exp / exp_sum
+ return dist
+
+
+def get_keep_pos_idxs(labels, remove_blank=None):
+ """
+ Remove duplicate and get pos idxs of keep items.
+ The value of keep_blank should be [None, 95].
+ """
+ duplicate_len_list = []
+ keep_pos_idx_list = []
+ keep_char_idx_list = []
+ for k, v_ in groupby(labels):
+ current_len = len(list(v_))
+ if k != remove_blank:
+ current_idx = int(sum(duplicate_len_list) + current_len // 2)
+ keep_pos_idx_list.append(current_idx)
+ keep_char_idx_list.append(k)
+ duplicate_len_list.append(current_len)
+ return keep_char_idx_list, keep_pos_idx_list
+
+
+def remove_blank(labels, blank=0):
+ new_labels = [x for x in labels if x != blank]
+ return new_labels
+
+
+def insert_blank(labels, blank=0):
+ new_labels = [blank]
+ for l in labels:
+ new_labels += [l, blank]
+ return new_labels
+
+
+def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True):
+ """
+ CTC greedy (best path) decoder.
+ """
+ raw_str = np.argmax(np.array(probs_seq), axis=1)
+ remove_blank_in_pos = None if keep_blank_in_idxs else blank
+ dedup_str, keep_idx_list = get_keep_pos_idxs(
+ raw_str, remove_blank=remove_blank_in_pos)
+ dst_str = remove_blank(dedup_str, blank=blank)
+ return dst_str, keep_idx_list
+
+
+def instance_ctc_greedy_decoder(gather_info,
+ logits_map,
+ keep_blank_in_idxs=True):
+ """
+ gather_info: [[x, y], [x, y] ...]
+ logits_map: H x W X (n_chars + 1)
+ """
+ _, _, C = logits_map.shape
+ ys, xs = zip(*gather_info)
+ logits_seq = logits_map[list(ys), list(xs)] # n x 96
+ probs_seq = softmax(logits_seq)
+ dst_str, keep_idx_list = ctc_greedy_decoder(
+ probs_seq, blank=C - 1, keep_blank_in_idxs=keep_blank_in_idxs)
+ keep_gather_list = [gather_info[idx] for idx in keep_idx_list]
+ return dst_str, keep_gather_list
+
+
+def ctc_decoder_for_image(gather_info_list, logits_map,
+ keep_blank_in_idxs=True):
+ """
+ CTC decoder using multiple processes.
+ """
+ decoder_results = []
+ for gather_info in gather_info_list:
+ res = instance_ctc_greedy_decoder(
+ gather_info, logits_map, keep_blank_in_idxs=keep_blank_in_idxs)
+ decoder_results.append(res)
+ return decoder_results
+
+
+def sort_with_direction(pos_list, f_direction):
+ """
+ f_direction: h x w x 2
+ pos_list: [[y, x], [y, x], [y, x] ...]
+ """
+
+ def sort_part_with_direction(pos_list, point_direction):
+ pos_list = np.array(pos_list).reshape(-1, 2)
+ point_direction = np.array(point_direction).reshape(-1, 2)
+ average_direction = np.mean(point_direction, axis=0, keepdims=True)
+ pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
+ sorted_list = pos_list[np.argsort(pos_proj_leng)].tolist()
+ sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
+ return sorted_list, sorted_direction
+
+ pos_list = np.array(pos_list).reshape(-1, 2)
+ point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
+ point_direction = point_direction[:, ::-1] # x, y -> y, x
+ sorted_point, sorted_direction = sort_part_with_direction(pos_list,
+ point_direction)
+
+ point_num = len(sorted_point)
+ if point_num >= 16:
+ middle_num = point_num // 2
+ first_part_point = sorted_point[:middle_num]
+ first_point_direction = sorted_direction[:middle_num]
+ sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
+ first_part_point, first_point_direction)
+
+ last_part_point = sorted_point[middle_num:]
+ last_point_direction = sorted_direction[middle_num:]
+ sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
+ last_part_point, last_point_direction)
+ sorted_point = sorted_fist_part_point + sorted_last_part_point
+ sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
+
+ return sorted_point, np.array(sorted_direction)
+
+
+def add_id(pos_list, image_id=0):
+ """
+ Add id for gather feature, for inference.
+ """
+ new_list = []
+ for item in pos_list:
+ new_list.append((image_id, item[0], item[1]))
+ return new_list
+
+
+def sort_and_expand_with_direction(pos_list, f_direction):
+ """
+ f_direction: h x w x 2
+ pos_list: [[y, x], [y, x], [y, x] ...]
+ """
+ h, w, _ = f_direction.shape
+ sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
+
+ # expand along
+ point_num = len(sorted_list)
+ sub_direction_len = max(point_num // 3, 2)
+ left_direction = point_direction[:sub_direction_len, :]
+ right_dirction = point_direction[point_num - sub_direction_len:, :]
+
+ left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
+ left_average_len = np.linalg.norm(left_average_direction)
+ left_start = np.array(sorted_list[0])
+ left_step = left_average_direction / (left_average_len + 1e-6)
+
+ right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
+ right_average_len = np.linalg.norm(right_average_direction)
+ right_step = right_average_direction / (right_average_len + 1e-6)
+ right_start = np.array(sorted_list[-1])
+
+ append_num = max(
+ int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
+ left_list = []
+ right_list = []
+ for i in range(append_num):
+ ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype(
+ 'int32').tolist()
+ if ly < h and lx < w and (ly, lx) not in left_list:
+ left_list.append((ly, lx))
+ ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype(
+ 'int32').tolist()
+ if ry < h and rx < w and (ry, rx) not in right_list:
+ right_list.append((ry, rx))
+
+ all_list = left_list[::-1] + sorted_list + right_list
+ return all_list
+
+
+def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
+ """
+ f_direction: h x w x 2
+ pos_list: [[y, x], [y, x], [y, x] ...]
+ binary_tcl_map: h x w
+ """
+ h, w, _ = f_direction.shape
+ sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
+
+ # expand along
+ point_num = len(sorted_list)
+ sub_direction_len = max(point_num // 3, 2)
+ left_direction = point_direction[:sub_direction_len, :]
+ right_dirction = point_direction[point_num - sub_direction_len:, :]
+
+ left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
+ left_average_len = np.linalg.norm(left_average_direction)
+ left_start = np.array(sorted_list[0])
+ left_step = left_average_direction / (left_average_len + 1e-6)
+
+ right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
+ right_average_len = np.linalg.norm(right_average_direction)
+ right_step = right_average_direction / (right_average_len + 1e-6)
+ right_start = np.array(sorted_list[-1])
+
+ append_num = max(
+ int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
+ max_append_num = 2 * append_num
+
+ left_list = []
+ right_list = []
+ for i in range(max_append_num):
+ ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype(
+ 'int32').tolist()
+ if ly < h and lx < w and (ly, lx) not in left_list:
+ if binary_tcl_map[ly, lx] > 0.5:
+ left_list.append((ly, lx))
+ else:
+ break
+
+ for i in range(max_append_num):
+ ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype(
+ 'int32').tolist()
+ if ry < h and rx < w and (ry, rx) not in right_list:
+ if binary_tcl_map[ry, rx] > 0.5:
+ right_list.append((ry, rx))
+ else:
+ break
+
+ all_list = left_list[::-1] + sorted_list + right_list
+ return all_list
+
+
+def generate_pivot_list_curved(p_score,
+ p_char_maps,
+ f_direction,
+ score_thresh=0.5,
+ is_expand=True,
+ is_backbone=False,
+ image_id=0):
+ """
+ return center point and end point of TCL instance; filter with the char maps;
+ """
+ p_score = p_score[0]
+ f_direction = f_direction.transpose(1, 2, 0)
+ p_tcl_map = (p_score > score_thresh) * 1.0
+ skeleton_map = thin(p_tcl_map)
+ instance_count, instance_label_map = cv2.connectedComponents(
+ skeleton_map.astype(np.uint8), connectivity=8)
+
+ # get TCL Instance
+ all_pos_yxs = []
+ center_pos_yxs = []
+ end_points_yxs = []
+ instance_center_pos_yxs = []
+ pred_strs = []
+ if instance_count > 0:
+ for instance_id in range(1, instance_count):
+ pos_list = []
+ ys, xs = np.where(instance_label_map == instance_id)
+ pos_list = list(zip(ys, xs))
+
+ ### FIX-ME, eliminate outlier
+ if len(pos_list) < 3:
+ continue
+
+ if is_expand:
+ pos_list_sorted = sort_and_expand_with_direction_v2(
+ pos_list, f_direction, p_tcl_map)
+ else:
+ pos_list_sorted, _ = sort_with_direction(pos_list, f_direction)
+ all_pos_yxs.append(pos_list_sorted)
+
+ # use decoder to filter backgroud points.
+ p_char_maps = p_char_maps.transpose([1, 2, 0])
+ decode_res = ctc_decoder_for_image(
+ all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True)
+ for decoded_str, keep_yxs_list in decode_res:
+ if is_backbone:
+ keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id)
+ instance_center_pos_yxs.append(keep_yxs_list_with_id)
+ pred_strs.append(decoded_str)
+ else:
+ end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1]))
+ center_pos_yxs.extend(keep_yxs_list)
+
+ if is_backbone:
+ return pred_strs, instance_center_pos_yxs
+ else:
+ return center_pos_yxs, end_points_yxs
+
+
+def generate_pivot_list_horizontal(p_score,
+ p_char_maps,
+ f_direction,
+ score_thresh=0.5,
+ is_backbone=False,
+ image_id=0):
+ """
+ return center point and end point of TCL instance; filter with the char maps;
+ """
+ p_score = p_score[0]
+ f_direction = f_direction.transpose(1, 2, 0)
+ p_tcl_map_bi = (p_score > score_thresh) * 1.0
+ instance_count, instance_label_map = cv2.connectedComponents(
+ p_tcl_map_bi.astype(np.uint8), connectivity=8)
+
+ # get TCL Instance
+ all_pos_yxs = []
+ center_pos_yxs = []
+ end_points_yxs = []
+ instance_center_pos_yxs = []
+
+ if instance_count > 0:
+ for instance_id in range(1, instance_count):
+ pos_list = []
+ ys, xs = np.where(instance_label_map == instance_id)
+ pos_list = list(zip(ys, xs))
+
+ ### FIX-ME, eliminate outlier
+ if len(pos_list) < 5:
+ continue
+
+ # add rule here
+ main_direction = extract_main_direction(pos_list,
+ f_direction) # y x
+ reference_directin = np.array([0, 1]).reshape([-1, 2]) # y x
+ is_h_angle = abs(np.sum(
+ main_direction * reference_directin)) < math.cos(math.pi / 180 *
+ 70)
+
+ point_yxs = np.array(pos_list)
+ max_y, max_x = np.max(point_yxs, axis=0)
+ min_y, min_x = np.min(point_yxs, axis=0)
+ is_h_len = (max_y - min_y) < 1.5 * (max_x - min_x)
+
+ pos_list_final = []
+ if is_h_len:
+ xs = np.unique(xs)
+ for x in xs:
+ ys = instance_label_map[:, x].copy().reshape((-1, ))
+ y = int(np.where(ys == instance_id)[0].mean())
+ pos_list_final.append((y, x))
+ else:
+ ys = np.unique(ys)
+ for y in ys:
+ xs = instance_label_map[y, :].copy().reshape((-1, ))
+ x = int(np.where(xs == instance_id)[0].mean())
+ pos_list_final.append((y, x))
+
+ pos_list_sorted, _ = sort_with_direction(pos_list_final,
+ f_direction)
+ all_pos_yxs.append(pos_list_sorted)
+
+ # use decoder to filter backgroud points.
+ p_char_maps = p_char_maps.transpose([1, 2, 0])
+ decode_res = ctc_decoder_for_image(
+ all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True)
+ for decoded_str, keep_yxs_list in decode_res:
+ if is_backbone:
+ keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id)
+ instance_center_pos_yxs.append(keep_yxs_list_with_id)
+ else:
+ end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1]))
+ center_pos_yxs.extend(keep_yxs_list)
+
+ if is_backbone:
+ return instance_center_pos_yxs
+ else:
+ return center_pos_yxs, end_points_yxs
+
+
+def generate_pivot_list_slow(p_score,
+ p_char_maps,
+ f_direction,
+ score_thresh=0.5,
+ is_backbone=False,
+ is_curved=True,
+ image_id=0):
+ """
+ Warp all the function together.
+ """
+ if is_curved:
+ return generate_pivot_list_curved(
+ p_score,
+ p_char_maps,
+ f_direction,
+ score_thresh=score_thresh,
+ is_expand=True,
+ is_backbone=is_backbone,
+ image_id=image_id)
+ else:
+ return generate_pivot_list_horizontal(
+ p_score,
+ p_char_maps,
+ f_direction,
+ score_thresh=score_thresh,
+ is_backbone=is_backbone,
+ image_id=image_id)
+
+
+# for refine module
+def extract_main_direction(pos_list, f_direction):
+ """
+ f_direction: h x w x 2
+ pos_list: [[y, x], [y, x], [y, x] ...]
+ """
+ pos_list = np.array(pos_list)
+ point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]]
+ point_direction = point_direction[:, ::-1] # x, y -> y, x
+ average_direction = np.mean(point_direction, axis=0, keepdims=True)
+ average_direction = average_direction / (
+ np.linalg.norm(average_direction) + 1e-6)
+ return average_direction
+
+
+def sort_by_direction_with_image_id_deprecated(pos_list, f_direction):
+ """
+ f_direction: h x w x 2
+ pos_list: [[id, y, x], [id, y, x], [id, y, x] ...]
+ """
+ pos_list_full = np.array(pos_list).reshape(-1, 3)
+ pos_list = pos_list_full[:, 1:]
+ point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
+ point_direction = point_direction[:, ::-1] # x, y -> y, x
+ average_direction = np.mean(point_direction, axis=0, keepdims=True)
+ pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
+ sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
+ return sorted_list
+
+
+def sort_by_direction_with_image_id(pos_list, f_direction):
+ """
+ f_direction: h x w x 2
+ pos_list: [[y, x], [y, x], [y, x] ...]
+ """
+
+ def sort_part_with_direction(pos_list_full, point_direction):
+ pos_list_full = np.array(pos_list_full).reshape(-1, 3)
+ pos_list = pos_list_full[:, 1:]
+ point_direction = np.array(point_direction).reshape(-1, 2)
+ average_direction = np.mean(point_direction, axis=0, keepdims=True)
+ pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
+ sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
+ sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
+ return sorted_list, sorted_direction
+
+ pos_list = np.array(pos_list).reshape(-1, 3)
+ point_direction = f_direction[pos_list[:, 1], pos_list[:, 2]] # x, y
+ point_direction = point_direction[:, ::-1] # x, y -> y, x
+ sorted_point, sorted_direction = sort_part_with_direction(pos_list,
+ point_direction)
+
+ point_num = len(sorted_point)
+ if point_num >= 16:
+ middle_num = point_num // 2
+ first_part_point = sorted_point[:middle_num]
+ first_point_direction = sorted_direction[:middle_num]
+ sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
+ first_part_point, first_point_direction)
+
+ last_part_point = sorted_point[middle_num:]
+ last_point_direction = sorted_direction[middle_num:]
+ sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
+ last_part_point, last_point_direction)
+ sorted_point = sorted_fist_part_point + sorted_last_part_point
+ sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
+
+ return sorted_point
+
+
+def generate_pivot_list_tt_inference(p_score,
+ p_char_maps,
+ f_direction,
+ score_thresh=0.5,
+ is_backbone=False,
+ is_curved=True,
+ image_id=0):
+ """
+ return center point and end point of TCL instance; filter with the char maps;
+ """
+ p_score = p_score[0]
+ f_direction = f_direction.transpose(1, 2, 0)
+ p_tcl_map = (p_score > score_thresh) * 1.0
+ skeleton_map = thin(p_tcl_map)
+ instance_count, instance_label_map = cv2.connectedComponents(
+ skeleton_map.astype(np.uint8), connectivity=8)
+
+ # get TCL Instance
+ all_pos_yxs = []
+ if instance_count > 0:
+ for instance_id in range(1, instance_count):
+ pos_list = []
+ ys, xs = np.where(instance_label_map == instance_id)
+ pos_list = list(zip(ys, xs))
+ ### FIX-ME, eliminate outlier
+ if len(pos_list) < 3:
+ continue
+ pos_list_sorted = sort_and_expand_with_direction_v2(
+ pos_list, f_direction, p_tcl_map)
+ pos_list_sorted_with_id = add_id(pos_list_sorted, image_id=image_id)
+ all_pos_yxs.append(pos_list_sorted_with_id)
+ return all_pos_yxs
diff --git a/backend/ppocr/utils/e2e_utils/pgnet_pp_utils.py b/backend/ppocr/utils/e2e_utils/pgnet_pp_utils.py
new file mode 100644
index 00000000..a15503c0
--- /dev/null
+++ b/backend/ppocr/utils/e2e_utils/pgnet_pp_utils.py
@@ -0,0 +1,162 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import paddle
+import os
+import sys
+
+__dir__ = os.path.dirname(__file__)
+sys.path.append(__dir__)
+sys.path.append(os.path.join(__dir__, '..'))
+from extract_textpoint_slow import *
+from extract_textpoint_fast import generate_pivot_list_fast, restore_poly
+
+
+class PGNet_PostProcess(object):
+ # two different post-process
+ def __init__(self, character_dict_path, valid_set, score_thresh, outs_dict,
+ shape_list):
+ self.Lexicon_Table = get_dict(character_dict_path)
+ self.valid_set = valid_set
+ self.score_thresh = score_thresh
+ self.outs_dict = outs_dict
+ self.shape_list = shape_list
+
+ def pg_postprocess_fast(self):
+ p_score = self.outs_dict['f_score']
+ p_border = self.outs_dict['f_border']
+ p_char = self.outs_dict['f_char']
+ p_direction = self.outs_dict['f_direction']
+ if isinstance(p_score, paddle.Tensor):
+ p_score = p_score[0].numpy()
+ p_border = p_border[0].numpy()
+ p_direction = p_direction[0].numpy()
+ p_char = p_char[0].numpy()
+ else:
+ p_score = p_score[0]
+ p_border = p_border[0]
+ p_direction = p_direction[0]
+ p_char = p_char[0]
+
+ src_h, src_w, ratio_h, ratio_w = self.shape_list[0]
+ instance_yxs_list, seq_strs = generate_pivot_list_fast(
+ p_score,
+ p_char,
+ p_direction,
+ self.Lexicon_Table,
+ score_thresh=self.score_thresh)
+ poly_list, keep_str_list = restore_poly(instance_yxs_list, seq_strs,
+ p_border, ratio_w, ratio_h,
+ src_w, src_h, self.valid_set)
+ data = {
+ 'points': poly_list,
+ 'texts': keep_str_list,
+ }
+ return data
+
+ def pg_postprocess_slow(self):
+ p_score = self.outs_dict['f_score']
+ p_border = self.outs_dict['f_border']
+ p_char = self.outs_dict['f_char']
+ p_direction = self.outs_dict['f_direction']
+ if isinstance(p_score, paddle.Tensor):
+ p_score = p_score[0].numpy()
+ p_border = p_border[0].numpy()
+ p_direction = p_direction[0].numpy()
+ p_char = p_char[0].numpy()
+ else:
+ p_score = p_score[0]
+ p_border = p_border[0]
+ p_direction = p_direction[0]
+ p_char = p_char[0]
+ src_h, src_w, ratio_h, ratio_w = self.shape_list[0]
+ is_curved = self.valid_set == "totaltext"
+ char_seq_idx_set, instance_yxs_list = generate_pivot_list_slow(
+ p_score,
+ p_char,
+ p_direction,
+ score_thresh=self.score_thresh,
+ is_backbone=True,
+ is_curved=is_curved)
+ seq_strs = []
+ for char_idx_set in char_seq_idx_set:
+ pr_str = ''.join([self.Lexicon_Table[pos] for pos in char_idx_set])
+ seq_strs.append(pr_str)
+ poly_list = []
+ keep_str_list = []
+ all_point_list = []
+ all_point_pair_list = []
+ for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
+ if len(yx_center_line) == 1:
+ yx_center_line.append(yx_center_line[-1])
+
+ offset_expand = 1.0
+ if self.valid_set == 'totaltext':
+ offset_expand = 1.2
+
+ point_pair_list = []
+ for batch_id, y, x in yx_center_line:
+ offset = p_border[:, y, x].reshape(2, 2)
+ if offset_expand != 1.0:
+ offset_length = np.linalg.norm(
+ offset, axis=1, keepdims=True)
+ expand_length = np.clip(
+ offset_length * (offset_expand - 1),
+ a_min=0.5,
+ a_max=3.0)
+ offset_detal = offset / offset_length * expand_length
+ offset = offset + offset_detal
+ ori_yx = np.array([y, x], dtype=np.float32)
+ point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array(
+ [ratio_w, ratio_h]).reshape(-1, 2)
+ point_pair_list.append(point_pair)
+
+ all_point_list.append([
+ int(round(x * 4.0 / ratio_w)),
+ int(round(y * 4.0 / ratio_h))
+ ])
+ all_point_pair_list.append(point_pair.round().astype(np.int32)
+ .tolist())
+
+ detected_poly, pair_length_info = point_pair2poly(point_pair_list)
+ detected_poly = expand_poly_along_width(
+ detected_poly, shrink_ratio_of_width=0.2)
+ detected_poly[:, 0] = np.clip(
+ detected_poly[:, 0], a_min=0, a_max=src_w)
+ detected_poly[:, 1] = np.clip(
+ detected_poly[:, 1], a_min=0, a_max=src_h)
+
+ if len(keep_str) < 2:
+ continue
+
+ keep_str_list.append(keep_str)
+ detected_poly = np.round(detected_poly).astype('int32')
+ if self.valid_set == 'partvgg':
+ middle_point = len(detected_poly) // 2
+ detected_poly = detected_poly[
+ [0, middle_point - 1, middle_point, -1], :]
+ poly_list.append(detected_poly)
+ elif self.valid_set == 'totaltext':
+ poly_list.append(detected_poly)
+ else:
+ print('--> Not supported format.')
+ exit(-1)
+ data = {
+ 'points': poly_list,
+ 'texts': keep_str_list,
+ }
+ return data
diff --git a/backend/ppocr/utils/e2e_utils/visual.py b/backend/ppocr/utils/e2e_utils/visual.py
new file mode 100644
index 00000000..e6e4fd06
--- /dev/null
+++ b/backend/ppocr/utils/e2e_utils/visual.py
@@ -0,0 +1,162 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import numpy as np
+import cv2
+import time
+
+
+def resize_image(im, max_side_len=512):
+ """
+ resize image to a size multiple of max_stride which is required by the network
+ :param im: the resized image
+ :param max_side_len: limit of max image size to avoid out of memory in gpu
+ :return: the resized image and the resize ratio
+ """
+ h, w, _ = im.shape
+
+ resize_w = w
+ resize_h = h
+
+ if resize_h > resize_w:
+ ratio = float(max_side_len) / resize_h
+ else:
+ ratio = float(max_side_len) / resize_w
+
+ resize_h = int(resize_h * ratio)
+ resize_w = int(resize_w * ratio)
+
+ max_stride = 128
+ resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
+ resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
+ im = cv2.resize(im, (int(resize_w), int(resize_h)))
+ ratio_h = resize_h / float(h)
+ ratio_w = resize_w / float(w)
+
+ return im, (ratio_h, ratio_w)
+
+
+def resize_image_min(im, max_side_len=512):
+ """
+ """
+ h, w, _ = im.shape
+
+ resize_w = w
+ resize_h = h
+
+ if resize_h < resize_w:
+ ratio = float(max_side_len) / resize_h
+ else:
+ ratio = float(max_side_len) / resize_w
+
+ resize_h = int(resize_h * ratio)
+ resize_w = int(resize_w * ratio)
+
+ max_stride = 128
+ resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
+ resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
+ im = cv2.resize(im, (int(resize_w), int(resize_h)))
+ ratio_h = resize_h / float(h)
+ ratio_w = resize_w / float(w)
+ return im, (ratio_h, ratio_w)
+
+
+def resize_image_for_totaltext(im, max_side_len=512):
+ """
+ """
+ h, w, _ = im.shape
+
+ resize_w = w
+ resize_h = h
+ ratio = 1.25
+ if h * ratio > max_side_len:
+ ratio = float(max_side_len) / resize_h
+
+ resize_h = int(resize_h * ratio)
+ resize_w = int(resize_w * ratio)
+
+ max_stride = 128
+ resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
+ resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
+ im = cv2.resize(im, (int(resize_w), int(resize_h)))
+ ratio_h = resize_h / float(h)
+ ratio_w = resize_w / float(w)
+ return im, (ratio_h, ratio_w)
+
+
+def point_pair2poly(point_pair_list):
+ """
+ Transfer vertical point_pairs into poly point in clockwise.
+ """
+ pair_length_list = []
+ for point_pair in point_pair_list:
+ pair_length = np.linalg.norm(point_pair[0] - point_pair[1])
+ pair_length_list.append(pair_length)
+ pair_length_list = np.array(pair_length_list)
+ pair_info = (pair_length_list.max(), pair_length_list.min(),
+ pair_length_list.mean())
+
+ point_num = len(point_pair_list) * 2
+ point_list = [0] * point_num
+ for idx, point_pair in enumerate(point_pair_list):
+ point_list[idx] = point_pair[0]
+ point_list[point_num - 1 - idx] = point_pair[1]
+ return np.array(point_list).reshape(-1, 2), pair_info
+
+
+def shrink_quad_along_width(quad, begin_width_ratio=0., end_width_ratio=1.):
+ """
+ Generate shrink_quad_along_width.
+ """
+ ratio_pair = np.array(
+ [[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
+ p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
+ p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
+ return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
+
+
+def expand_poly_along_width(poly, shrink_ratio_of_width=0.3):
+ """
+ expand poly along width.
+ """
+ point_num = poly.shape[0]
+ left_quad = np.array(
+ [poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
+ left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
+ (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
+ left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0)
+ right_quad = np.array(
+ [
+ poly[point_num // 2 - 2], poly[point_num // 2 - 1],
+ poly[point_num // 2], poly[point_num // 2 + 1]
+ ],
+ dtype=np.float32)
+ right_ratio = 1.0 + \
+ shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
+ (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
+ right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio)
+ poly[0] = left_quad_expand[0]
+ poly[-1] = left_quad_expand[-1]
+ poly[point_num // 2 - 1] = right_quad_expand[1]
+ poly[point_num // 2] = right_quad_expand[2]
+ return poly
+
+
+def norm2(x, axis=None):
+ if axis:
+ return np.sqrt(np.sum(x**2, axis=axis))
+ return np.sqrt(np.sum(x**2))
+
+
+def cos(p1, p2):
+ return (p1 * p2).sum() / (norm2(p1) * norm2(p2))
diff --git a/backend/ppocr/utils/gen_label.py b/backend/ppocr/utils/gen_label.py
deleted file mode 100644
index 43afe9dd..00000000
--- a/backend/ppocr/utils/gen_label.py
+++ /dev/null
@@ -1,79 +0,0 @@
-#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
-#
-#Licensed under the Apache License, Version 2.0 (the "License");
-#you may not use this file except in compliance with the License.
-#You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-#Unless required by applicable law or agreed to in writing, software
-#distributed under the License is distributed on an "AS IS" BASIS,
-#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-#See the License for the specific language governing permissions and
-#limitations under the License.
-import os
-import argparse
-import json
-
-
-def gen_rec_label(input_path, out_label):
- with open(out_label, 'w') as out_file:
- with open(input_path, 'r') as f:
- for line in f.readlines():
- tmp = line.strip('\n').replace(" ", "").split(',')
- img_path, label = tmp[0], tmp[1]
- label = label.replace("\"", "")
- out_file.write(img_path + '\t' + label + '\n')
-
-
-def gen_det_label(root_path, input_dir, out_label):
- with open(out_label, 'w') as out_file:
- for label_file in os.listdir(input_dir):
- img_path = root_path + label_file[3:-4] + ".jpg"
- label = []
- with open(os.path.join(input_dir, label_file), 'r') as f:
- for line in f.readlines():
- tmp = line.strip("\n\r").replace("\xef\xbb\xbf",
- "").split(',')
- points = tmp[:8]
- s = []
- for i in range(0, len(points), 2):
- b = points[i:i + 2]
- b = [int(t) for t in b]
- s.append(b)
- result = {"transcription": tmp[8], "points": s}
- label.append(result)
-
- out_file.write(img_path + '\t' + json.dumps(
- label, ensure_ascii=False) + '\n')
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- '--mode',
- type=str,
- default="rec",
- help='Generate rec_label or det_label, can be set rec or det')
- parser.add_argument(
- '--root_path',
- type=str,
- default=".",
- help='The root directory of images.Only takes effect when mode=det ')
- parser.add_argument(
- '--input_path',
- type=str,
- default=".",
- help='Input_label or input path to be converted')
- parser.add_argument(
- '--output_label',
- type=str,
- default="out_label.txt",
- help='Output file name')
-
- args = parser.parse_args()
- if args.mode == "rec":
- print("Generate rec label")
- gen_rec_label(args.input_path, args.output_label)
- elif args.mode == "det":
- gen_det_label(args.root_path, args.input_path, args.output_label)
diff --git a/backend/ppocr/utils/iou.py b/backend/ppocr/utils/iou.py
new file mode 100644
index 00000000..35459f5f
--- /dev/null
+++ b/backend/ppocr/utils/iou.py
@@ -0,0 +1,54 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/whai362/PSENet/blob/python3/models/loss/iou.py
+"""
+
+import paddle
+
+EPS = 1e-6
+
+
+def iou_single(a, b, mask, n_class):
+ valid = mask == 1
+ a = a.masked_select(valid)
+ b = b.masked_select(valid)
+ miou = []
+ for i in range(n_class):
+ if a.shape == [0] and a.shape == b.shape:
+ inter = paddle.to_tensor(0.0)
+ union = paddle.to_tensor(0.0)
+ else:
+ inter = ((a == i).logical_and(b == i)).astype('float32')
+ union = ((a == i).logical_or(b == i)).astype('float32')
+ miou.append(paddle.sum(inter) / (paddle.sum(union) + EPS))
+ miou = sum(miou) / len(miou)
+ return miou
+
+
+def iou(a, b, mask, n_class=2, reduce=True):
+ batch_size = a.shape[0]
+
+ a = a.reshape([batch_size, -1])
+ b = b.reshape([batch_size, -1])
+ mask = mask.reshape([batch_size, -1])
+
+ iou = paddle.zeros((batch_size, ), dtype='float32')
+ for i in range(batch_size):
+ iou[i] = iou_single(a[i], b[i], mask[i], n_class)
+
+ if reduce:
+ iou = paddle.mean(iou)
+ return iou
diff --git a/backend/ppocr/utils/loggers/__init__.py b/backend/ppocr/utils/loggers/__init__.py
new file mode 100644
index 00000000..b1e92f73
--- /dev/null
+++ b/backend/ppocr/utils/loggers/__init__.py
@@ -0,0 +1,3 @@
+from .vdl_logger import VDLLogger
+from .wandb_logger import WandbLogger
+from .loggers import Loggers
diff --git a/backend/ppocr/utils/loggers/base_logger.py b/backend/ppocr/utils/loggers/base_logger.py
new file mode 100644
index 00000000..3a7fc359
--- /dev/null
+++ b/backend/ppocr/utils/loggers/base_logger.py
@@ -0,0 +1,15 @@
+import os
+from abc import ABC, abstractmethod
+
+class BaseLogger(ABC):
+ def __init__(self, save_dir):
+ self.save_dir = save_dir
+ os.makedirs(self.save_dir, exist_ok=True)
+
+ @abstractmethod
+ def log_metrics(self, metrics, prefix=None):
+ pass
+
+ @abstractmethod
+ def close(self):
+ pass
\ No newline at end of file
diff --git a/backend/ppocr/utils/loggers/loggers.py b/backend/ppocr/utils/loggers/loggers.py
new file mode 100644
index 00000000..26014662
--- /dev/null
+++ b/backend/ppocr/utils/loggers/loggers.py
@@ -0,0 +1,18 @@
+from .wandb_logger import WandbLogger
+
+class Loggers(object):
+ def __init__(self, loggers):
+ super().__init__()
+ self.loggers = loggers
+
+ def log_metrics(self, metrics, prefix=None, step=None):
+ for logger in self.loggers:
+ logger.log_metrics(metrics, prefix=prefix, step=step)
+
+ def log_model(self, is_best, prefix, metadata=None):
+ for logger in self.loggers:
+ logger.log_model(is_best=is_best, prefix=prefix, metadata=metadata)
+
+ def close(self):
+ for logger in self.loggers:
+ logger.close()
\ No newline at end of file
diff --git a/backend/ppocr/utils/loggers/vdl_logger.py b/backend/ppocr/utils/loggers/vdl_logger.py
new file mode 100644
index 00000000..c345f932
--- /dev/null
+++ b/backend/ppocr/utils/loggers/vdl_logger.py
@@ -0,0 +1,21 @@
+from .base_logger import BaseLogger
+from visualdl import LogWriter
+
+class VDLLogger(BaseLogger):
+ def __init__(self, save_dir):
+ super().__init__(save_dir)
+ self.vdl_writer = LogWriter(logdir=save_dir)
+
+ def log_metrics(self, metrics, prefix=None, step=None):
+ if not prefix:
+ prefix = ""
+ updated_metrics = {prefix + "/" + k: v for k, v in metrics.items()}
+
+ for k, v in updated_metrics.items():
+ self.vdl_writer.add_scalar(k, v, step)
+
+ def log_model(self, is_best, prefix, metadata=None):
+ pass
+
+ def close(self):
+ self.vdl_writer.close()
\ No newline at end of file
diff --git a/backend/ppocr/utils/loggers/wandb_logger.py b/backend/ppocr/utils/loggers/wandb_logger.py
new file mode 100644
index 00000000..5c805f4e
--- /dev/null
+++ b/backend/ppocr/utils/loggers/wandb_logger.py
@@ -0,0 +1,78 @@
+import os
+from .base_logger import BaseLogger
+
+class WandbLogger(BaseLogger):
+ def __init__(self,
+ project=None,
+ name=None,
+ id=None,
+ entity=None,
+ save_dir=None,
+ config=None,
+ **kwargs):
+ try:
+ import wandb
+ self.wandb = wandb
+ except ModuleNotFoundError:
+ raise ModuleNotFoundError(
+ "Please install wandb using `pip install wandb`"
+ )
+
+ self.project = project
+ self.name = name
+ self.id = id
+ self.save_dir = save_dir
+ self.config = config
+ self.kwargs = kwargs
+ self.entity = entity
+ self._run = None
+ self._wandb_init = dict(
+ project=self.project,
+ name=self.name,
+ id=self.id,
+ entity=self.entity,
+ dir=self.save_dir,
+ resume="allow"
+ )
+ self._wandb_init.update(**kwargs)
+
+ _ = self.run
+
+ if self.config:
+ self.run.settings_config.update(self.config)
+
+ @property
+ def run(self):
+ if self._run is None:
+ if self.wandb.run is not None:
+ logger.info(
+ "There is a wandb run already in progress "
+ "and newly created instances of `WandbLogger` will reuse"
+ " this run. If this is not desired, call `wandb.finish()`"
+ "before instantiating `WandbLogger`."
+ )
+ self._run = self.wandb.run
+ else:
+ self._run = self.wandb.init(**self._wandb_init)
+ return self._run
+
+ def log_metrics(self, metrics, prefix=None, step=None):
+ if not prefix:
+ prefix = ""
+ updated_metrics = {prefix.lower() + "/" + k: v for k, v in metrics.items()}
+
+ self.run.log(updated_metrics, step=step)
+
+ def log_model(self, is_best, prefix, metadata=None):
+ model_path = os.path.join(self.save_dir, prefix + '.pdparams')
+ artifact = self.wandb.Artifact('model-{}'.format(self.run.id), type='model', metadata=metadata)
+ artifact.add_file(model_path, name="model_ckpt.pdparams")
+
+ aliases = [prefix]
+ if is_best:
+ aliases.append("best")
+
+ self.run.log_artifact(artifact, aliases=aliases)
+
+ def close(self):
+ self.run.finish()
\ No newline at end of file
diff --git a/backend/ppocr/utils/logging.py b/backend/ppocr/utils/logging.py
index 951141db..1eac8f35 100644
--- a/backend/ppocr/utils/logging.py
+++ b/backend/ppocr/utils/logging.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,6 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+"""
+This code is refer from:
+https://github.com/WenmuZhou/PytorchOCR/blob/master/torchocr/utils/logging.py
+"""
import os
import sys
@@ -22,7 +26,7 @@
@functools.lru_cache()
-def get_logger(name='root', log_file=None, log_level=logging.INFO):
+def get_logger(name='ppocr', log_file=None, log_level=logging.DEBUG):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
@@ -63,4 +67,5 @@ def get_logger(name='root', log_file=None, log_level=logging.INFO):
else:
logger.setLevel(logging.ERROR)
logger_initialized[name] = True
+ logger.propagate = False
return logger
diff --git a/backend/ppocr/utils/network.py b/backend/ppocr/utils/network.py
new file mode 100644
index 00000000..118d1be3
--- /dev/null
+++ b/backend/ppocr/utils/network.py
@@ -0,0 +1,84 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import tarfile
+import requests
+from tqdm import tqdm
+
+from ppocr.utils.logging import get_logger
+
+
+def download_with_progressbar(url, save_path):
+ logger = get_logger()
+ response = requests.get(url, stream=True)
+ if response.status_code == 200:
+ total_size_in_bytes = int(response.headers.get('content-length', 1))
+ block_size = 1024 # 1 Kibibyte
+ progress_bar = tqdm(
+ total=total_size_in_bytes, unit='iB', unit_scale=True)
+ with open(save_path, 'wb') as file:
+ for data in response.iter_content(block_size):
+ progress_bar.update(len(data))
+ file.write(data)
+ progress_bar.close()
+ else:
+ logger.error("Something went wrong while downloading models")
+ sys.exit(0)
+
+
+def maybe_download(model_storage_directory, url):
+ # using custom model
+ tar_file_name_list = [
+ 'inference.pdiparams', 'inference.pdiparams.info', 'inference.pdmodel'
+ ]
+ if not os.path.exists(
+ os.path.join(model_storage_directory, 'inference.pdiparams')
+ ) or not os.path.exists(
+ os.path.join(model_storage_directory, 'inference.pdmodel')):
+ assert url.endswith('.tar'), 'Only supports tar compressed package'
+ tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
+ print('download {} to {}'.format(url, tmp_path))
+ os.makedirs(model_storage_directory, exist_ok=True)
+ download_with_progressbar(url, tmp_path)
+ with tarfile.open(tmp_path, 'r') as tarObj:
+ for member in tarObj.getmembers():
+ filename = None
+ for tar_file_name in tar_file_name_list:
+ if tar_file_name in member.name:
+ filename = tar_file_name
+ if filename is None:
+ continue
+ file = tarObj.extractfile(member)
+ with open(
+ os.path.join(model_storage_directory, filename),
+ 'wb') as f:
+ f.write(file.read())
+ os.remove(tmp_path)
+
+
+def is_link(s):
+ return s is not None and s.startswith('http')
+
+
+def confirm_model_dir_url(model_dir, default_model_dir, default_url):
+ url = default_url
+ if model_dir is None or is_link(model_dir):
+ if is_link(model_dir):
+ url = model_dir
+ file_name = url.split('/')[-1][:-4]
+ model_dir = default_model_dir
+ model_dir = os.path.join(model_dir, file_name)
+ return model_dir, url
diff --git a/backend/ppocr/utils/poly_nms.py b/backend/ppocr/utils/poly_nms.py
new file mode 100644
index 00000000..9dcb3d2c
--- /dev/null
+++ b/backend/ppocr/utils/poly_nms.py
@@ -0,0 +1,146 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+from shapely.geometry import Polygon
+
+
+def points2polygon(points):
+ """Convert k points to 1 polygon.
+
+ Args:
+ points (ndarray or list): A ndarray or a list of shape (2k)
+ that indicates k points.
+
+ Returns:
+ polygon (Polygon): A polygon object.
+ """
+ if isinstance(points, list):
+ points = np.array(points)
+
+ assert isinstance(points, np.ndarray)
+ assert (points.size % 2 == 0) and (points.size >= 8)
+
+ point_mat = points.reshape([-1, 2])
+ return Polygon(point_mat)
+
+
+def poly_intersection(poly_det, poly_gt, buffer=0.0001):
+ """Calculate the intersection area between two polygon.
+
+ Args:
+ poly_det (Polygon): A polygon predicted by detector.
+ poly_gt (Polygon): A gt polygon.
+
+ Returns:
+ intersection_area (float): The intersection area between two polygons.
+ """
+ assert isinstance(poly_det, Polygon)
+ assert isinstance(poly_gt, Polygon)
+
+ if buffer == 0:
+ poly_inter = poly_det & poly_gt
+ else:
+ poly_inter = poly_det.buffer(buffer) & poly_gt.buffer(buffer)
+ return poly_inter.area, poly_inter
+
+
+def poly_union(poly_det, poly_gt):
+ """Calculate the union area between two polygon.
+
+ Args:
+ poly_det (Polygon): A polygon predicted by detector.
+ poly_gt (Polygon): A gt polygon.
+
+ Returns:
+ union_area (float): The union area between two polygons.
+ """
+ assert isinstance(poly_det, Polygon)
+ assert isinstance(poly_gt, Polygon)
+
+ area_det = poly_det.area
+ area_gt = poly_gt.area
+ area_inters, _ = poly_intersection(poly_det, poly_gt)
+ return area_det + area_gt - area_inters
+
+
+def valid_boundary(x, with_score=True):
+ num = len(x)
+ if num < 8:
+ return False
+ if num % 2 == 0 and (not with_score):
+ return True
+ if num % 2 == 1 and with_score:
+ return True
+
+ return False
+
+
+def boundary_iou(src, target):
+ """Calculate the IOU between two boundaries.
+
+ Args:
+ src (list): Source boundary.
+ target (list): Target boundary.
+
+ Returns:
+ iou (float): The iou between two boundaries.
+ """
+ assert valid_boundary(src, False)
+ assert valid_boundary(target, False)
+ src_poly = points2polygon(src)
+ target_poly = points2polygon(target)
+
+ return poly_iou(src_poly, target_poly)
+
+
+def poly_iou(poly_det, poly_gt):
+ """Calculate the IOU between two polygons.
+
+ Args:
+ poly_det (Polygon): A polygon predicted by detector.
+ poly_gt (Polygon): A gt polygon.
+
+ Returns:
+ iou (float): The IOU between two polygons.
+ """
+ assert isinstance(poly_det, Polygon)
+ assert isinstance(poly_gt, Polygon)
+ area_inters, _ = poly_intersection(poly_det, poly_gt)
+ area_union = poly_union(poly_det, poly_gt)
+ if area_union == 0:
+ return 0.0
+ return area_inters / area_union
+
+
+def poly_nms(polygons, threshold):
+ assert isinstance(polygons, list)
+
+ polygons = np.array(sorted(polygons, key=lambda x: x[-1]))
+
+ keep_poly = []
+ index = [i for i in range(polygons.shape[0])]
+
+ while len(index) > 0:
+ keep_poly.append(polygons[index[-1]].tolist())
+ A = polygons[index[-1]][:-1]
+ index = np.delete(index, -1)
+ iou_list = np.zeros((len(index), ))
+ for i in range(len(index)):
+ B = polygons[index[i]][:-1]
+ iou_list[i] = boundary_iou(A, B)
+ remove_index = np.where(iou_list > threshold)
+ index = np.delete(index, remove_index)
+
+ return keep_poly
diff --git a/backend/ppocr/utils/profiler.py b/backend/ppocr/utils/profiler.py
new file mode 100644
index 00000000..c4e28bc6
--- /dev/null
+++ b/backend/ppocr/utils/profiler.py
@@ -0,0 +1,110 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+import paddle
+
+# A global variable to record the number of calling times for profiler
+# functions. It is used to specify the tracing range of training steps.
+_profiler_step_id = 0
+
+# A global variable to avoid parsing from string every time.
+_profiler_options = None
+
+
+class ProfilerOptions(object):
+ '''
+ Use a string to initialize a ProfilerOptions.
+ The string should be in the format: "key1=value1;key2=value;key3=value3".
+ For example:
+ "profile_path=model.profile"
+ "batch_range=[50, 60]; profile_path=model.profile"
+ "batch_range=[50, 60]; tracer_option=OpDetail; profile_path=model.profile"
+ ProfilerOptions supports following key-value pair:
+ batch_range - a integer list, e.g. [100, 110].
+ state - a string, the optional values are 'CPU', 'GPU' or 'All'.
+ sorted_key - a string, the optional values are 'calls', 'total',
+ 'max', 'min' or 'ave.
+ tracer_option - a string, the optional values are 'Default', 'OpDetail',
+ 'AllOpDetail'.
+ profile_path - a string, the path to save the serialized profile data,
+ which can be used to generate a timeline.
+ exit_on_finished - a boolean.
+ '''
+
+ def __init__(self, options_str):
+ assert isinstance(options_str, str)
+
+ self._options = {
+ 'batch_range': [10, 20],
+ 'state': 'All',
+ 'sorted_key': 'total',
+ 'tracer_option': 'Default',
+ 'profile_path': '/tmp/profile',
+ 'exit_on_finished': True
+ }
+ self._parse_from_string(options_str)
+
+ def _parse_from_string(self, options_str):
+ for kv in options_str.replace(' ', '').split(';'):
+ key, value = kv.split('=')
+ if key == 'batch_range':
+ value_list = value.replace('[', '').replace(']', '').split(',')
+ value_list = list(map(int, value_list))
+ if len(value_list) >= 2 and value_list[0] >= 0 and value_list[
+ 1] > value_list[0]:
+ self._options[key] = value_list
+ elif key == 'exit_on_finished':
+ self._options[key] = value.lower() in ("yes", "true", "t", "1")
+ elif key in [
+ 'state', 'sorted_key', 'tracer_option', 'profile_path'
+ ]:
+ self._options[key] = value
+
+ def __getitem__(self, name):
+ if self._options.get(name, None) is None:
+ raise ValueError(
+ "ProfilerOptions does not have an option named %s." % name)
+ return self._options[name]
+
+
+def add_profiler_step(options_str=None):
+ '''
+ Enable the operator-level timing using PaddlePaddle's profiler.
+ The profiler uses a independent variable to count the profiler steps.
+ One call of this function is treated as a profiler step.
+
+ Args:
+ profiler_options - a string to initialize the ProfilerOptions.
+ Default is None, and the profiler is disabled.
+ '''
+ if options_str is None:
+ return
+
+ global _profiler_step_id
+ global _profiler_options
+
+ if _profiler_options is None:
+ _profiler_options = ProfilerOptions(options_str)
+
+ if _profiler_step_id == _profiler_options['batch_range'][0]:
+ paddle.utils.profiler.start_profiler(
+ _profiler_options['state'], _profiler_options['tracer_option'])
+ elif _profiler_step_id == _profiler_options['batch_range'][1]:
+ paddle.utils.profiler.stop_profiler(_profiler_options['sorted_key'],
+ _profiler_options['profile_path'])
+ if _profiler_options['exit_on_finished']:
+ sys.exit(0)
+
+ _profiler_step_id += 1
diff --git a/backend/ppocr/utils/save_load.py b/backend/ppocr/utils/save_load.py
index 02814d62..b09f1db6 100644
--- a/backend/ppocr/utils/save_load.py
+++ b/backend/ppocr/utils/save_load.py
@@ -23,7 +23,9 @@
import paddle
-__all__ = ['init_model', 'save_model', 'load_dygraph_pretrain']
+from ppocr.utils.logging import get_logger
+
+__all__ = ['load_model']
def _mkdir_if_not_exist(path, logger):
@@ -42,58 +44,74 @@ def _mkdir_if_not_exist(path, logger):
raise OSError('Failed to mkdir {}'.format(path))
-def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False):
- if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
- raise ValueError("Model pretrain path {} does not "
- "exists.".format(path))
- if load_static_weights:
- pre_state_dict = paddle.static.load_program_state(path)
- param_state_dict = {}
- model_dict = model.state_dict()
- for key in model_dict.keys():
- weight_name = model_dict[key].name
- weight_name = weight_name.replace('binarize', '').replace(
- 'thresh', '') # for DB
- if weight_name in pre_state_dict.keys():
- # logger.info('Load weight: {}, shape: {}'.format(
- # weight_name, pre_state_dict[weight_name].shape))
- if 'encoder_rnn' in key:
- # delete axis which is 1
- pre_state_dict[weight_name] = pre_state_dict[
- weight_name].squeeze()
- # change axis
- if len(pre_state_dict[weight_name].shape) > 1:
- pre_state_dict[weight_name] = pre_state_dict[
- weight_name].transpose((1, 0))
- param_state_dict[key] = pre_state_dict[weight_name]
- else:
- param_state_dict[key] = model_dict[key]
- model.set_state_dict(param_state_dict)
- return
-
- param_state_dict = paddle.load(path + '.pdparams')
- model.set_state_dict(param_state_dict)
- return
-
-
-def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
+def load_model(config, model, optimizer=None, model_type='det'):
"""
load model from checkpoint or pretrained_model
"""
- gloabl_config = config['Global']
- checkpoints = gloabl_config.get('checkpoints')
- pretrained_model = gloabl_config.get('pretrained_model')
+ logger = get_logger()
+ global_config = config['Global']
+ checkpoints = global_config.get('checkpoints')
+ pretrained_model = global_config.get('pretrained_model')
best_model_dict = {}
+
+ if model_type == 'vqa':
+ checkpoints = config['Architecture']['Backbone']['checkpoints']
+ # load vqa method metric
+ if checkpoints:
+ if os.path.exists(os.path.join(checkpoints, 'metric.states')):
+ with open(os.path.join(checkpoints, 'metric.states'),
+ 'rb') as f:
+ states_dict = pickle.load(f) if six.PY2 else pickle.load(
+ f, encoding='latin1')
+ best_model_dict = states_dict.get('best_model_dict', {})
+ if 'epoch' in states_dict:
+ best_model_dict['start_epoch'] = states_dict['epoch'] + 1
+ logger.info("resume from {}".format(checkpoints))
+
+ if optimizer is not None:
+ if checkpoints[-1] in ['/', '\\']:
+ checkpoints = checkpoints[:-1]
+ if os.path.exists(checkpoints + '.pdopt'):
+ optim_dict = paddle.load(checkpoints + '.pdopt')
+ optimizer.set_state_dict(optim_dict)
+ else:
+ logger.warning(
+ "{}.pdopt is not exists, params of optimizer is not loaded".
+ format(checkpoints))
+ return best_model_dict
+
if checkpoints:
+ if checkpoints.endswith('.pdparams'):
+ checkpoints = checkpoints.replace('.pdparams', '')
assert os.path.exists(checkpoints + ".pdparams"), \
- "Given dir {}.pdparams not exist.".format(checkpoints)
- assert os.path.exists(checkpoints + ".pdopt"), \
- "Given dir {}.pdopt not exist.".format(checkpoints)
- para_dict = paddle.load(checkpoints + '.pdparams')
- opti_dict = paddle.load(checkpoints + '.pdopt')
- model.set_state_dict(para_dict)
+ "The {}.pdparams does not exists!".format(checkpoints)
+
+ # load params from trained model
+ params = paddle.load(checkpoints + '.pdparams')
+ state_dict = model.state_dict()
+ new_state_dict = {}
+ for key, value in state_dict.items():
+ if key not in params:
+ logger.warning("{} not in loaded params {} !".format(
+ key, params.keys()))
+ continue
+ pre_value = params[key]
+ if list(value.shape) == list(pre_value.shape):
+ new_state_dict[key] = pre_value
+ else:
+ logger.warning(
+ "The shape of model params {} {} not matched with loaded params shape {} !".
+ format(key, value.shape, pre_value.shape))
+ model.set_state_dict(new_state_dict)
+
if optimizer is not None:
- optimizer.set_state_dict(opti_dict)
+ if os.path.exists(checkpoints + '.pdopt'):
+ optim_dict = paddle.load(checkpoints + '.pdopt')
+ optimizer.set_state_dict(optim_dict)
+ else:
+ logger.warning(
+ "{}.pdopt is not exists, params of optimizer is not loaded".
+ format(checkpoints))
if os.path.exists(checkpoints + '.states'):
with open(checkpoints + '.states', 'rb') as f:
@@ -102,29 +120,44 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
best_model_dict = states_dict.get('best_model_dict', {})
if 'epoch' in states_dict:
best_model_dict['start_epoch'] = states_dict['epoch'] + 1
-
logger.info("resume from {}".format(checkpoints))
elif pretrained_model:
- load_static_weights = gloabl_config.get('load_static_weights', False)
- if not isinstance(pretrained_model, list):
- pretrained_model = [pretrained_model]
- if not isinstance(load_static_weights, list):
- load_static_weights = [load_static_weights] * len(pretrained_model)
- for idx, pretrained in enumerate(pretrained_model):
- load_static = load_static_weights[idx]
- load_dygraph_pretrain(
- model, logger, path=pretrained, load_static_weights=load_static)
- logger.info("load pretrained model from {}".format(
- pretrained_model))
+ load_pretrained_params(model, pretrained_model)
else:
logger.info('train from scratch')
return best_model_dict
-def save_model(net,
+def load_pretrained_params(model, path):
+ logger = get_logger()
+ if path.endswith('.pdparams'):
+ path = path.replace('.pdparams', '')
+ assert os.path.exists(path + ".pdparams"), \
+ "The {}.pdparams does not exists!".format(path)
+
+ params = paddle.load(path + '.pdparams')
+ state_dict = model.state_dict()
+ new_state_dict = {}
+ for k1 in params.keys():
+ if k1 not in state_dict.keys():
+ logger.warning("The pretrained params {} not in model".format(k1))
+ else:
+ if list(state_dict[k1].shape) == list(params[k1].shape):
+ new_state_dict[k1] = params[k1]
+ else:
+ logger.warning(
+ "The shape of model params {} {} not matched with loaded params {} {} !".
+ format(k1, state_dict[k1].shape, k1, params[k1].shape))
+ model.set_state_dict(new_state_dict)
+ logger.info("load pretrain successful from {}".format(path))
+ return model
+
+
+def save_model(model,
optimizer,
model_path,
logger,
+ config,
is_best=False,
prefix='ppocr',
**kwargs):
@@ -133,13 +166,20 @@ def save_model(net,
"""
_mkdir_if_not_exist(model_path, logger)
model_prefix = os.path.join(model_path, prefix)
- paddle.save(net.state_dict(), model_prefix + '.pdparams')
paddle.save(optimizer.state_dict(), model_prefix + '.pdopt')
-
+ if config['Architecture']["model_type"] != 'vqa':
+ paddle.save(model.state_dict(), model_prefix + '.pdparams')
+ metric_prefix = model_prefix
+ else:
+ if config['Global']['distributed']:
+ model._layers.backbone.model.save_pretrained(model_prefix)
+ else:
+ model.backbone.model.save_pretrained(model_prefix)
+ metric_prefix = os.path.join(model_prefix, 'metric')
# save metric and config
- with open(model_prefix + '.states', 'wb') as f:
- pickle.dump(kwargs, f, protocol=2)
if is_best:
+ with open(metric_prefix + '.states', 'wb') as f:
+ pickle.dump(kwargs, f, protocol=2)
logger.info('save best model is to {}'.format(model_prefix))
else:
logger.info("save model in {}".format(model_prefix))
diff --git a/backend/ppocr/utils/utility.py b/backend/ppocr/utils/utility.py
index 6a746314..4a25ff8b 100755
--- a/backend/ppocr/utils/utility.py
+++ b/backend/ppocr/utils/utility.py
@@ -14,7 +14,11 @@
import logging
import os
+import imghdr
import cv2
+import random
+import numpy as np
+import paddle
def print_dict(d, logger, delimiter=0):
@@ -45,23 +49,27 @@ def get_check_global_params(mode):
return check_params
+def _check_image_file(path):
+ img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif'}
+ return any([path.lower().endswith(e) for e in img_end])
+
+
def get_image_file_list(img_file):
imgs_lists = []
if img_file is None or not os.path.exists(img_file):
raise Exception("not found any img file in {}".format(img_file))
img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif'}
- if os.path.isfile(img_file) and os.path.splitext(img_file)[-1][1:].lower(
- ) in img_end:
+ if os.path.isfile(img_file) and _check_image_file(img_file):
imgs_lists.append(img_file)
elif os.path.isdir(img_file):
for single_file in os.listdir(img_file):
file_path = os.path.join(img_file, single_file)
- if os.path.isfile(file_path) and os.path.splitext(file_path)[-1][
- 1:].lower() in img_end:
+ if os.path.isfile(file_path) and _check_image_file(file_path):
imgs_lists.append(file_path)
if len(imgs_lists) == 0:
raise Exception("not found any img file in {}".format(img_file))
+ imgs_lists = sorted(imgs_lists)
return imgs_lists
@@ -78,3 +86,46 @@ def check_and_read_gif(img_path):
imgvalue = frame[:, :, ::-1]
return imgvalue, True
return None, False
+
+
+def load_vqa_bio_label_maps(label_map_path):
+ with open(label_map_path, "r", encoding='utf-8') as fin:
+ lines = fin.readlines()
+ lines = [line.strip() for line in lines]
+ if "O" not in lines:
+ lines.insert(0, "O")
+ labels = []
+ for line in lines:
+ if line == "O":
+ labels.append("O")
+ else:
+ labels.append("B-" + line)
+ labels.append("I-" + line)
+ label2id_map = {label: idx for idx, label in enumerate(labels)}
+ id2label_map = {idx: label for idx, label in enumerate(labels)}
+ return label2id_map, id2label_map
+
+
+def set_seed(seed=1024):
+ random.seed(seed)
+ np.random.seed(seed)
+ paddle.seed(seed)
+
+
+class AverageMeter:
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ """reset"""
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ """update"""
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
diff --git a/backend/ppocr/utils/visual.py b/backend/ppocr/utils/visual.py
new file mode 100644
index 00000000..7a8c1674
--- /dev/null
+++ b/backend/ppocr/utils/visual.py
@@ -0,0 +1,98 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import numpy as np
+from PIL import Image, ImageDraw, ImageFont
+
+
+def draw_ser_results(image,
+ ocr_results,
+ font_path="doc/fonts/simfang.ttf",
+ font_size=18):
+ np.random.seed(2021)
+ color = (np.random.permutation(range(255)),
+ np.random.permutation(range(255)),
+ np.random.permutation(range(255)))
+ color_map = {
+ idx: (color[0][idx], color[1][idx], color[2][idx])
+ for idx in range(1, 255)
+ }
+ if isinstance(image, np.ndarray):
+ image = Image.fromarray(image)
+ elif isinstance(image, str) and os.path.isfile(image):
+ image = Image.open(image).convert('RGB')
+ img_new = image.copy()
+ draw = ImageDraw.Draw(img_new)
+
+ font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
+ for ocr_info in ocr_results:
+ if ocr_info["pred_id"] not in color_map:
+ continue
+ color = color_map[ocr_info["pred_id"]]
+ text = "{}: {}".format(ocr_info["pred"], ocr_info["text"])
+
+ draw_box_txt(ocr_info["bbox"], text, draw, font, font_size, color)
+
+ img_new = Image.blend(image, img_new, 0.5)
+ return np.array(img_new)
+
+
+def draw_box_txt(bbox, text, draw, font, font_size, color):
+ # draw ocr results outline
+ bbox = ((bbox[0], bbox[1]), (bbox[2], bbox[3]))
+ draw.rectangle(bbox, fill=color)
+
+ # draw ocr results
+ start_y = max(0, bbox[0][1] - font_size)
+ tw = font.getsize(text)[0]
+ draw.rectangle(
+ [(bbox[0][0] + 1, start_y), (bbox[0][0] + tw + 1, start_y + font_size)],
+ fill=(0, 0, 255))
+ draw.text((bbox[0][0] + 1, start_y), text, fill=(255, 255, 255), font=font)
+
+
+def draw_re_results(image,
+ result,
+ font_path="doc/fonts/simfang.ttf",
+ font_size=18):
+ np.random.seed(0)
+ if isinstance(image, np.ndarray):
+ image = Image.fromarray(image)
+ elif isinstance(image, str) and os.path.isfile(image):
+ image = Image.open(image).convert('RGB')
+ img_new = image.copy()
+ draw = ImageDraw.Draw(img_new)
+
+ font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
+ color_head = (0, 0, 255)
+ color_tail = (255, 0, 0)
+ color_line = (0, 255, 0)
+
+ for ocr_info_head, ocr_info_tail in result:
+ draw_box_txt(ocr_info_head["bbox"], ocr_info_head["text"], draw, font,
+ font_size, color_head)
+ draw_box_txt(ocr_info_tail["bbox"], ocr_info_tail["text"], draw, font,
+ font_size, color_tail)
+
+ center_head = (
+ (ocr_info_head['bbox'][0] + ocr_info_head['bbox'][2]) // 2,
+ (ocr_info_head['bbox'][1] + ocr_info_head['bbox'][3]) // 2)
+ center_tail = (
+ (ocr_info_tail['bbox'][0] + ocr_info_tail['bbox'][2]) // 2,
+ (ocr_info_tail['bbox'][1] + ocr_info_tail['bbox'][3]) // 2)
+
+ draw.line([center_head, center_tail], fill=color_line, width=5)
+
+ img_new = Image.blend(image, img_new, 0.5)
+ return np.array(img_new)
diff --git a/backend/subfinder/linux/VideoSubFinderCli b/backend/subfinder/linux/VideoSubFinderCli
new file mode 100755
index 00000000..74ff3846
Binary files /dev/null and b/backend/subfinder/linux/VideoSubFinderCli differ
diff --git a/backend/subfinder/linux/VideoSubFinderCli.run b/backend/subfinder/linux/VideoSubFinderCli.run
new file mode 100755
index 00000000..9f2ad6a5
--- /dev/null
+++ b/backend/subfinder/linux/VideoSubFinderCli.run
@@ -0,0 +1,5 @@
+#!/bin/sh
+cd ${0%/*}
+export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$PWD:/lib64
+chmod +x ./VideoSubFinderCli
+./VideoSubFinderCli "$@"
diff --git a/backend/subfinder/linux/settings/general.cfg b/backend/subfinder/linux/settings/general.cfg
new file mode 100644
index 00000000..da029178
--- /dev/null
+++ b/backend/subfinder/linux/settings/general.cfg
@@ -0,0 +1,35 @@
+prefered_locale = eng
+dont_delete_unrecognized_images1 = 1
+dont_delete_unrecognized_images2 = 1
+generate_cleared_text_images_on_test = 1
+dump_debug_images = 0
+dump_debug_second_filtration_images = 0
+clear_test_images_folder = 1
+show_transformed_images_only = 0
+moderate_threshold = 0.4
+moderate_threshold_for_NEdges = 0.3
+segment_width = 8
+segment_height = 3
+minimum_segments_count = 2
+min_sum_color_diff = 500
+between_text_distace = 0.05
+text_centre_offset = 0.1
+min_points_number = 30
+min_points_density = 0.3
+min_symbol_height = 0.02
+min_symbol_density = 0.2
+min_NEdges_points_density = 0.25
+threads = 4
+sub_frame_length = 6
+text_procent = 0.3
+min_text_len_(in_procent) = 0.022
+sub_square_error = 0.3
+vedges_points_line_error = 0.35
+clear_image_logical = 0
+clean_rgb_images_after_run = 0
+def_string_for_empty_sub = sub duration: %sub_duration%
+min_sub_duration = 0
+txt_dw = 5
+txt_dy = 5
+fount_size_ocr_lbl = 8
+fount_size_ocr_btn = 10
diff --git a/backend/subfinder/windows/VideoSubFinderWXW.exe b/backend/subfinder/windows/VideoSubFinderWXW.exe
new file mode 100644
index 00000000..3234c138
Binary files /dev/null and b/backend/subfinder/windows/VideoSubFinderWXW.exe differ
diff --git a/backend/subfinder/windows/avcodec-58.dll b/backend/subfinder/windows/avcodec-58.dll
new file mode 100644
index 00000000..cd120d4f
Binary files /dev/null and b/backend/subfinder/windows/avcodec-58.dll differ
diff --git a/backend/subfinder/windows/avdevice-58.dll b/backend/subfinder/windows/avdevice-58.dll
new file mode 100644
index 00000000..66093493
Binary files /dev/null and b/backend/subfinder/windows/avdevice-58.dll differ
diff --git a/backend/subfinder/windows/avfilter-7.dll b/backend/subfinder/windows/avfilter-7.dll
new file mode 100644
index 00000000..21f6cf3e
Binary files /dev/null and b/backend/subfinder/windows/avfilter-7.dll differ
diff --git a/backend/subfinder/windows/avformat-58.dll b/backend/subfinder/windows/avformat-58.dll
new file mode 100644
index 00000000..dcc948fa
Binary files /dev/null and b/backend/subfinder/windows/avformat-58.dll differ
diff --git a/backend/subfinder/windows/avutil-56.dll b/backend/subfinder/windows/avutil-56.dll
new file mode 100644
index 00000000..b90eee56
Binary files /dev/null and b/backend/subfinder/windows/avutil-56.dll differ
diff --git a/backend/subfinder/windows/bitmaps/left_na.bmp b/backend/subfinder/windows/bitmaps/left_na.bmp
new file mode 100644
index 00000000..a30b40d2
Binary files /dev/null and b/backend/subfinder/windows/bitmaps/left_na.bmp differ
diff --git a/backend/subfinder/windows/bitmaps/left_od.bmp b/backend/subfinder/windows/bitmaps/left_od.bmp
new file mode 100644
index 00000000..3c26b886
Binary files /dev/null and b/backend/subfinder/windows/bitmaps/left_od.bmp differ
diff --git a/backend/subfinder/windows/bitmaps/right_na.bmp b/backend/subfinder/windows/bitmaps/right_na.bmp
new file mode 100644
index 00000000..35112740
Binary files /dev/null and b/backend/subfinder/windows/bitmaps/right_na.bmp differ
diff --git a/backend/subfinder/windows/bitmaps/right_od.bmp b/backend/subfinder/windows/bitmaps/right_od.bmp
new file mode 100644
index 00000000..3e3ce81e
Binary files /dev/null and b/backend/subfinder/windows/bitmaps/right_od.bmp differ
diff --git a/backend/subfinder/windows/bitmaps/sb_la.bmp b/backend/subfinder/windows/bitmaps/sb_la.bmp
new file mode 100644
index 00000000..5bb65be4
Binary files /dev/null and b/backend/subfinder/windows/bitmaps/sb_la.bmp differ
diff --git a/backend/subfinder/windows/bitmaps/sb_lc.bmp b/backend/subfinder/windows/bitmaps/sb_lc.bmp
new file mode 100644
index 00000000..d92e411d
Binary files /dev/null and b/backend/subfinder/windows/bitmaps/sb_lc.bmp differ
diff --git a/backend/subfinder/windows/bitmaps/sb_ra.bmp b/backend/subfinder/windows/bitmaps/sb_ra.bmp
new file mode 100644
index 00000000..d55a4758
Binary files /dev/null and b/backend/subfinder/windows/bitmaps/sb_ra.bmp differ
diff --git a/backend/subfinder/windows/bitmaps/sb_rc.bmp b/backend/subfinder/windows/bitmaps/sb_rc.bmp
new file mode 100644
index 00000000..609256d3
Binary files /dev/null and b/backend/subfinder/windows/bitmaps/sb_rc.bmp differ
diff --git a/backend/subfinder/windows/bitmaps/sb_t.bmp b/backend/subfinder/windows/bitmaps/sb_t.bmp
new file mode 100644
index 00000000..ecdd46d7
Binary files /dev/null and b/backend/subfinder/windows/bitmaps/sb_t.bmp differ
diff --git a/backend/subfinder/windows/bitmaps/tb_pause.bmp b/backend/subfinder/windows/bitmaps/tb_pause.bmp
new file mode 100644
index 00000000..2a15732a
Binary files /dev/null and b/backend/subfinder/windows/bitmaps/tb_pause.bmp differ
diff --git a/backend/subfinder/windows/bitmaps/tb_run.bmp b/backend/subfinder/windows/bitmaps/tb_run.bmp
new file mode 100644
index 00000000..69253adc
Binary files /dev/null and b/backend/subfinder/windows/bitmaps/tb_run.bmp differ
diff --git a/backend/subfinder/windows/bitmaps/tb_stop.bmp b/backend/subfinder/windows/bitmaps/tb_stop.bmp
new file mode 100644
index 00000000..d73a638f
Binary files /dev/null and b/backend/subfinder/windows/bitmaps/tb_stop.bmp differ
diff --git a/backend/subfinder/windows/cudart64_110.dll b/backend/subfinder/windows/cudart64_110.dll
new file mode 100644
index 00000000..6e795cf5
Binary files /dev/null and b/backend/subfinder/windows/cudart64_110.dll differ
diff --git a/backend/subfinder/windows/finished.wav b/backend/subfinder/windows/finished.wav
new file mode 100644
index 00000000..d6e8baa8
Binary files /dev/null and b/backend/subfinder/windows/finished.wav differ
diff --git a/backend/subfinder/windows/nppc64_11.dll b/backend/subfinder/windows/nppc64_11.dll
new file mode 100644
index 00000000..064adce4
Binary files /dev/null and b/backend/subfinder/windows/nppc64_11.dll differ
diff --git a/backend/subfinder/windows/nppicc64_11.dll b/backend/subfinder/windows/nppicc64_11.dll
new file mode 100644
index 00000000..3abe17b7
Binary files /dev/null and b/backend/subfinder/windows/nppicc64_11.dll differ
diff --git a/backend/subfinder/windows/nppig64_11.dll b/backend/subfinder/windows/nppig64_11.dll
new file mode 100644
index 00000000..0859df57
Binary files /dev/null and b/backend/subfinder/windows/nppig64_11.dll differ
diff --git a/backend/subfinder/windows/opencv_videoio_ffmpeg430_64.dll b/backend/subfinder/windows/opencv_videoio_ffmpeg430_64.dll
new file mode 100644
index 00000000..af1ae6a5
Binary files /dev/null and b/backend/subfinder/windows/opencv_videoio_ffmpeg430_64.dll differ
diff --git a/backend/subfinder/windows/opencv_world430.dll b/backend/subfinder/windows/opencv_world430.dll
new file mode 100644
index 00000000..2e47847f
Binary files /dev/null and b/backend/subfinder/windows/opencv_world430.dll differ
diff --git a/backend/subfinder/windows/postproc-55.dll b/backend/subfinder/windows/postproc-55.dll
new file mode 100644
index 00000000..5692b480
Binary files /dev/null and b/backend/subfinder/windows/postproc-55.dll differ
diff --git a/backend/subfinder/windows/previous_video.inf b/backend/subfinder/windows/previous_video.inf
new file mode 100644
index 00000000..05936eb7
--- /dev/null
+++ b/backend/subfinder/windows/previous_video.inf
@@ -0,0 +1,4 @@
+C:\Users\fangyao\Downloads\test.mp4
+0
+53766
+0
\ No newline at end of file
diff --git a/backend/subfinder/windows/settings/eng/locale.cfg b/backend/subfinder/windows/settings/eng/locale.cfg
new file mode 100644
index 00000000..d7d24fe7
--- /dev/null
+++ b/backend/subfinder/windows/settings/eng/locale.cfg
@@ -0,0 +1,102 @@
+label_text_alignment = Text Alignment
+ocr_label_msd_text = Min Sub Duration
+ocr_label_jsact_text = Join Subs And Correct Time
+ocr_label_clear_txt_folders = Clear TXT Folders Before Run
+ocr_button_ccti_text = Create Cleared TXTImages
+ocr_button_csftr_text = Create Sub From TXTResults
+ocr_button_cesfcti_text = Create Empty Sub From Cleared TXTImages
+ocr_button_ces_text = Create Empty Sub From RGBImages
+ocr_button_join_text = Join TXTImages
+ocr_button_test_text = Test
+ocr_label_save_each_substring_separately = Save Each Substring Separately
+ocr_label_save_scaled_images = Save Scaled Images
+ssp_label_parameters_influencing_image_processing = Parameters Influencing Image Processing
+ssp_label_ocl_and_multiframe_image_stream_processing = OCR and Multiframe Image Stream Processing
+ssp_oi_group_global_image_processing_settings = Global Image Processing Settings
+ssp_oi_property_use_ocl = Use OCL In OpenCV
+ssp_oi_property_use_cuda_gpu = Use CUDA GPU Acceleration
+ssp_oi_property_image_scale_for_clear_image = Image Scale For Clear Image
+ssp_oi_property_cpu_kmeans_initial_loop_iterations = CPU kmeans initial loop iterations
+ssp_oi_property_cpu_kmeans_loop_iterations = CPU kmeans loop iterations
+ssp_oi_property_cuda_kmeans_initial_loop_iterations = CUDA kmeans initial loop iterations
+ssp_oi_property_cuda_kmeans_loop_iterations = CUDA kmeans loop iterations
+ssp_oi_property_generate_cleared_text_images_on_test = Generate Cleared Text Images On Test Button
+ssp_oi_property_dump_debug_images = Dump Debug Images
+ssp_oi_property_dump_debug_second_filtration_images = Dump Debug Secondary Processing Images
+ssp_oi_property_clear_test_images_folder = Clear Test Images Folder
+ssp_oi_property_show_transformed_images_only = Show Transformed Images Only
+ssp_oi_group_initial_image_processing = Initial Image Processing
+ssp_oi_sub_group_settings_for_sobel_operators = Settings For Sobel Operators
+ssp_oi_property_moderate_threshold = Moderate Threshold
+ssp_oi_property_moderate_nedges_threshold = Moderate NEdges Threshold
+ssp_oi_sub_group_settings_for_color_filtering = Settings For Color Filtering
+ssp_oi_property_segment_width = Line Segment Width
+ssp_oi_property_min_segments_count = Min Segments Count
+ssp_oi_property_min_sum_color_difference = Min Sum Color Difference
+ssp_oi_group_secondary_image_processing = Secondary Image Processing
+ssp_oi_sub_group_settings_for_linear_filtering = Settings For Linear Filtering
+ssp_oi_property_line_height = Line Segment Height
+ssp_oi_property_max_between_text_distance = Max Between Text Distance
+ssp_oi_property_max_text_center_offset = Max Text Offset
+ssp_oi_property_max_text_center_percent_offset = Max Text Center Percent Offset
+ssp_oi_sub_group_settings_for_color_border_points = Settings For Color Border Points
+ssp_oi_property_min_points_number = Min Points Number
+ssp_oi_property_min_points_density = Min Points Density
+ssp_oi_property_min_symbol_height = Min Symbol Height (in % to Full Image Height)
+ssp_oi_property_min_symbol_density = Min Symbol Density (in % to Its Size)
+ssp_oi_property_min_vedges_points_density = Min VEdges points density
+ssp_oi_property_min_nedges_points_density = Min NEdges points density
+ssp_oi_property_min_sum_multiple_color_difference = Min Sum Multiple Color Difference
+ssp_oi_group_tertiary_image_processing = Tertiary Image Processing
+ssp_oi_property_min_vedges_points_density_per_half_line = Min VEdges points density (per half line)
+ssp_oi_property_min_hedges_points_density_per_half_line = Min HEdges points density (per half line)
+ssp_oi_property_min_nedges_points_density_per_half_line = Min NEdges points density (per half line)
+ssp_oim_group_ocr_settings = OCR settings
+ssp_oim_property_clear_images_logical = Clear Images Logical;(don't use on Hieroglyph or Arabic subs)
+ssp_oim_property_clear_rgbimages_after_search_subtitles = Clear RGBImages after search subtitles
+ssp_oim_property_using_isaimages_for_getting_txt_areas = Use ISAImages for getting TXT areas
+ssp_oim_property_using_ilaimages_for_getting_txt_areas = Use ILAImages for getting TXT areas
+label_ILA_images_for_getting_txt_symbols_areas = Use ILAImages for getting TXT symbols areas
+label_use_ILA_images_before_clear_txt_images_from_borders = Use ILAImages before clear TXT images from borders
+ssp_oim_property_validate_and_compare_cleared_txt_images = (NotRealized) Validate And Compare Cleared TXTImages
+ssp_oim_property_dont_delete_unrecognized_images_first = Don't Delete Unrecognized Images (First)
+ssp_oim_property_dont_delete_unrecognized_images_second = Don't Delete Unrecognized Images (Second)
+ssp_oim_property_default_string_for_empty_sub = Default string for empty sub
+ssp_oim_group_settings_for_multiframe_image_processing = Settings For Multi-Frame Image Processing
+ssp_oim_sub_group_settings_for_sub_detection = Settings For Sub Detection
+ssp_oim_property_threads = Number Of Parallel Tasks;(For Run Search)
+ssp_ocr_threads = Number Of Parallel Tasks;(For Create Cleared TXTImages)
+ssp_oim_property_sub_frames_length = Sub Frames Length
+ssp_oim_property_use_ILA_images_for_search_subtitles = Use ILAImages for search subtitles
+ssp_oim_property_use_ISA_images_for_search_subtitles = Analyze ISAImages for sub presence
+ssp_oim_property_replace_ISA_by_filtered_version = Replace ISAImages by filtered version
+ssp_oim_property_max_dl_down = Max luminance diff from down for IL image generation
+ssp_oim_property_max_dl_up = Max luminance diff from up for IL image generation
+ssp_oim_sub_group_settings_for_comparing_subs = Settings For Comparing Subs
+ssp_oim_property_vedges_points_line_error = VEdges Points line error
+ssp_oim_property_ila_points_line_error = ILA Points line error
+ssp_oim_sub_group_settings_for_checking_sub = Settings For Checking Sub
+ssp_oim_property_text_percent = Text Percent
+ssp_oim_property_min_text_length = Min Text Length
+ssp_oim_property_use_gradient_images_for_clear_txt_images = Use Gradient Images For Clear TXTImages
+ssp_oim_property_use_ILA_images_for_clear_txt_images = Use ILAImages For Clear TXTImages
+ssp_oim_property_clear_txt_images_by_main_color = Clear TXTImages By Main Color
+ssp_oi_property_moderate_threshold_for_scaled_image = Moderate Threshold For Scaled Image
+ssp_oim_property_remove_wide_symbols = Remove too wide symbols;(don't use for Arabic or handwritten subs)
+ssp_hw_device = FFMPEG HW Devices
+label_filter_descr = FFMPEG Video Filters
+label_settings_file = Current Settings File
+label_playback_sound = Playback Sound On Task Finished
+label_border_is_darker = Characters Border Is Darker
+label_extend_by_grey_color = Extend By Grey Color;(try to use in case of subs with unstable luminance)
+label_allow_min_luminance = Allow Min Luminance;(used only if "Extend By Grey Color" is set)
+ssp_oim_sub_group_settings_for_update_video_color = Settings For Update Video Color
+label_video_contrast = Video Contrast
+label_video_gamma = Video Gamma
+label_pixel_color = Pixel Color;(By 'Left Mouse Click' in Video Box)
+label_use_filter_color = Use Filter Colors;(Use 'Ctrl+Enter' for add New Line);(Press and hold 'T'/'R'/'Y'/'U' button in Video Box for check)
+label_use_outline_filter_color = Use Outline Filter Colors;(Use 'Ctrl+Enter' for add New Line);(Press and hold 'T'/'R'/'I'/'U' button in Video Box for check)
+label_dL_color = default dL For RGB and Lab Filter Colors
+label_dA_color = default dA For Lab Filter Colors
+label_dB_color = default dB For Lab Filter Colors
+label_combine_to_single_cluster = Combine To Single Cluster;(can be used in case of multiple colors in single line)
diff --git a/backend/subfinder/windows/settings/general.cfg b/backend/subfinder/windows/settings/general.cfg
new file mode 100644
index 00000000..5e965749
--- /dev/null
+++ b/backend/subfinder/windows/settings/general.cfg
@@ -0,0 +1,82 @@
+prefered_locale = eng
+ocr_join_txt_images_split_line = [begin_time] --> [end_time]
+process_affinity_mask = -1
+fount_size_lbl = 10
+fount_size_btn = 13
+dont_delete_unrecognized_images1 = 0
+dont_delete_unrecognized_images2 = 1
+generate_cleared_text_images_on_test = 1
+dump_debug_images = 0
+dump_debug_second_filtration_images = 0
+clear_test_images_folder = 1
+show_transformed_images_only = 0
+use_ocl = 1
+use_cuda_gpu = 0
+use_filter_color = none
+use_outline_filter_color = none
+dL_color = 40
+dA_color = 30
+dB_color = 30
+combine_to_single_cluster = 0
+cuda_kmeans_initial_loop_iterations = 20
+cuda_kmeans_loop_iterations = 30
+cpu_kmeans_initial_loop_iterations = 20
+cpu_kmeans_loop_iterations = 30
+moderate_threshold_for_scaled_image = 0.25
+moderate_threshold = 0.25
+moderate_threshold_for_NEdges = 0.25
+segment_width = 8
+segment_height = 3
+minimum_segments_count = 2
+min_sum_color_diff = 0
+between_text_distace = 0.07
+text_centre_offset = 0.2
+image_scale_for_clear_image = 4
+use_ISA_images = 1
+use_ILA_images = 1
+use_ILA_images_for_getting_txt_symbols_areas = 0
+use_ILA_images_before_clear_txt_images_from_borders = 0
+use_gradient_images_for_clear_txt_images = 1
+clear_txt_images_by_main_color = 1
+use_ILA_images_for_clear_txt_images = 1
+min_points_number = 30
+min_points_density = 0.3
+min_symbol_height = 0.02
+min_symbol_density = 0.2
+min_NEdges_points_density = 0.2
+threads = -1
+ocr_threads = -1
+sub_frame_length = 6
+text_percent = 0.3
+min_text_len_in_percent = 0.022
+vedges_points_line_error = 0.3
+ila_points_line_error = 0.3
+video_contrast = 1
+video_gamma = 1
+clear_txt_folders = 1
+join_subs_and_correct_time = 1
+clear_image_logical = 0
+clean_rgb_images_after_run = 0
+def_string_for_empty_sub = sub duration: %sub_duration%
+min_sub_duration = 0
+txt_dw = 5
+txt_dy = 5
+use_ISA_images_for_search_subtitles = 1
+use_ILA_images_for_search_subtitles = 1
+replace_ISA_by_filtered_version = 1
+max_dl_down = 20
+max_dl_up = 40
+remove_wide_symbols = 0
+hw_device = cpu
+filter_descr = none
+text_alignment = Center
+save_each_substring_separately = 0
+save_scaled_images = 1
+playback_sound = 0
+border_is_darker = 1
+extend_by_grey_color = 0
+allow_min_luminance = 100
+bottom_video_image_percent_end = 0
+top_video_image_percent_end = 0.3
+left_video_image_percent_end = 0
+right_video_image_percent_end = 1
diff --git a/backend/subfinder/windows/swresample-3.dll b/backend/subfinder/windows/swresample-3.dll
new file mode 100644
index 00000000..51134cf1
Binary files /dev/null and b/backend/subfinder/windows/swresample-3.dll differ
diff --git a/backend/subfinder/windows/swscale-5.dll b/backend/subfinder/windows/swscale-5.dll
new file mode 100644
index 00000000..40e591e9
Binary files /dev/null and b/backend/subfinder/windows/swscale-5.dll differ
diff --git a/backend/tools/NotoSansCJK-Bold.otf b/backend/tools/NotoSansCJK-Bold.otf
new file mode 100644
index 00000000..7f666ddb
Binary files /dev/null and b/backend/tools/NotoSansCJK-Bold.otf differ
diff --git a/backend/tools/__init__.py b/backend/tools/__init__.py
new file mode 100644
index 00000000..d56c9dba
--- /dev/null
+++ b/backend/tools/__init__.py
@@ -0,0 +1,14 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/backend/tools/constant.py b/backend/tools/constant.py
new file mode 100644
index 00000000..6b8f6637
--- /dev/null
+++ b/backend/tools/constant.py
@@ -0,0 +1,26 @@
+from enum import Enum
+
+
+# 默认字幕出现的大致区域
+class SubtitleArea(Enum):
+ # 字幕区域出现在下半部分
+ LOWER_PART = 0
+ # 字幕区域出现在上半部分
+ UPPER_PART = 1
+ # 不知道字幕区域可能出现的位置
+ UNKNOWN = 2
+ # 明确知道字幕区域出现的位置
+ CUSTOM = 3
+
+
+class BackgroundColor(Enum):
+ # 字幕背景
+ WHITE = 0
+ DARK = 1
+ UNKNOWN = 2
+
+
+BGR_COLOR_GREEN = (0, 0xff, 0)
+BGR_COLOR_BLUE = (0xff, 0, 0)
+BGR_COLOR_RED = (0, 0, 0xff)
+BGR_COLOR_WHITE = (0xff, 0xff, 0xff)
diff --git a/backend/tools/eval.py b/backend/tools/eval.py
index 9817fa75..cab28334 100755
--- a/backend/tools/eval.py
+++ b/backend/tools/eval.py
@@ -20,15 +20,14 @@
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
-sys.path.append(__dir__)
-sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
+sys.path.insert(0, __dir__)
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
from ppocr.data import build_dataloader
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric
-from ppocr.utils.save_load import init_model
-from ppocr.utils.utility import print_dict
+from ppocr.utils.save_load import load_model
import tools.program as program
@@ -44,12 +43,51 @@ def main():
# build model
# for rec algorithm
if hasattr(post_process_class, 'character'):
- config['Architecture']["Head"]['out_channels'] = len(
- getattr(post_process_class, 'character'))
+ char_num = len(getattr(post_process_class, 'character'))
+ if config['Architecture']["algorithm"] in ["Distillation",
+ ]: # distillation model
+ for key in config['Architecture']["Models"]:
+ if config['Architecture']['Models'][key]['Head'][
+ 'name'] == 'MultiHead': # for multi head
+ out_channels_list = {}
+ if config['PostProcess'][
+ 'name'] == 'DistillationSARLabelDecode':
+ char_num = char_num - 2
+ out_channels_list['CTCLabelDecode'] = char_num
+ out_channels_list['SARLabelDecode'] = char_num + 2
+ config['Architecture']['Models'][key]['Head'][
+ 'out_channels_list'] = out_channels_list
+ else:
+ config['Architecture']["Models"][key]["Head"][
+ 'out_channels'] = char_num
+ elif config['Architecture']['Head'][
+ 'name'] == 'MultiHead': # for multi head
+ out_channels_list = {}
+ if config['PostProcess']['name'] == 'SARLabelDecode':
+ char_num = char_num - 2
+ out_channels_list['CTCLabelDecode'] = char_num
+ out_channels_list['SARLabelDecode'] = char_num + 2
+ config['Architecture']['Head'][
+ 'out_channels_list'] = out_channels_list
+ else: # base rec model
+ config['Architecture']["Head"]['out_channels'] = char_num
+
model = build_model(config['Architecture'])
- use_srn = config['Architecture']['algorithm'] == "SRN"
+ extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR"]
+ extra_input = False
+ if config['Architecture']['algorithm'] == 'Distillation':
+ for key in config['Architecture']["Models"]:
+ extra_input = extra_input or config['Architecture']['Models'][key][
+ 'algorithm'] in extra_input_models
+ else:
+ extra_input = config['Architecture']['algorithm'] in extra_input_models
+ if "model_type" in config['Architecture'].keys():
+ model_type = config['Architecture']['model_type']
+ else:
+ model_type = None
- best_model_dict = init_model(config, model, logger)
+ best_model_dict = load_model(
+ config, model, model_type=config['Architecture']["model_type"])
if len(best_model_dict):
logger.info('metric in ckpt ***************')
for k, v in best_model_dict.items():
@@ -57,10 +95,9 @@ def main():
# build metric
eval_class = build_metric(config['Metric'])
-
# start eval
metric = program.eval(model, valid_dataloader, post_process_class,
- eval_class, use_srn)
+ eval_class, model_type, extra_input)
logger.info('metric eval ***************')
for k, v in metric.items():
logger.info('{}:{}'.format(k, v))
diff --git a/backend/tools/export_center.py b/backend/tools/export_center.py
new file mode 100644
index 00000000..9a6372f1
--- /dev/null
+++ b/backend/tools/export_center.py
@@ -0,0 +1,76 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import sys
+import pickle
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
+
+from ppocr.data import build_dataloader
+from ppocr.modeling.architectures import build_model
+from ppocr.postprocess import build_post_process
+from ppocr.utils.save_load import load_model
+import tools.program as program
+
+
+def main():
+ global_config = config['Global']
+ # build dataloader
+ config['Eval']['dataset']['name'] = config['Train']['dataset']['name']
+ config['Eval']['dataset']['data_dir'] = config['Train']['dataset'][
+ 'data_dir']
+ config['Eval']['dataset']['label_file_list'] = config['Train']['dataset'][
+ 'label_file_list']
+ eval_dataloader = build_dataloader(config, 'Eval', device, logger)
+
+ # build post process
+ post_process_class = build_post_process(config['PostProcess'],
+ global_config)
+
+ # build model
+ # for rec algorithm
+ if hasattr(post_process_class, 'character'):
+ char_num = len(getattr(post_process_class, 'character'))
+ config['Architecture']["Head"]['out_channels'] = char_num
+
+ #set return_features = True
+ config['Architecture']["Head"]["return_feats"] = True
+
+ model = build_model(config['Architecture'])
+
+ best_model_dict = load_model(config, model)
+ if len(best_model_dict):
+ logger.info('metric in ckpt ***************')
+ for k, v in best_model_dict.items():
+ logger.info('{}:{}'.format(k, v))
+
+ # get features from train data
+ char_center = program.get_center(model, eval_dataloader, post_process_class)
+
+ #serialize to disk
+ with open("train_center.pkl", 'wb') as f:
+ pickle.dump(char_center, f)
+ return
+
+
+if __name__ == '__main__':
+ config, device, logger, vdl_writer = program.preprocess()
+ main()
diff --git a/backend/tools/export_model.py b/backend/tools/export_model.py
index 1e9526e0..76c716e0 100755
--- a/backend/tools/export_model.py
+++ b/backend/tools/export_model.py
@@ -17,7 +17,7 @@
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
-sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
+sys.path.append(os.path.abspath(os.path.join(__dir__, "..")))
import argparse
@@ -26,75 +26,146 @@
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
-from ppocr.utils.save_load import init_model
+from ppocr.utils.save_load import load_model
from ppocr.utils.logging import get_logger
from tools.program import load_config, merge_config, ArgsParser
-def parse_args():
- parser = argparse.ArgumentParser()
- parser.add_argument("-c", "--config", help="configuration file to use")
- parser.add_argument(
- "-o", "--output_path", type=str, default='./output/infer/')
- return parser.parse_args()
-
-
-def main():
- FLAGS = ArgsParser().parse_args()
- config = load_config(FLAGS.config)
- merge_config(FLAGS.opt)
- logger = get_logger()
- # build post process
-
- post_process_class = build_post_process(config['PostProcess'],
- config['Global'])
-
- # build model
- # for rec algorithm
- if hasattr(post_process_class, 'character'):
- char_num = len(getattr(post_process_class, 'character'))
- config['Architecture']["Head"]['out_channels'] = char_num
- model = build_model(config['Architecture'])
- init_model(config, model, logger)
- model.eval()
-
- save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
-
- if config['Architecture']['algorithm'] == "SRN":
+def export_single_model(model, arch_config, save_path, logger, quanter=None):
+ if arch_config["algorithm"] == "SRN":
+ max_text_length = arch_config["Head"]["max_text_length"]
other_shape = [
paddle.static.InputSpec(
- shape=[None, 1, 64, 256], dtype='float32'), [
+ shape=[None, 1, 64, 256], dtype="float32"), [
paddle.static.InputSpec(
shape=[None, 256, 1],
dtype="int64"), paddle.static.InputSpec(
- shape=[None, 25, 1],
- dtype="int64"), paddle.static.InputSpec(
- shape=[None, 8, 25, 25], dtype="int64"),
+ shape=[None, max_text_length, 1], dtype="int64"),
paddle.static.InputSpec(
- shape=[None, 8, 25, 25], dtype="int64")
+ shape=[None, 8, max_text_length, max_text_length],
+ dtype="int64"), paddle.static.InputSpec(
+ shape=[None, 8, max_text_length, max_text_length],
+ dtype="int64")
]
]
model = to_static(model, input_spec=other_shape)
+ elif arch_config["algorithm"] == "SAR":
+ other_shape = [
+ paddle.static.InputSpec(
+ shape=[None, 3, 48, 160], dtype="float32"),
+ ]
+ model = to_static(model, input_spec=other_shape)
+ elif arch_config["algorithm"] == "SVTR":
+ if arch_config["Head"]["name"] == 'MultiHead':
+ other_shape = [
+ paddle.static.InputSpec(
+ shape=[None, 3, 48, -1], dtype="float32"),
+ ]
+ else:
+ other_shape = [
+ paddle.static.InputSpec(
+ shape=[None, 3, 64, 256], dtype="float32"),
+ ]
+ model = to_static(model, input_spec=other_shape)
+ elif arch_config["algorithm"] == "PREN":
+ other_shape = [
+ paddle.static.InputSpec(
+ shape=[None, 3, 64, 512], dtype="float32"),
+ ]
+ model = to_static(model, input_spec=other_shape)
else:
infer_shape = [3, -1, -1]
- if config['Architecture']['model_type'] == "rec":
+ if arch_config["model_type"] == "rec":
infer_shape = [3, 32, -1] # for rec model, H must be 32
- if 'Transform' in config['Architecture'] and config['Architecture'][
- 'Transform'] is not None and config['Architecture'][
- 'Transform']['name'] == 'TPS':
+ if "Transform" in arch_config and arch_config[
+ "Transform"] is not None and arch_config["Transform"][
+ "name"] == "TPS":
logger.info(
- 'When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training'
+ "When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
)
infer_shape[-1] = 100
+ if arch_config["algorithm"] == "NRTR":
+ infer_shape = [1, 32, 100]
+ elif arch_config["model_type"] == "table":
+ infer_shape = [3, 488, 488]
model = to_static(
model,
input_spec=[
paddle.static.InputSpec(
- shape=[None] + infer_shape, dtype='float32')
+ shape=[None] + infer_shape, dtype="float32")
])
- paddle.jit.save(model, save_path)
- logger.info('inference model is saved to {}'.format(save_path))
+ if quanter is None:
+ paddle.jit.save(model, save_path)
+ else:
+ quanter.save_quantized_model(model, save_path)
+ logger.info("inference model is saved to {}".format(save_path))
+ return
+
+
+def main():
+ FLAGS = ArgsParser().parse_args()
+ config = load_config(FLAGS.settings_config)
+ config = merge_config(config, FLAGS.opt)
+ logger = get_logger()
+ # build post process
+
+ post_process_class = build_post_process(config["PostProcess"],
+ config["Global"])
+
+ # build model
+ # for rec algorithm
+ if hasattr(post_process_class, "character"):
+ char_num = len(getattr(post_process_class, "character"))
+ if config["Architecture"]["algorithm"] in ["Distillation",
+ ]: # distillation model
+ for key in config["Architecture"]["Models"]:
+ if config["Architecture"]["Models"][key]["Head"][
+ "name"] == 'MultiHead': # multi head
+ out_channels_list = {}
+ if config['PostProcess'][
+ 'name'] == 'DistillationSARLabelDecode':
+ char_num = char_num - 2
+ out_channels_list['CTCLabelDecode'] = char_num
+ out_channels_list['SARLabelDecode'] = char_num + 2
+ config['Architecture']['Models'][key]['Head'][
+ 'out_channels_list'] = out_channels_list
+ else:
+ config["Architecture"]["Models"][key]["Head"][
+ "out_channels"] = char_num
+ # just one final tensor needs to exported for inference
+ config["Architecture"]["Models"][key][
+ "return_all_feats"] = False
+ elif config['Architecture']['Head'][
+ 'name'] == 'MultiHead': # multi head
+ out_channels_list = {}
+ char_num = len(getattr(post_process_class, 'character'))
+ if config['PostProcess']['name'] == 'SARLabelDecode':
+ char_num = char_num - 2
+ out_channels_list['CTCLabelDecode'] = char_num
+ out_channels_list['SARLabelDecode'] = char_num + 2
+ config['Architecture']['Head'][
+ 'out_channels_list'] = out_channels_list
+ else: # base rec model
+ config["Architecture"]["Head"]["out_channels"] = char_num
+
+ model = build_model(config["Architecture"])
+ load_model(config, model)
+ model.eval()
+
+ save_path = config["Global"]["save_inference_dir"]
+
+ arch_config = config["Architecture"]
+
+ if arch_config["algorithm"] in ["Distillation", ]: # distillation model
+ archs = list(arch_config["Models"].values())
+ for idx, name in enumerate(model.model_name_list):
+ sub_model_save_path = os.path.join(save_path, name, "inference")
+ export_single_model(model.model_list[idx], archs[idx],
+ sub_model_save_path, logger)
+ else:
+ save_path = os.path.join(save_path, "inference")
+ export_single_model(model, arch_config, save_path, logger)
if __name__ == "__main__":
diff --git a/backend/tools/infer/predict_cls.py b/backend/tools/infer/predict_cls.py
index 074172cc..ed2f47c0 100755
--- a/backend/tools/infer/predict_cls.py
+++ b/backend/tools/infer/predict_cls.py
@@ -16,7 +16,7 @@
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
-sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
@@ -45,8 +45,9 @@ def __init__(self, args):
"label_list": args.label_list,
}
self.postprocess_op = build_post_process(postprocess_params)
- self.predictor, self.input_tensor, self.output_tensors = \
+ self.predictor, self.input_tensor, self.output_tensors, _ = \
utility.create_predictor(args, 'cls', logger)
+ self.use_onnx = args.use_onnx
def resize_norm_img(self, img):
imgC, imgH, imgW = self.cls_image_shape
@@ -84,9 +85,11 @@ def __call__(self, img_list):
batch_num = self.cls_batch_num
elapse = 0
for beg_img_no in range(0, img_num, batch_num):
+
end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = []
max_wh_ratio = 0
+ starttime = time.time()
for ino in range(beg_img_no, end_img_no):
h, w = img_list[indices[ino]].shape[0:2]
wh_ratio = w * 1.0 / h
@@ -97,11 +100,17 @@ def __call__(self, img_list):
norm_img_batch.append(norm_img)
norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy()
- starttime = time.time()
- self.input_tensor.copy_from_cpu(norm_img_batch)
- self.predictor.run()
- prob_out = self.output_tensors[0].copy_to_cpu()
+ if self.use_onnx:
+ input_dict = {}
+ input_dict[self.input_tensor.name] = norm_img_batch
+ outputs = self.predictor.run(self.output_tensors, input_dict)
+ prob_out = outputs[0]
+ else:
+ self.input_tensor.copy_from_cpu(norm_img_batch)
+ self.predictor.run()
+ prob_out = self.output_tensors[0].copy_to_cpu()
+ self.predictor.try_shrink_memory()
cls_result = self.postprocess_op(prob_out)
elapse += time.time() - starttime
for rno in range(len(cls_result)):
@@ -129,20 +138,13 @@ def main(args):
img_list.append(img)
try:
img_list, cls_res, predict_time = text_classifier(img_list)
- except:
+ except Exception as E:
logger.info(traceback.format_exc())
- logger.info(
- "ERROR!!!! \n"
- "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
- "If your model has tps module: "
- "TPS does not support variable shape.\n"
- "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
+ logger.info(E)
exit()
for ino in range(len(img_list)):
logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
cls_res[ino]))
- logger.info("Total predict time for {} images, cost: {:.3f}".format(
- len(img_list), predict_time))
if __name__ == "__main__":
diff --git a/backend/tools/infer/predict_det.py b/backend/tools/infer/predict_det.py
index b14825bd..5f2675d6 100755
--- a/backend/tools/infer/predict_det.py
+++ b/backend/tools/infer/predict_det.py
@@ -16,7 +16,7 @@
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
-sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
@@ -30,7 +30,7 @@
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.data import create_operators, transform
from ppocr.postprocess import build_post_process
-
+import json
logger = get_logger()
@@ -38,8 +38,12 @@ class TextDetector(object):
def __init__(self, args):
self.args = args
self.det_algorithm = args.det_algorithm
+ self.use_onnx = args.use_onnx
pre_process_list = [{
- 'DetResizeForTest': None
+ 'DetResizeForTest': {
+ 'limit_side_len': args.det_limit_side_len,
+ 'limit_type': args.det_limit_type,
+ }
}, {
'NormalizeImage': {
'std': [0.229, 0.224, 0.225],
@@ -62,6 +66,7 @@ def __init__(self, args):
postprocess_params["max_candidates"] = 1000
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
postprocess_params["use_dilation"] = args.use_dilation
+ postprocess_params["score_mode"] = args.det_db_score_mode
elif self.det_algorithm == "EAST":
postprocess_params['name'] = 'EASTPostProcess'
postprocess_params["score_thresh"] = args.det_east_score_thresh
@@ -85,38 +90,73 @@ def __init__(self, args):
postprocess_params["sample_pts_num"] = 2
postprocess_params["expand_scale"] = 1.0
postprocess_params["shrink_ratio_of_width"] = 0.3
+ elif self.det_algorithm == "PSE":
+ postprocess_params['name'] = 'PSEPostProcess'
+ postprocess_params["thresh"] = args.det_pse_thresh
+ postprocess_params["box_thresh"] = args.det_pse_box_thresh
+ postprocess_params["min_area"] = args.det_pse_min_area
+ postprocess_params["box_type"] = args.det_pse_box_type
+ postprocess_params["scale"] = args.det_pse_scale
+ self.det_pse_box_type = args.det_pse_box_type
+ elif self.det_algorithm == "FCE":
+ pre_process_list[0] = {
+ 'DetResizeForTest': {
+ 'rescale_img': [1080, 736]
+ }
+ }
+ postprocess_params['name'] = 'FCEPostProcess'
+ postprocess_params["scales"] = args.scales
+ postprocess_params["alpha"] = args.alpha
+ postprocess_params["beta"] = args.beta
+ postprocess_params["fourier_degree"] = args.fourier_degree
+ postprocess_params["box_type"] = args.det_fce_box_type
else:
logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
sys.exit(0)
self.preprocess_op = create_operators(pre_process_list)
self.postprocess_op = build_post_process(postprocess_params)
- self.predictor, self.input_tensor, self.output_tensors = utility.create_predictor(
- args, 'det', logger) # paddle.jit.load(args.det_model_dir)
- # self.predictor.eval()
+ self.predictor, self.input_tensor, self.output_tensors, self.config = utility.create_predictor(
+ args, 'det', logger)
+
+ if self.use_onnx:
+ img_h, img_w = self.input_tensor.shape[2:]
+ if img_h is not None and img_w is not None and img_h > 0 and img_w > 0:
+ pre_process_list[0] = {
+ 'DetResizeForTest': {
+ 'image_shape': [img_h, img_w]
+ }
+ }
+ self.preprocess_op = create_operators(pre_process_list)
+
+ if args.benchmark:
+ import auto_log
+ pid = os.getpid()
+ gpu_id = utility.get_infer_gpuid()
+ self.autolog = auto_log.AutoLogger(
+ model_name="det",
+ model_precision=args.precision,
+ batch_size=1,
+ data_shape="dynamic",
+ save_path=None,
+ inference_config=self.config,
+ pids=pid,
+ process_name=None,
+ gpu_ids=gpu_id if args.use_gpu else None,
+ time_keys=[
+ 'preprocess_time', 'inference_time', 'postprocess_time'
+ ],
+ warmup=2,
+ logger=logger)
def order_points_clockwise(self, pts):
- """
- reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py
- # sort the points based on their x-coordinates
- """
- xSorted = pts[np.argsort(pts[:, 0]), :]
-
- # grab the left-most and right-most points from the sorted
- # x-roodinate points
- leftMost = xSorted[:2, :]
- rightMost = xSorted[2:, :]
-
- # now, sort the left-most coordinates according to their
- # y-coordinates so we can grab the top-left and bottom-left
- # points, respectively
- leftMost = leftMost[np.argsort(leftMost[:, 1]), :]
- (tl, bl) = leftMost
-
- rightMost = rightMost[np.argsort(rightMost[:, 1]), :]
- (tr, br) = rightMost
-
- rect = np.array([tl, tr, br, bl], dtype="float32")
+ rect = np.zeros((4, 2), dtype="float32")
+ s = pts.sum(axis=1)
+ rect[0] = pts[np.argmin(s)]
+ rect[2] = pts[np.argmax(s)]
+ diff = np.diff(pts, axis=1)
+ rect[1] = pts[np.argmin(diff)]
+ rect[3] = pts[np.argmax(diff)]
return rect
def clip_det_res(self, points, img_height, img_width):
@@ -151,6 +191,12 @@ def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
def __call__(self, img):
ori_im = img.copy()
data = {'image': img}
+
+ st = time.time()
+
+ if self.args.benchmark:
+ self.autolog.times.start()
+
data = transform(data, self.preprocess_op)
img, shape_list = data
if img is None:
@@ -158,14 +204,22 @@ def __call__(self, img):
img = np.expand_dims(img, axis=0)
shape_list = np.expand_dims(shape_list, axis=0)
img = img.copy()
- starttime = time.time()
- self.input_tensor.copy_from_cpu(img)
- self.predictor.run()
- outputs = []
- for output_tensor in self.output_tensors:
- output = output_tensor.copy_to_cpu()
- outputs.append(output)
+ if self.args.benchmark:
+ self.autolog.times.stamp()
+ if self.use_onnx:
+ input_dict = {}
+ input_dict[self.input_tensor.name] = img
+ outputs = self.predictor.run(self.output_tensors, input_dict)
+ else:
+ self.input_tensor.copy_from_cpu(img)
+ self.predictor.run()
+ outputs = []
+ for output_tensor in self.output_tensors:
+ output = output_tensor.copy_to_cpu()
+ outputs.append(output)
+ if self.args.benchmark:
+ self.autolog.times.stamp()
preds = {}
if self.det_algorithm == "EAST":
@@ -176,19 +230,28 @@ def __call__(self, img):
preds['f_score'] = outputs[1]
preds['f_tco'] = outputs[2]
preds['f_tvo'] = outputs[3]
- elif self.det_algorithm == 'DB':
+ elif self.det_algorithm in ['DB', 'PSE']:
preds['maps'] = outputs[0]
+ elif self.det_algorithm == 'FCE':
+ for i, output in enumerate(outputs):
+ preds['level_{}'.format(i)] = output
else:
raise NotImplementedError
+ #self.predictor.try_shrink_memory()
post_result = self.postprocess_op(preds, shape_list)
dt_boxes = post_result[0]['points']
- if self.det_algorithm == "SAST" and self.det_sast_polygon:
+ if (self.det_algorithm == "SAST" and self.det_sast_polygon) or (
+ self.det_algorithm in ["PSE", "FCE"] and
+ self.postprocess_op.box_type == 'poly'):
dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
else:
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
- elapse = time.time() - starttime
- return dt_boxes, elapse
+
+ if self.args.benchmark:
+ self.autolog.times.end(stamp=True)
+ et = time.time()
+ return dt_boxes, et - st
if __name__ == "__main__":
@@ -198,8 +261,15 @@ def __call__(self, img):
count = 0
total_time = 0
draw_img_save = "./inference_results"
+
+ if args.warmup:
+ img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
+ for i in range(2):
+ res = text_detector(img)
+
if not os.path.exists(draw_img_save):
os.makedirs(draw_img_save)
+ save_results = []
for image_file in image_file_list:
img, flag = check_and_read_gif(image_file)
if not flag:
@@ -207,16 +277,26 @@ def __call__(self, img):
if img is None:
logger.info("error in loading image:{}".format(image_file))
continue
- dt_boxes, elapse = text_detector(img)
+ st = time.time()
+ dt_boxes, _ = text_detector(img)
+ elapse = time.time() - st
if count > 0:
total_time += elapse
count += 1
- logger.info("Predict time of {}: {}".format(image_file, elapse))
+ save_pred = os.path.basename(image_file) + "\t" + str(
+ json.dumps([x.tolist() for x in dt_boxes])) + "\n"
+ save_results.append(save_pred)
+ logger.info(save_pred)
+ logger.info("The predict time of {}: {}".format(image_file, elapse))
src_im = utility.draw_text_det_res(dt_boxes, image_file)
img_name_pure = os.path.split(image_file)[-1]
img_path = os.path.join(draw_img_save,
"det_res_{}".format(img_name_pure))
cv2.imwrite(img_path, src_im)
logger.info("The visualized image saved in {}".format(img_path))
- if count > 1:
- logger.info("Avg Time: {}".format(total_time / (count - 1)))
+
+ with open(os.path.join(draw_img_save, "det_results.txt"), 'w') as f:
+ f.writelines(save_results)
+ f.close()
+ if args.benchmark:
+ text_detector.autolog.report()
diff --git a/backend/tools/infer/predict_e2e.py b/backend/tools/infer/predict_e2e.py
new file mode 100755
index 00000000..fb2859f0
--- /dev/null
+++ b/backend/tools/infer/predict_e2e.py
@@ -0,0 +1,169 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import sys
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
+
+os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+
+import cv2
+import numpy as np
+import time
+import sys
+
+import tools.infer.utility as utility
+from ppocr.utils.logging import get_logger
+from ppocr.utils.utility import get_image_file_list, check_and_read_gif
+from ppocr.data import create_operators, transform
+from ppocr.postprocess import build_post_process
+
+logger = get_logger()
+
+
+class TextE2E(object):
+ def __init__(self, args):
+ self.args = args
+ self.e2e_algorithm = args.e2e_algorithm
+ self.use_onnx = args.use_onnx
+ pre_process_list = [{
+ 'E2EResizeForTest': {}
+ }, {
+ 'NormalizeImage': {
+ 'std': [0.229, 0.224, 0.225],
+ 'mean': [0.485, 0.456, 0.406],
+ 'scale': '1./255.',
+ 'order': 'hwc'
+ }
+ }, {
+ 'ToCHWImage': None
+ }, {
+ 'KeepKeys': {
+ 'keep_keys': ['image', 'shape']
+ }
+ }]
+ postprocess_params = {}
+ if self.e2e_algorithm == "PGNet":
+ pre_process_list[0] = {
+ 'E2EResizeForTest': {
+ 'max_side_len': args.e2e_limit_side_len,
+ 'valid_set': 'totaltext'
+ }
+ }
+ postprocess_params['name'] = 'PGPostProcess'
+ postprocess_params["score_thresh"] = args.e2e_pgnet_score_thresh
+ postprocess_params["character_dict_path"] = args.e2e_char_dict_path
+ postprocess_params["valid_set"] = args.e2e_pgnet_valid_set
+ postprocess_params["mode"] = args.e2e_pgnet_mode
+ else:
+ logger.info("unknown e2e_algorithm:{}".format(self.e2e_algorithm))
+ sys.exit(0)
+
+ self.preprocess_op = create_operators(pre_process_list)
+ self.postprocess_op = build_post_process(postprocess_params)
+ self.predictor, self.input_tensor, self.output_tensors, _ = utility.create_predictor(
+ args, 'e2e', logger) # paddle.jit.load(args.det_model_dir)
+ # self.predictor.eval()
+
+ def clip_det_res(self, points, img_height, img_width):
+ for pno in range(points.shape[0]):
+ points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
+ points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
+ return points
+
+ def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
+ img_height, img_width = image_shape[0:2]
+ dt_boxes_new = []
+ for box in dt_boxes:
+ box = self.clip_det_res(box, img_height, img_width)
+ dt_boxes_new.append(box)
+ dt_boxes = np.array(dt_boxes_new)
+ return dt_boxes
+
+ def __call__(self, img):
+
+ ori_im = img.copy()
+ data = {'image': img}
+ data = transform(data, self.preprocess_op)
+ img, shape_list = data
+ if img is None:
+ return None, 0
+ img = np.expand_dims(img, axis=0)
+ shape_list = np.expand_dims(shape_list, axis=0)
+ img = img.copy()
+ starttime = time.time()
+
+ if self.use_onnx:
+ input_dict = {}
+ input_dict[self.input_tensor.name] = img
+ outputs = self.predictor.run(self.output_tensors, input_dict)
+ preds = {}
+ preds['f_border'] = outputs[0]
+ preds['f_char'] = outputs[1]
+ preds['f_direction'] = outputs[2]
+ preds['f_score'] = outputs[3]
+ else:
+ self.input_tensor.copy_from_cpu(img)
+ self.predictor.run()
+ outputs = []
+ for output_tensor in self.output_tensors:
+ output = output_tensor.copy_to_cpu()
+ outputs.append(output)
+
+ preds = {}
+ if self.e2e_algorithm == 'PGNet':
+ preds['f_border'] = outputs[0]
+ preds['f_char'] = outputs[1]
+ preds['f_direction'] = outputs[2]
+ preds['f_score'] = outputs[3]
+ else:
+ raise NotImplementedError
+ post_result = self.postprocess_op(preds, shape_list)
+ points, strs = post_result['points'], post_result['texts']
+ dt_boxes = self.filter_tag_det_res_only_clip(points, ori_im.shape)
+ elapse = time.time() - starttime
+ return dt_boxes, strs, elapse
+
+
+if __name__ == "__main__":
+ args = utility.parse_args()
+ image_file_list = get_image_file_list(args.image_dir)
+ text_detector = TextE2E(args)
+ count = 0
+ total_time = 0
+ draw_img_save = "./inference_results"
+ if not os.path.exists(draw_img_save):
+ os.makedirs(draw_img_save)
+ for image_file in image_file_list:
+ img, flag = check_and_read_gif(image_file)
+ if not flag:
+ img = cv2.imread(image_file)
+ if img is None:
+ logger.info("error in loading image:{}".format(image_file))
+ continue
+ points, strs, elapse = text_detector(img)
+ if count > 0:
+ total_time += elapse
+ count += 1
+ logger.info("Predict time of {}: {}".format(image_file, elapse))
+ src_im = utility.draw_e2e_res(points, strs, image_file)
+ img_name_pure = os.path.split(image_file)[-1]
+ img_path = os.path.join(draw_img_save,
+ "e2e_res_{}".format(img_name_pure))
+ cv2.imwrite(img_path, src_im)
+ logger.info("The visualized image saved in {}".format(img_path))
+ if count > 1:
+ logger.info("Avg Time: {}".format(total_time / (count - 1)))
diff --git a/backend/tools/infer/predict_rec.py b/backend/tools/infer/predict_rec.py
index b3d9d490..3664ef2c 100755
--- a/backend/tools/infer/predict_rec.py
+++ b/backend/tools/infer/predict_rec.py
@@ -13,10 +13,10 @@
# limitations under the License.
import os
import sys
-
+from PIL import Image
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
-sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
@@ -38,44 +38,91 @@
class TextRecognizer(object):
def __init__(self, args):
self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
- self.character_type = args.rec_char_type
self.rec_batch_num = args.rec_batch_num
self.rec_algorithm = args.rec_algorithm
postprocess_params = {
'name': 'CTCLabelDecode',
- "character_type": args.rec_char_type,
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
if self.rec_algorithm == "SRN":
postprocess_params = {
'name': 'SRNLabelDecode',
- "character_type": args.rec_char_type,
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
elif self.rec_algorithm == "RARE":
postprocess_params = {
'name': 'AttnLabelDecode',
- "character_type": args.rec_char_type,
+ "character_dict_path": args.rec_char_dict_path,
+ "use_space_char": args.use_space_char
+ }
+ elif self.rec_algorithm == 'NRTR':
+ postprocess_params = {
+ 'name': 'NRTRLabelDecode',
+ "character_dict_path": args.rec_char_dict_path,
+ "use_space_char": args.use_space_char
+ }
+ elif self.rec_algorithm == "SAR":
+ postprocess_params = {
+ 'name': 'SARLabelDecode',
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
self.postprocess_op = build_post_process(postprocess_params)
- self.predictor, self.input_tensor, self.output_tensors = \
+ self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'rec', logger)
+ self.benchmark = args.benchmark
+ self.use_onnx = args.use_onnx
+ if args.benchmark:
+ import auto_log
+ pid = os.getpid()
+ gpu_id = utility.get_infer_gpuid()
+ self.autolog = auto_log.AutoLogger(
+ model_name="rec",
+ model_precision=args.precision,
+ batch_size=args.rec_batch_num,
+ data_shape="dynamic",
+ save_path=None, #args.save_log_path,
+ inference_config=self.config,
+ pids=pid,
+ process_name=None,
+ gpu_ids=gpu_id if args.use_gpu else None,
+ time_keys=[
+ 'preprocess_time', 'inference_time', 'postprocess_time'
+ ],
+ warmup=0,
+ logger=logger)
def resize_norm_img(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape
+ if self.rec_algorithm == 'NRTR':
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+ # return padding_im
+ image_pil = Image.fromarray(np.uint8(img))
+ img = image_pil.resize([100, 32], Image.ANTIALIAS)
+ img = np.array(img)
+ norm_img = np.expand_dims(img, -1)
+ norm_img = norm_img.transpose((2, 0, 1))
+ return norm_img.astype(np.float32) / 128. - 1.
+
assert imgC == img.shape[2]
- if self.character_type == "ch":
- imgW = int((32 * max_wh_ratio))
+ imgW = int((imgH * max_wh_ratio))
+ if self.use_onnx:
+ w = self.input_tensor.shape[3:][0]
+ if w is not None and w > 0:
+ imgW = w
+
h, w = img.shape[:2]
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
+ if self.rec_algorithm == 'RARE':
+ if resized_w > self.rec_image_shape[2]:
+ resized_w = self.rec_image_shape[2]
+ imgW = self.rec_image_shape[2]
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
@@ -85,6 +132,17 @@ def resize_norm_img(self, img, max_wh_ratio):
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
+ def resize_norm_img_svtr(self, img, image_shape):
+
+ imgC, imgH, imgW = image_shape
+ resized_image = cv2.resize(
+ img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
+ resized_image = resized_image.astype('float32')
+ resized_image = resized_image.transpose((2, 0, 1)) / 255
+ resized_image -= 0.5
+ resized_image /= 0.5
+ return resized_image
+
def resize_norm_img_srn(self, img, image_shape):
imgC, imgH, imgW = image_shape
@@ -157,6 +215,41 @@ def process_image_srn(self, img, image_shape, num_heads, max_text_length):
return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2)
+ def resize_norm_img_sar(self, img, image_shape,
+ width_downsample_ratio=0.25):
+ imgC, imgH, imgW_min, imgW_max = image_shape
+ h = img.shape[0]
+ w = img.shape[1]
+ valid_ratio = 1.0
+ # make sure new_width is an integral multiple of width_divisor.
+ width_divisor = int(1 / width_downsample_ratio)
+ # resize
+ ratio = w / float(h)
+ resize_w = math.ceil(imgH * ratio)
+ if resize_w % width_divisor != 0:
+ resize_w = round(resize_w / width_divisor) * width_divisor
+ if imgW_min is not None:
+ resize_w = max(imgW_min, resize_w)
+ if imgW_max is not None:
+ valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
+ resize_w = min(imgW_max, resize_w)
+ resized_image = cv2.resize(img, (resize_w, imgH))
+ resized_image = resized_image.astype('float32')
+ # norm
+ if image_shape[0] == 1:
+ resized_image = resized_image / 255
+ resized_image = resized_image[np.newaxis, :]
+ else:
+ resized_image = resized_image.transpose((2, 0, 1)) / 255
+ resized_image -= 0.5
+ resized_image /= 0.5
+ resize_shape = resized_image.shape
+ padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
+ padding_im[:, :, 0:resize_w] = resized_image
+ pad_shape = padding_im.shape
+
+ return padding_im, resize_shape, pad_shape, valid_ratio
+
def __call__(self, img_list):
img_num = len(img_list)
# Calculate the aspect ratio of all text bars
@@ -165,27 +258,32 @@ def __call__(self, img_list):
width_list.append(img.shape[1] / float(img.shape[0]))
# Sorting can speed up the recognition process
indices = np.argsort(np.array(width_list))
-
- # rec_res = []
rec_res = [['', 0.0]] * img_num
batch_num = self.rec_batch_num
- elapse = 0
+ st = time.time()
+ if self.benchmark:
+ self.autolog.times.start()
for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = []
- max_wh_ratio = 0
+ imgC, imgH, imgW = self.rec_image_shape
+ max_wh_ratio = imgW / imgH
+ # max_wh_ratio = 0
for ino in range(beg_img_no, end_img_no):
- # h, w = img_list[ino].shape[0:2]
h, w = img_list[indices[ino]].shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no):
- if self.rec_algorithm != "SRN":
- norm_img = self.resize_norm_img(img_list[indices[ino]],
- max_wh_ratio)
+
+ if self.rec_algorithm == "SAR":
+ norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
+ img_list[indices[ino]], self.rec_image_shape)
norm_img = norm_img[np.newaxis, :]
+ valid_ratio = np.expand_dims(valid_ratio, axis=0)
+ valid_ratios = []
+ valid_ratios.append(valid_ratio)
norm_img_batch.append(norm_img)
- else:
+ elif self.rec_algorithm == "SRN":
norm_img = self.process_image_srn(
img_list[indices[ino]], self.rec_image_shape, 8, 25)
encoder_word_pos_list = []
@@ -197,11 +295,22 @@ def __call__(self, img_list):
gsrm_slf_attn_bias1_list.append(norm_img[3])
gsrm_slf_attn_bias2_list.append(norm_img[4])
norm_img_batch.append(norm_img[0])
+ elif self.rec_algorithm == "SVTR":
+ norm_img = self.resize_norm_img_svtr(img_list[indices[ino]],
+ self.rec_image_shape)
+ norm_img = norm_img[np.newaxis, :]
+ norm_img_batch.append(norm_img)
+ else:
+ norm_img = self.resize_norm_img(img_list[indices[ino]],
+ max_wh_ratio)
+ norm_img = norm_img[np.newaxis, :]
+ norm_img_batch.append(norm_img)
norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy()
+ if self.benchmark:
+ self.autolog.times.stamp()
if self.rec_algorithm == "SRN":
- starttime = time.time()
encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
gsrm_slf_attn_bias1_list = np.concatenate(
@@ -216,33 +325,78 @@ def __call__(self, img_list):
gsrm_slf_attn_bias1_list,
gsrm_slf_attn_bias2_list,
]
- input_names = self.predictor.get_input_names()
- for i in range(len(input_names)):
- input_tensor = self.predictor.get_input_handle(input_names[
- i])
- input_tensor.copy_from_cpu(inputs[i])
- self.predictor.run()
- outputs = []
- for output_tensor in self.output_tensors:
- output = output_tensor.copy_to_cpu()
- outputs.append(output)
- preds = {"predict": outputs[2]}
+ if self.use_onnx:
+ input_dict = {}
+ input_dict[self.input_tensor.name] = norm_img_batch
+ outputs = self.predictor.run(self.output_tensors,
+ input_dict)
+ preds = {"predict": outputs[2]}
+ else:
+ input_names = self.predictor.get_input_names()
+ for i in range(len(input_names)):
+ input_tensor = self.predictor.get_input_handle(
+ input_names[i])
+ input_tensor.copy_from_cpu(inputs[i])
+ self.predictor.run()
+ outputs = []
+ for output_tensor in self.output_tensors:
+ output = output_tensor.copy_to_cpu()
+ outputs.append(output)
+ if self.benchmark:
+ self.autolog.times.stamp()
+ preds = {"predict": outputs[2]}
+ elif self.rec_algorithm == "SAR":
+ valid_ratios = np.concatenate(valid_ratios)
+ inputs = [
+ norm_img_batch,
+ valid_ratios,
+ ]
+ if self.use_onnx:
+ input_dict = {}
+ input_dict[self.input_tensor.name] = norm_img_batch
+ outputs = self.predictor.run(self.output_tensors,
+ input_dict)
+ preds = outputs[0]
+ else:
+ input_names = self.predictor.get_input_names()
+ for i in range(len(input_names)):
+ input_tensor = self.predictor.get_input_handle(
+ input_names[i])
+ input_tensor.copy_from_cpu(inputs[i])
+ self.predictor.run()
+ outputs = []
+ for output_tensor in self.output_tensors:
+ output = output_tensor.copy_to_cpu()
+ outputs.append(output)
+ if self.benchmark:
+ self.autolog.times.stamp()
+ preds = outputs[0]
else:
- starttime = time.time()
- self.input_tensor.copy_from_cpu(norm_img_batch)
- self.predictor.run()
-
- outputs = []
- for output_tensor in self.output_tensors:
- output = output_tensor.copy_to_cpu()
- outputs.append(output)
- preds = outputs[0]
-
+ if self.use_onnx:
+ input_dict = {}
+ input_dict[self.input_tensor.name] = norm_img_batch
+ outputs = self.predictor.run(self.output_tensors,
+ input_dict)
+ preds = outputs[0]
+ else:
+ self.input_tensor.copy_from_cpu(norm_img_batch)
+ self.predictor.run()
+ outputs = []
+ for output_tensor in self.output_tensors:
+ output = output_tensor.copy_to_cpu()
+ outputs.append(output)
+ if self.benchmark:
+ self.autolog.times.stamp()
+ if len(outputs) != 1:
+ preds = outputs
+ else:
+ preds = outputs[0]
rec_result = self.postprocess_op(preds)
for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
- elapse += time.time() - starttime
- return rec_res, elapse
+ if self.benchmark:
+ self.autolog.times.end(stamp=True)
+ return rec_res, time.time() - st
def main(args):
@@ -250,6 +404,17 @@ def main(args):
text_recognizer = TextRecognizer(args)
valid_image_file_list = []
img_list = []
+
+ logger.info(
+ "In PP-OCRv3, rec_image_shape parameter defaults to '3, 48, 320', "
+ "if you are using recognition model with PP-OCRv2 or an older version, please set --rec_image_shape='3,32,320"
+ )
+ # warmup 2 times
+ if args.warmup:
+ img = np.random.uniform(0, 255, [48, 320, 3]).astype(np.uint8)
+ for i in range(2):
+ res = text_recognizer([img] * int(args.rec_batch_num))
+
for image_file in image_file_list:
img, flag = check_and_read_gif(image_file)
if not flag:
@@ -260,21 +425,17 @@ def main(args):
valid_image_file_list.append(image_file)
img_list.append(img)
try:
- rec_res, predict_time = text_recognizer(img_list)
- except:
+ rec_res, _ = text_recognizer(img_list)
+
+ except Exception as E:
logger.info(traceback.format_exc())
- logger.info(
- "ERROR!!!! \n"
- "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
- "If your model has tps module: "
- "TPS does not support variable shape.\n"
- "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
+ logger.info(E)
exit()
for ino in range(len(img_list)):
logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
rec_res[ino]))
- logger.info("Total predict time for {} images, cost: {:.3f}".format(
- len(img_list), predict_time))
+ if args.benchmark:
+ text_recognizer.autolog.report()
if __name__ == "__main__":
diff --git a/backend/tools/infer/predict_system.py b/backend/tools/infer/predict_system.py
index de7ee9d3..4af3da70 100755
--- a/backend/tools/infer/predict_system.py
+++ b/backend/tools/infer/predict_system.py
@@ -13,17 +13,20 @@
# limitations under the License.
import os
import sys
+import subprocess
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
-sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2
import copy
import numpy as np
+import json
import time
+import logging
from PIL import Image
import tools.infer.utility as utility
import tools.infer.predict_rec as predict_rec
@@ -31,13 +34,15 @@
import tools.infer.predict_cls as predict_cls
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.logging import get_logger
-from tools.infer.utility import draw_ocr_box_txt
-
+from tools.infer.utility import draw_ocr_box_txt, get_rotate_crop_image
logger = get_logger()
class TextSystem(object):
def __init__(self, args):
+ if not args.show_log:
+ logger.setLevel(logging.INFO)
+
self.text_detector = predict_det.TextDetector(args)
self.text_recognizer = predict_rec.TextRecognizer(args)
self.use_angle_cls = args.use_angle_cls
@@ -45,50 +50,24 @@ def __init__(self, args):
if self.use_angle_cls:
self.text_classifier = predict_cls.TextClassifier(args)
- def get_rotate_crop_image(self, img, points):
- '''
- img_height, img_width = img.shape[0:2]
- left = int(np.min(points[:, 0]))
- right = int(np.max(points[:, 0]))
- top = int(np.min(points[:, 1]))
- bottom = int(np.max(points[:, 1]))
- img_crop = img[top:bottom, left:right, :].copy()
- points[:, 0] = points[:, 0] - left
- points[:, 1] = points[:, 1] - top
- '''
- img_crop_width = int(
- max(
- np.linalg.norm(points[0] - points[1]),
- np.linalg.norm(points[2] - points[3])))
- img_crop_height = int(
- max(
- np.linalg.norm(points[0] - points[3]),
- np.linalg.norm(points[1] - points[2])))
- pts_std = np.float32([[0, 0], [img_crop_width, 0],
- [img_crop_width, img_crop_height],
- [0, img_crop_height]])
- M = cv2.getPerspectiveTransform(points, pts_std)
- dst_img = cv2.warpPerspective(
- img,
- M, (img_crop_width, img_crop_height),
- borderMode=cv2.BORDER_REPLICATE,
- flags=cv2.INTER_CUBIC)
- dst_img_height, dst_img_width = dst_img.shape[0:2]
- if dst_img_height * 1.0 / dst_img_width >= 1.5:
- dst_img = np.rot90(dst_img)
- return dst_img
-
- def print_draw_crop_rec_res(self, img_crop_list, rec_res):
+ self.args = args
+ self.crop_image_res_index = 0
+
+ def draw_crop_rec_res(self, output_dir, img_crop_list, rec_res):
+ os.makedirs(output_dir, exist_ok=True)
bbox_num = len(img_crop_list)
for bno in range(bbox_num):
- cv2.imwrite("./output/img_crop_%d.jpg" % bno, img_crop_list[bno])
- logger.info(bno, rec_res[bno])
+ cv2.imwrite(
+ os.path.join(output_dir,
+ f"mg_crop_{bno+self.crop_image_res_index}.jpg"),
+ img_crop_list[bno])
+ logger.debug(f"{bno}, {rec_res[bno]}")
+ self.crop_image_res_index += bbox_num
- def __call__(self, img):
+ def __call__(self, img, cls=True):
ori_im = img.copy()
dt_boxes, elapse = self.text_detector(img)
- logger.info("dt_boxes num : {}, elapse : {}".format(
- len(dt_boxes), elapse))
+
if dt_boxes is None:
return None, None
img_crop_list = []
@@ -97,24 +76,23 @@ def __call__(self, img):
for bno in range(len(dt_boxes)):
tmp_box = copy.deepcopy(dt_boxes[bno])
- img_crop = self.get_rotate_crop_image(ori_im, tmp_box)
+ img_crop = get_rotate_crop_image(ori_im, tmp_box)
img_crop_list.append(img_crop)
- if self.use_angle_cls:
+ if self.use_angle_cls and cls:
img_crop_list, angle_list, elapse = self.text_classifier(
img_crop_list)
- logger.info("cls num : {}, elapse : {}".format(
- len(img_crop_list), elapse))
+
rec_res, elapse = self.text_recognizer(img_crop_list)
- logger.info("rec_res num : {}, elapse : {}".format(
- len(rec_res), elapse))
- # self.print_draw_crop_rec_res(img_crop_list, rec_res)
+ if self.args.save_crop_res:
+ self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list,
+ rec_res)
filter_boxes, filter_rec_res = [], []
- for box, rec_reuslt in zip(dt_boxes, rec_res):
- text, score = rec_reuslt
+ for box, rec_result in zip(dt_boxes, rec_res):
+ text, score = rec_result
if score >= self.drop_score:
filter_boxes.append(box)
- filter_rec_res.append(rec_reuslt)
+ filter_rec_res.append(rec_result)
return filter_boxes, filter_rec_res
@@ -141,24 +119,49 @@ def sorted_boxes(dt_boxes):
def main(args):
image_file_list = get_image_file_list(args.image_dir)
+ image_file_list = image_file_list[args.process_id::args.total_process_num]
text_sys = TextSystem(args)
is_visualize = True
font_path = args.vis_font_path
drop_score = args.drop_score
- for image_file in image_file_list:
+ draw_img_save_dir = args.draw_img_save_dir
+ os.makedirs(draw_img_save_dir, exist_ok=True)
+ save_results = []
+
+ logger.info("In PP-OCRv3, rec_image_shape parameter defaults to '3, 48, 320', "
+ "if you are using recognition model with PP-OCRv2 or an older version, please set --rec_image_shape='3,32,320")
+
+ # warm up 10 times
+ if args.warmup:
+ img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
+ for i in range(10):
+ res = text_sys(img)
+
+ total_time = 0
+ cpu_mem, gpu_mem, gpu_util = 0, 0, 0
+ _st = time.time()
+ count = 0
+ for idx, image_file in enumerate(image_file_list):
+
img, flag = check_and_read_gif(image_file)
if not flag:
img = cv2.imread(image_file)
if img is None:
- logger.info("error in loading image:{}".format(image_file))
+ logger.debug("error in loading image:{}".format(image_file))
continue
starttime = time.time()
dt_boxes, rec_res = text_sys(img)
elapse = time.time() - starttime
- logger.info("Predict time of %s: %.3fs" % (image_file, elapse))
+ total_time += elapse
+
- for text, score in rec_res:
- logger.info("{}, {:.3f}".format(text, score))
+ res = [{
+ "transcription": rec_res[idx][0],
+ "points": np.array(dt_boxes[idx]).astype(np.int32).tolist(),
+ } for idx in range(len(dt_boxes))]
+ save_pred = os.path.basename(image_file) + "\t" + json.dumps(
+ res, ensure_ascii=False) + "\n"
+ save_results.append(save_pred)
if is_visualize:
image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
@@ -173,15 +176,35 @@ def main(args):
scores,
drop_score=drop_score,
font_path=font_path)
- draw_img_save = "./inference_results/"
- if not os.path.exists(draw_img_save):
- os.makedirs(draw_img_save)
+ if flag:
+ image_file = image_file[:-3] + "png"
cv2.imwrite(
- os.path.join(draw_img_save, os.path.basename(image_file)),
+ os.path.join(draw_img_save_dir, os.path.basename(image_file)),
draw_img[:, :, ::-1])
- logger.info("The visualized image saved in {}".format(
- os.path.join(draw_img_save, os.path.basename(image_file))))
+
+
+ logger.info("The predict total time is {}".format(time.time() - _st))
+ if args.benchmark:
+ text_sys.text_detector.autolog.report()
+ text_sys.text_recognizer.autolog.report()
+
+ with open(os.path.join(draw_img_save_dir, "system_results.txt"), 'w', encoding='utf-8') as f:
+ f.writelines(save_results)
if __name__ == "__main__":
- main(utility.parse_args())
+ args = utility.parse_args()
+ if args.use_mp:
+ p_list = []
+ total_process_num = args.total_process_num
+ for process_id in range(total_process_num):
+ cmd = [sys.executable, "-u"] + sys.argv + [
+ "--process_id={}".format(process_id),
+ "--use_mp={}".format(False)
+ ]
+ p = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stdout)
+ p_list.append(p)
+ for p in p_list:
+ p.wait()
+ else:
+ main(args)
diff --git a/backend/tools/infer/utility.py b/backend/tools/infer/utility.py
old mode 100755
new mode 100644
index 92f3e745..29b3755e
--- a/backend/tools/infer/utility.py
+++ b/backend/tools/infer/utility.py
@@ -15,24 +15,29 @@
import argparse
import os
import sys
+import platform
import cv2
import numpy as np
-import json
+import paddle
from PIL import Image, ImageDraw, ImageFont
import math
from paddle import inference
+import time
+from ppocr.utils.logging import get_logger
-def parse_args():
- def str2bool(v):
- return v.lower() in ("true", "t", "1")
+def str2bool(v):
+ return v.lower() in ("true", "t", "1")
+
+def init_args():
parser = argparse.ArgumentParser()
# params for prediction engine
parser.add_argument("--use_gpu", type=str2bool, default=True)
parser.add_argument("--ir_optim", type=str2bool, default=True)
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
- parser.add_argument("--use_fp16", type=str2bool, default=False)
+ parser.add_argument("--min_subgraph_size", type=int, default=15)
+ parser.add_argument("--precision", type=str, default="fp32")
parser.add_argument("--gpu_mem", type=int, default=500)
# params for text detector
@@ -44,11 +49,11 @@ def str2bool(v):
# DB parmas
parser.add_argument("--det_db_thresh", type=float, default=0.3)
- parser.add_argument("--det_db_box_thresh", type=float, default=0.5)
- parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6)
+ parser.add_argument("--det_db_box_thresh", type=float, default=0.6)
+ parser.add_argument("--det_db_unclip_ratio", type=float, default=1.5)
parser.add_argument("--max_batch_size", type=int, default=10)
- parser.add_argument("--use_dilation", type=bool, default=False)
-
+ parser.add_argument("--use_dilation", type=str2bool, default=False)
+ parser.add_argument("--det_db_score_mode", type=str, default="fast")
# EAST parmas
parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
@@ -57,13 +62,26 @@ def str2bool(v):
# SAST parmas
parser.add_argument("--det_sast_score_thresh", type=float, default=0.5)
parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2)
- parser.add_argument("--det_sast_polygon", type=bool, default=False)
+ parser.add_argument("--det_sast_polygon", type=str2bool, default=False)
+
+ # PSE parmas
+ parser.add_argument("--det_pse_thresh", type=float, default=0)
+ parser.add_argument("--det_pse_box_thresh", type=float, default=0.85)
+ parser.add_argument("--det_pse_min_area", type=float, default=16)
+ parser.add_argument("--det_pse_box_type", type=str, default='quad')
+ parser.add_argument("--det_pse_scale", type=int, default=1)
+
+ # FCE parmas
+ parser.add_argument("--scales", type=list, default=[8, 16, 32])
+ parser.add_argument("--alpha", type=float, default=1.0)
+ parser.add_argument("--beta", type=float, default=1.0)
+ parser.add_argument("--fourier_degree", type=int, default=5)
+ parser.add_argument("--det_fce_box_type", type=str, default='poly')
# params for text recognizer
parser.add_argument("--rec_algorithm", type=str, default='CRNN')
parser.add_argument("--rec_model_dir", type=str)
- parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
- parser.add_argument("--rec_char_type", type=str, default='ch')
+ parser.add_argument("--rec_image_shape", type=str, default="3, 48, 320")
parser.add_argument("--rec_batch_num", type=int, default=6)
parser.add_argument("--max_text_length", type=int, default=25)
parser.add_argument(
@@ -75,6 +93,19 @@ def str2bool(v):
"--vis_font_path", type=str, default="./doc/fonts/simfang.ttf")
parser.add_argument("--drop_score", type=float, default=0.5)
+ # params for e2e
+ parser.add_argument("--e2e_algorithm", type=str, default='PGNet')
+ parser.add_argument("--e2e_model_dir", type=str)
+ parser.add_argument("--e2e_limit_side_len", type=float, default=768)
+ parser.add_argument("--e2e_limit_type", type=str, default='max')
+
+ # PGNet parmas
+ parser.add_argument("--e2e_pgnet_score_thresh", type=float, default=0.5)
+ parser.add_argument(
+ "--e2e_char_dict_path", type=str, default="./ppocr/utils/ic15_dict.txt")
+ parser.add_argument("--e2e_pgnet_valid_set", type=str, default='totaltext')
+ parser.add_argument("--e2e_pgnet_mode", type=str, default='fast')
+
# params for text classifier
parser.add_argument("--use_angle_cls", type=str2bool, default=False)
parser.add_argument("--cls_model_dir", type=str)
@@ -84,8 +115,31 @@ def str2bool(v):
parser.add_argument("--cls_thresh", type=float, default=0.9)
parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
+ parser.add_argument("--cpu_threads", type=int, default=10)
parser.add_argument("--use_pdserving", type=str2bool, default=False)
+ parser.add_argument("--warmup", type=str2bool, default=False)
+
+ #
+ parser.add_argument(
+ "--draw_img_save_dir", type=str, default="./inference_results")
+ parser.add_argument("--save_crop_res", type=str2bool, default=False)
+ parser.add_argument("--crop_res_save_dir", type=str, default="./output")
+
+ # multi-process
+ parser.add_argument("--use_mp", type=str2bool, default=False)
+ parser.add_argument("--total_process_num", type=int, default=1)
+ parser.add_argument("--process_id", type=int, default=0)
+
+ parser.add_argument("--benchmark", type=str2bool, default=False)
+ parser.add_argument("--save_log_path", type=str, default="./log_output/")
+ parser.add_argument("--show_log", type=str2bool, default=False)
+ parser.add_argument("--use_onnx", type=str2bool, default=False)
+ return parser
+
+
+def parse_args():
+ parser = init_args()
return parser.parse_args()
@@ -94,59 +148,224 @@ def create_predictor(args, mode, logger):
model_dir = args.det_model_dir
elif mode == 'cls':
model_dir = args.cls_model_dir
- else:
+ elif mode == 'rec':
model_dir = args.rec_model_dir
+ elif mode == 'table':
+ model_dir = args.table_model_dir
+ else:
+ model_dir = args.e2e_model_dir
if model_dir is None:
logger.info("not find {} model file path {}".format(mode, model_dir))
sys.exit(0)
- model_file_path = model_dir + "/inference.pdmodel"
- params_file_path = model_dir + "/inference.pdiparams"
- if not os.path.exists(model_file_path):
- logger.info("not find model file path {}".format(model_file_path))
- sys.exit(0)
- if not os.path.exists(params_file_path):
- logger.info("not find params file path {}".format(params_file_path))
- sys.exit(0)
+ if args.use_onnx:
+ import onnxruntime as ort
+ model_file_path = model_dir
+ if not os.path.exists(model_file_path):
+ raise ValueError("not find model file path {}".format(
+ model_file_path))
+ sess = ort.InferenceSession(model_file_path)
+ return sess, sess.get_inputs()[0], None, None
- config = inference.Config(model_file_path, params_file_path)
-
- if args.use_gpu:
- config.enable_use_gpu(args.gpu_mem, 0)
- if args.use_tensorrt:
- config.enable_tensorrt_engine(
- precision_mode=inference.PrecisionType.Half
- if args.use_fp16 else inference.PrecisionType.Float32,
- max_batch_size=args.max_batch_size)
else:
- config.disable_gpu()
- config.set_cpu_math_library_num_threads(6)
- if args.enable_mkldnn:
- # cache 10 different shapes for mkldnn to avoid memory leak
- config.set_mkldnn_cache_capacity(10)
- config.enable_mkldnn()
- # TODO LDOUBLEV: fix mkldnn bug when bach_size > 1
- #config.set_mkldnn_op({'conv2d', 'depthwise_conv2d', 'pool2d', 'batch_norm'})
- args.rec_batch_num = 1
-
- # enable memory optim
- config.enable_memory_optim()
- config.disable_glog_info()
-
- config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
- config.switch_use_feed_fetch_ops(False)
-
- # create predictor
- predictor = inference.create_predictor(config)
- input_names = predictor.get_input_names()
- for name in input_names:
- input_tensor = predictor.get_input_handle(name)
+ model_file_path = model_dir + "/inference.pdmodel"
+ params_file_path = model_dir + "/inference.pdiparams"
+ if not os.path.exists(model_file_path):
+ raise ValueError("not find model file path {}".format(
+ model_file_path))
+ if not os.path.exists(params_file_path):
+ raise ValueError("not find params file path {}".format(
+ params_file_path))
+
+ config = inference.Config(model_file_path, params_file_path)
+
+ if hasattr(args, 'precision'):
+ if args.precision == "fp16" and args.use_tensorrt:
+ precision = inference.PrecisionType.Half
+ elif args.precision == "int8":
+ precision = inference.PrecisionType.Int8
+ else:
+ precision = inference.PrecisionType.Float32
+ else:
+ precision = inference.PrecisionType.Float32
+
+ if args.use_gpu:
+ gpu_id = get_infer_gpuid()
+ if gpu_id is None:
+ logger.warning(
+ "GPU is not found in current device by nvidia-smi. Please check your device or ignore it if run on jetson."
+ )
+ config.enable_use_gpu(args.gpu_mem, 0)
+ if args.use_tensorrt:
+ config.enable_tensorrt_engine(
+ workspace_size=1 << 30,
+ precision_mode=precision,
+ max_batch_size=args.max_batch_size,
+ min_subgraph_size=args.min_subgraph_size)
+ # skip the minmum trt subgraph
+ use_dynamic_shape = True
+ if mode == "det":
+ min_input_shape = {
+ "x": [1, 3, 50, 50],
+ "conv2d_92.tmp_0": [1, 120, 20, 20],
+ "conv2d_91.tmp_0": [1, 24, 10, 10],
+ "conv2d_59.tmp_0": [1, 96, 20, 20],
+ "nearest_interp_v2_1.tmp_0": [1, 256, 10, 10],
+ "nearest_interp_v2_2.tmp_0": [1, 256, 20, 20],
+ "conv2d_124.tmp_0": [1, 256, 20, 20],
+ "nearest_interp_v2_3.tmp_0": [1, 64, 20, 20],
+ "nearest_interp_v2_4.tmp_0": [1, 64, 20, 20],
+ "nearest_interp_v2_5.tmp_0": [1, 64, 20, 20],
+ "elementwise_add_7": [1, 56, 2, 2],
+ "nearest_interp_v2_0.tmp_0": [1, 256, 2, 2]
+ }
+ max_input_shape = {
+ "x": [1, 3, 1536, 1536],
+ "conv2d_92.tmp_0": [1, 120, 400, 400],
+ "conv2d_91.tmp_0": [1, 24, 200, 200],
+ "conv2d_59.tmp_0": [1, 96, 400, 400],
+ "nearest_interp_v2_1.tmp_0": [1, 256, 200, 200],
+ "conv2d_124.tmp_0": [1, 256, 400, 400],
+ "nearest_interp_v2_2.tmp_0": [1, 256, 400, 400],
+ "nearest_interp_v2_3.tmp_0": [1, 64, 400, 400],
+ "nearest_interp_v2_4.tmp_0": [1, 64, 400, 400],
+ "nearest_interp_v2_5.tmp_0": [1, 64, 400, 400],
+ "elementwise_add_7": [1, 56, 400, 400],
+ "nearest_interp_v2_0.tmp_0": [1, 256, 400, 400]
+ }
+ opt_input_shape = {
+ "x": [1, 3, 640, 640],
+ "conv2d_92.tmp_0": [1, 120, 160, 160],
+ "conv2d_91.tmp_0": [1, 24, 80, 80],
+ "conv2d_59.tmp_0": [1, 96, 160, 160],
+ "nearest_interp_v2_1.tmp_0": [1, 256, 80, 80],
+ "nearest_interp_v2_2.tmp_0": [1, 256, 160, 160],
+ "conv2d_124.tmp_0": [1, 256, 160, 160],
+ "nearest_interp_v2_3.tmp_0": [1, 64, 160, 160],
+ "nearest_interp_v2_4.tmp_0": [1, 64, 160, 160],
+ "nearest_interp_v2_5.tmp_0": [1, 64, 160, 160],
+ "elementwise_add_7": [1, 56, 40, 40],
+ "nearest_interp_v2_0.tmp_0": [1, 256, 40, 40]
+ }
+ min_pact_shape = {
+ "nearest_interp_v2_26.tmp_0": [1, 256, 20, 20],
+ "nearest_interp_v2_27.tmp_0": [1, 64, 20, 20],
+ "nearest_interp_v2_28.tmp_0": [1, 64, 20, 20],
+ "nearest_interp_v2_29.tmp_0": [1, 64, 20, 20]
+ }
+ max_pact_shape = {
+ "nearest_interp_v2_26.tmp_0": [1, 256, 400, 400],
+ "nearest_interp_v2_27.tmp_0": [1, 64, 400, 400],
+ "nearest_interp_v2_28.tmp_0": [1, 64, 400, 400],
+ "nearest_interp_v2_29.tmp_0": [1, 64, 400, 400]
+ }
+ opt_pact_shape = {
+ "nearest_interp_v2_26.tmp_0": [1, 256, 160, 160],
+ "nearest_interp_v2_27.tmp_0": [1, 64, 160, 160],
+ "nearest_interp_v2_28.tmp_0": [1, 64, 160, 160],
+ "nearest_interp_v2_29.tmp_0": [1, 64, 160, 160]
+ }
+ min_input_shape.update(min_pact_shape)
+ max_input_shape.update(max_pact_shape)
+ opt_input_shape.update(opt_pact_shape)
+ elif mode == "rec":
+ if args.rec_algorithm != "CRNN":
+ use_dynamic_shape = False
+ imgH = int(args.rec_image_shape.split(',')[-2])
+ min_input_shape = {"x": [1, 3, imgH, 10]}
+ max_input_shape = {"x": [args.rec_batch_num, 3, imgH, 1536]}
+ opt_input_shape = {"x": [args.rec_batch_num, 3, imgH, 320]}
+ elif mode == "cls":
+ min_input_shape = {"x": [1, 3, 48, 10]}
+ max_input_shape = {"x": [args.rec_batch_num, 3, 48, 1024]}
+ opt_input_shape = {"x": [args.rec_batch_num, 3, 48, 320]}
+ else:
+ use_dynamic_shape = False
+ if use_dynamic_shape:
+ config.set_trt_dynamic_shape_info(
+ min_input_shape, max_input_shape, opt_input_shape)
+
+ else:
+ config.disable_gpu()
+ if hasattr(args, "cpu_threads"):
+ config.set_cpu_math_library_num_threads(args.cpu_threads)
+ else:
+ # default cpu threads as 10
+ config.set_cpu_math_library_num_threads(10)
+ if args.enable_mkldnn:
+ # cache 10 different shapes for mkldnn to avoid memory leak
+ config.set_mkldnn_cache_capacity(10)
+ config.enable_mkldnn()
+ if args.precision == "fp16":
+ config.enable_mkldnn_bfloat16()
+ # enable memory optim
+ config.enable_memory_optim()
+ config.disable_glog_info()
+ config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
+ config.delete_pass("matmul_transpose_reshape_fuse_pass")
+ if mode == 'table':
+ config.delete_pass("fc_fuse_pass") # not supported for table
+ config.switch_use_feed_fetch_ops(False)
+ config.switch_ir_optim(True)
+
+ # create predictor
+ predictor = inference.create_predictor(config)
+ input_names = predictor.get_input_names()
+ for name in input_names:
+ input_tensor = predictor.get_input_handle(name)
+ output_tensors = get_output_tensors(args, mode, predictor)
+ return predictor, input_tensor, output_tensors, config
+
+
+def get_output_tensors(args, mode, predictor):
output_names = predictor.get_output_names()
output_tensors = []
- for output_name in output_names:
- output_tensor = predictor.get_output_handle(output_name)
- output_tensors.append(output_tensor)
- return predictor, input_tensor, output_tensors
+ if mode == "rec" and args.rec_algorithm == "CRNN":
+ output_name = 'softmax_0.tmp_0'
+ if output_name in output_names:
+ return [predictor.get_output_handle(output_name)]
+ else:
+ for output_name in output_names:
+ output_tensor = predictor.get_output_handle(output_name)
+ output_tensors.append(output_tensor)
+ else:
+ for output_name in output_names:
+ output_tensor = predictor.get_output_handle(output_name)
+ output_tensors.append(output_tensor)
+ return output_tensors
+
+
+def get_infer_gpuid():
+ sysstr = platform.system()
+ if sysstr == "Windows":
+ return 0
+
+ if not paddle.core.is_compiled_with_rocm():
+ cmd = "env | grep CUDA_VISIBLE_DEVICES"
+ else:
+ cmd = "env | grep HIP_VISIBLE_DEVICES"
+ env_cuda = os.popen(cmd).readlines()
+ if len(env_cuda) == 0:
+ return 0
+ else:
+ gpu_id = env_cuda[0].strip().split("=")[1]
+ return int(gpu_id[0])
+
+
+def draw_e2e_res(dt_boxes, strs, img_path):
+ src_im = cv2.imread(img_path)
+ for box, str in zip(dt_boxes, strs):
+ box = box.astype(np.int32).reshape((-1, 1, 2))
+ cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
+ cv2.putText(
+ src_im,
+ str,
+ org=(int(box[0, 0, 0]), int(box[0, 0, 1])),
+ fontFace=cv2.FONT_HERSHEY_COMPLEX,
+ fontScale=0.7,
+ color=(0, 255, 0),
+ thickness=1)
+ return src_im
def draw_text_det_res(dt_boxes, img_path):
@@ -174,7 +393,7 @@ def draw_ocr(image,
txts=None,
scores=None,
drop_score=0.5,
- font_path="./doc/simfang.ttf"):
+ font_path="./doc/fonts/simfang.ttf"):
"""
Visualize the results of OCR detection and recognition
args:
@@ -216,7 +435,7 @@ def draw_ocr_box_txt(image,
scores=None,
drop_score=0.5,
font_path="./doc/simfang.ttf"):
- h, w = image.frame_height, image.frame_width
+ h, w = image.height, image.width
img_left = image.copy()
img_right = Image.new('RGB', (w, h), (255, 255, 255))
@@ -381,23 +600,46 @@ def draw_boxes(image, boxes, scores=None, drop_score=0.5):
return image
+def get_rotate_crop_image(img, points):
+ '''
+ img_height, img_width = img.shape[0:2]
+ left = int(np.min(points[:, 0]))
+ right = int(np.max(points[:, 0]))
+ top = int(np.min(points[:, 1]))
+ bottom = int(np.max(points[:, 1]))
+ img_crop = img[top:bottom, left:right, :].copy()
+ points[:, 0] = points[:, 0] - left
+ points[:, 1] = points[:, 1] - top
+ '''
+ assert len(points) == 4, "shape of points must be 4*2"
+ img_crop_width = int(
+ max(
+ np.linalg.norm(points[0] - points[1]),
+ np.linalg.norm(points[2] - points[3])))
+ img_crop_height = int(
+ max(
+ np.linalg.norm(points[0] - points[3]),
+ np.linalg.norm(points[1] - points[2])))
+ pts_std = np.float32([[0, 0], [img_crop_width, 0],
+ [img_crop_width, img_crop_height],
+ [0, img_crop_height]])
+ M = cv2.getPerspectiveTransform(points, pts_std)
+ dst_img = cv2.warpPerspective(
+ img,
+ M, (img_crop_width, img_crop_height),
+ borderMode=cv2.BORDER_REPLICATE,
+ flags=cv2.INTER_CUBIC)
+ dst_img_height, dst_img_width = dst_img.shape[0:2]
+ if dst_img_height * 1.0 / dst_img_width >= 1.5:
+ dst_img = np.rot90(dst_img)
+ return dst_img
+
+
+def check_gpu(use_gpu):
+ if use_gpu and not paddle.is_compiled_with_cuda():
+ use_gpu = False
+ return use_gpu
+
+
if __name__ == '__main__':
- test_img = "./doc/test_v2"
- predict_txt = "./doc/predict.txt"
- f = open(predict_txt, 'r')
- data = f.readlines()
- img_path, anno = data[0].strip().split('\t')
- img_name = os.path.basename(img_path)
- img_path = os.path.join(test_img, img_name)
- image = Image.open(img_path)
-
- data = json.loads(anno)
- boxes, txts, scores = [], [], []
- for dic in data:
- boxes.append(dic['points'])
- txts.append(dic['transcription'])
- scores.append(round(dic['scores'], 3))
-
- new_img = draw_ocr(image, boxes, txts, scores)
-
- cv2.imwrite(img_name, new_img)
+ pass
diff --git a/backend/tools/infer_cls.py b/backend/tools/infer_cls.py
index 49696482..7fd6b536 100755
--- a/backend/tools/infer_cls.py
+++ b/backend/tools/infer_cls.py
@@ -23,7 +23,7 @@
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
-sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
@@ -32,7 +32,7 @@
from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
-from ppocr.utils.save_load import init_model
+from ppocr.utils.save_load import load_model
from ppocr.utils.utility import get_image_file_list
import tools.program as program
@@ -47,7 +47,7 @@ def main():
# build model
model = build_model(config['Architecture'])
- init_model(config, model, logger)
+ load_model(config, model)
# create data ops
transforms = []
@@ -57,6 +57,8 @@ def main():
continue
elif op_name == 'KeepKeys':
op[op_name]['keep_keys'] = ['image']
+ elif op_name == "SSLRotateResize":
+ op[op_name]["mode"] = "test"
transforms.append(op)
global_config['infer_mode'] = True
ops = create_operators(transforms, global_config)
@@ -73,8 +75,8 @@ def main():
images = paddle.to_tensor(images)
preds = model(images)
post_result = post_process_class(preds)
- for rec_reuslt in post_result:
- logger.info('\t result: {}'.format(rec_reuslt))
+ for rec_result in post_result:
+ logger.info('\t result: {}'.format(rec_result))
logger.info("success!")
diff --git a/backend/tools/infer_det.py b/backend/tools/infer_det.py
index 913d617d..1acecedf 100755
--- a/backend/tools/infer_det.py
+++ b/backend/tools/infer_det.py
@@ -23,7 +23,7 @@
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
-sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
@@ -34,35 +34,33 @@
from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
-from ppocr.utils.save_load import init_model
+from ppocr.utils.save_load import load_model
from ppocr.utils.utility import get_image_file_list
import tools.program as program
-def draw_det_res(dt_boxes, config, img, img_name):
+def draw_det_res(dt_boxes, config, img, img_name, save_path):
if len(dt_boxes) > 0:
import cv2
src_im = img
for box in dt_boxes:
box = box.astype(np.int32).reshape((-1, 1, 2))
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
- save_det_path = os.path.dirname(config['Global'][
- 'save_res_path']) + "/det_results/"
- if not os.path.exists(save_det_path):
- os.makedirs(save_det_path)
- save_path = os.path.join(save_det_path, os.path.basename(img_name))
+ if not os.path.exists(save_path):
+ os.makedirs(save_path)
+ save_path = os.path.join(save_path, os.path.basename(img_name))
cv2.imwrite(save_path, src_im)
logger.info("The detected Image saved in {}".format(save_path))
+@paddle.no_grad()
def main():
global_config = config['Global']
# build model
model = build_model(config['Architecture'])
- init_model(config, model, logger)
-
+ load_model(config, model)
# build post process
post_process_class = build_post_process(config['PostProcess'])
@@ -96,20 +94,41 @@ def main():
images = paddle.to_tensor(images)
preds = model(images)
post_result = post_process_class(preds, shape_list)
- boxes = post_result[0]['points']
- # write result
+
+ src_img = cv2.imread(file)
+
dt_boxes_json = []
- for box in boxes:
- tmp_json = {"transcription": ""}
- tmp_json['points'] = box.tolist()
- dt_boxes_json.append(tmp_json)
+ # parser boxes if post_result is dict
+ if isinstance(post_result, dict):
+ det_box_json = {}
+ for k in post_result.keys():
+ boxes = post_result[k][0]['points']
+ dt_boxes_list = []
+ for box in boxes:
+ tmp_json = {"transcription": ""}
+ tmp_json['points'] = box.tolist()
+ dt_boxes_list.append(tmp_json)
+ det_box_json[k] = dt_boxes_list
+ save_det_path = os.path.dirname(config['Global'][
+ 'save_res_path']) + "/det_results_{}/".format(k)
+ draw_det_res(boxes, config, src_img, file, save_det_path)
+ else:
+ boxes = post_result[0]['points']
+ dt_boxes_json = []
+ # write result
+ for box in boxes:
+ tmp_json = {"transcription": ""}
+ tmp_json['points'] = box.tolist()
+ dt_boxes_json.append(tmp_json)
+ save_det_path = os.path.dirname(config['Global'][
+ 'save_res_path']) + "/det_results/"
+ draw_det_res(boxes, config, src_img, file, save_det_path)
otstr = file + "\t" + json.dumps(dt_boxes_json) + "\n"
fout.write(otstr.encode())
- src_img = cv2.imread(file)
- draw_det_res(boxes, config, src_img, file)
+
logger.info("success!")
if __name__ == '__main__':
config, device, logger, vdl_writer = program.preprocess()
- main()
\ No newline at end of file
+ main()
diff --git a/backend/tools/infer_e2e.py b/backend/tools/infer_e2e.py
new file mode 100755
index 00000000..d3e6b28f
--- /dev/null
+++ b/backend/tools/infer_e2e.py
@@ -0,0 +1,122 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+import os
+import sys
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
+
+os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+
+import cv2
+import json
+import paddle
+
+from ppocr.data import create_operators, transform
+from ppocr.modeling.architectures import build_model
+from ppocr.postprocess import build_post_process
+from ppocr.utils.save_load import load_model
+from ppocr.utils.utility import get_image_file_list
+import tools.program as program
+
+
+def draw_e2e_res(dt_boxes, strs, config, img, img_name):
+ if len(dt_boxes) > 0:
+ src_im = img
+ for box, str in zip(dt_boxes, strs):
+ box = box.astype(np.int32).reshape((-1, 1, 2))
+ cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
+ cv2.putText(
+ src_im,
+ str,
+ org=(int(box[0, 0, 0]), int(box[0, 0, 1])),
+ fontFace=cv2.FONT_HERSHEY_COMPLEX,
+ fontScale=0.7,
+ color=(0, 255, 0),
+ thickness=1)
+ save_det_path = os.path.dirname(config['Global'][
+ 'save_res_path']) + "/e2e_results/"
+ if not os.path.exists(save_det_path):
+ os.makedirs(save_det_path)
+ save_path = os.path.join(save_det_path, os.path.basename(img_name))
+ cv2.imwrite(save_path, src_im)
+ logger.info("The e2e Image saved in {}".format(save_path))
+
+
+def main():
+ global_config = config['Global']
+
+ # build model
+ model = build_model(config['Architecture'])
+
+ load_model(config, model)
+
+ # build post process
+ post_process_class = build_post_process(config['PostProcess'],
+ global_config)
+
+ # create data ops
+ transforms = []
+ for op in config['Eval']['dataset']['transforms']:
+ op_name = list(op)[0]
+ if 'Label' in op_name:
+ continue
+ elif op_name == 'KeepKeys':
+ op[op_name]['keep_keys'] = ['image', 'shape']
+ transforms.append(op)
+
+ ops = create_operators(transforms, global_config)
+
+ save_res_path = config['Global']['save_res_path']
+ if not os.path.exists(os.path.dirname(save_res_path)):
+ os.makedirs(os.path.dirname(save_res_path))
+
+ model.eval()
+ with open(save_res_path, "wb") as fout:
+ for file in get_image_file_list(config['Global']['infer_img']):
+ logger.info("infer_img: {}".format(file))
+ with open(file, 'rb') as f:
+ img = f.read()
+ data = {'image': img}
+ batch = transform(data, ops)
+ images = np.expand_dims(batch[0], axis=0)
+ shape_list = np.expand_dims(batch[1], axis=0)
+ images = paddle.to_tensor(images)
+ preds = model(images)
+ post_result = post_process_class(preds, shape_list)
+ points, strs = post_result['points'], post_result['texts']
+ # write result
+ dt_boxes_json = []
+ for poly, str in zip(points, strs):
+ tmp_json = {"transcription": str}
+ tmp_json['points'] = poly.tolist()
+ dt_boxes_json.append(tmp_json)
+ otstr = file + "\t" + json.dumps(dt_boxes_json) + "\n"
+ fout.write(otstr.encode())
+ src_img = cv2.imread(file)
+ draw_e2e_res(points, strs, config, src_img, file)
+ logger.info("success!")
+
+
+if __name__ == '__main__':
+ config, device, logger, vdl_writer = program.preprocess()
+ main()
diff --git a/backend/tools/infer_kie.py b/backend/tools/infer_kie.py
new file mode 100755
index 00000000..0cb0b870
--- /dev/null
+++ b/backend/tools/infer_kie.py
@@ -0,0 +1,153 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import paddle.nn.functional as F
+
+import os
+import sys
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
+
+os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+
+import cv2
+import paddle
+
+from ppocr.data import create_operators, transform
+from ppocr.modeling.architectures import build_model
+from ppocr.utils.save_load import load_model
+import tools.program as program
+import time
+
+
+def read_class_list(filepath):
+ dict = {}
+ with open(filepath, "r") as f:
+ lines = f.readlines()
+ for line in lines:
+ key, value = line.split(" ")
+ dict[key] = value.rstrip()
+ return dict
+
+
+def draw_kie_result(batch, node, idx_to_cls, count):
+ img = batch[6].copy()
+ boxes = batch[7]
+ h, w = img.shape[:2]
+ pred_img = np.ones((h, w * 2, 3), dtype=np.uint8) * 255
+ max_value, max_idx = paddle.max(node, -1), paddle.argmax(node, -1)
+ node_pred_label = max_idx.numpy().tolist()
+ node_pred_score = max_value.numpy().tolist()
+
+ for i, box in enumerate(boxes):
+ if i >= len(node_pred_label):
+ break
+ new_box = [[box[0], box[1]], [box[2], box[1]], [box[2], box[3]],
+ [box[0], box[3]]]
+ Pts = np.array([new_box], np.int32)
+ cv2.polylines(
+ img, [Pts.reshape((-1, 1, 2))],
+ True,
+ color=(255, 255, 0),
+ thickness=1)
+ x_min = int(min([point[0] for point in new_box]))
+ y_min = int(min([point[1] for point in new_box]))
+
+ pred_label = str(node_pred_label[i])
+ if pred_label in idx_to_cls:
+ pred_label = idx_to_cls[pred_label]
+ pred_score = '{:.2f}'.format(node_pred_score[i])
+ text = pred_label + '(' + pred_score + ')'
+ cv2.putText(pred_img, text, (x_min * 2, y_min),
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1)
+ vis_img = np.ones((h, w * 3, 3), dtype=np.uint8) * 255
+ vis_img[:, :w] = img
+ vis_img[:, w:] = pred_img
+ save_kie_path = os.path.dirname(config['Global'][
+ 'save_res_path']) + "/kie_results/"
+ if not os.path.exists(save_kie_path):
+ os.makedirs(save_kie_path)
+ save_path = os.path.join(save_kie_path, str(count) + ".png")
+ cv2.imwrite(save_path, vis_img)
+ logger.info("The Kie Image saved in {}".format(save_path))
+
+
+def main():
+ global_config = config['Global']
+
+ # build model
+ model = build_model(config['Architecture'])
+ load_model(config, model)
+
+ # create data ops
+ transforms = []
+ for op in config['Eval']['dataset']['transforms']:
+ transforms.append(op)
+
+ data_dir = config['Eval']['dataset']['data_dir']
+
+ ops = create_operators(transforms, global_config)
+
+ save_res_path = config['Global']['save_res_path']
+ class_path = config['Global']['class_path']
+ idx_to_cls = read_class_list(class_path)
+ if not os.path.exists(os.path.dirname(save_res_path)):
+ os.makedirs(os.path.dirname(save_res_path))
+
+ model.eval()
+
+ warmup_times = 0
+ count_t = []
+ with open(save_res_path, "wb") as fout:
+ with open(config['Global']['infer_img'], "rb") as f:
+ lines = f.readlines()
+ for index, data_line in enumerate(lines):
+ if index == 10:
+ warmup_t = time.time()
+ data_line = data_line.decode('utf-8')
+ substr = data_line.strip("\n").split("\t")
+ img_path, label = data_dir + "/" + substr[0], substr[1]
+ data = {'img_path': img_path, 'label': label}
+ with open(data['img_path'], 'rb') as f:
+ img = f.read()
+ data['image'] = img
+ st = time.time()
+ batch = transform(data, ops)
+ batch_pred = [0] * len(batch)
+ for i in range(len(batch)):
+ batch_pred[i] = paddle.to_tensor(
+ np.expand_dims(
+ batch[i], axis=0))
+ st = time.time()
+ node, edge = model(batch_pred)
+ node = F.softmax(node, -1)
+ count_t.append(time.time() - st)
+ draw_kie_result(batch, node, idx_to_cls, index)
+ logger.info("success!")
+ logger.info("It took {} s for predict {} images.".format(
+ np.sum(count_t), len(count_t)))
+ ips = len(count_t[warmup_times:]) / np.sum(count_t[warmup_times:])
+ logger.info("The ips is {} images/s".format(ips))
+
+
+if __name__ == '__main__':
+ config, device, logger, vdl_writer = program.preprocess()
+ main()
diff --git a/backend/tools/infer_rec.py b/backend/tools/infer_rec.py
index 075ec261..193e24a4 100755
--- a/backend/tools/infer_rec.py
+++ b/backend/tools/infer_rec.py
@@ -20,10 +20,11 @@
import os
import sys
+import json
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
-sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
@@ -32,7 +33,7 @@
from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
-from ppocr.utils.save_load import init_model
+from ppocr.utils.save_load import load_model
from ppocr.utils.utility import get_image_file_list
import tools.program as program
@@ -46,12 +47,38 @@ def main():
# build model
if hasattr(post_process_class, 'character'):
- config['Architecture']["Head"]['out_channels'] = len(
- getattr(post_process_class, 'character'))
+ char_num = len(getattr(post_process_class, 'character'))
+ if config['Architecture']["algorithm"] in ["Distillation",
+ ]: # distillation model
+ for key in config['Architecture']["Models"]:
+ if config['Architecture']['Models'][key]['Head'][
+ 'name'] == 'MultiHead': # for multi head
+ out_channels_list = {}
+ if config['PostProcess'][
+ 'name'] == 'DistillationSARLabelDecode':
+ char_num = char_num - 2
+ out_channels_list['CTCLabelDecode'] = char_num
+ out_channels_list['SARLabelDecode'] = char_num + 2
+ config['Architecture']['Models'][key]['Head'][
+ 'out_channels_list'] = out_channels_list
+ else:
+ config['Architecture']["Models"][key]["Head"][
+ 'out_channels'] = char_num
+ elif config['Architecture']['Head'][
+ 'name'] == 'MultiHead': # for multi head loss
+ out_channels_list = {}
+ if config['PostProcess']['name'] == 'SARLabelDecode':
+ char_num = char_num - 2
+ out_channels_list['CTCLabelDecode'] = char_num
+ out_channels_list['SARLabelDecode'] = char_num + 2
+ config['Architecture']['Head'][
+ 'out_channels_list'] = out_channels_list
+ else: # base rec model
+ config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture'])
- init_model(config, model, logger)
+ load_model(config, model)
# create data ops
transforms = []
@@ -67,41 +94,70 @@ def main():
'image', 'encoder_word_pos', 'gsrm_word_pos',
'gsrm_slf_attn_bias1', 'gsrm_slf_attn_bias2'
]
+ elif config['Architecture']['algorithm'] == "SAR":
+ op[op_name]['keep_keys'] = ['image', 'valid_ratio']
else:
op[op_name]['keep_keys'] = ['image']
transforms.append(op)
global_config['infer_mode'] = True
ops = create_operators(transforms, global_config)
+ save_res_path = config['Global'].get('save_res_path',
+ "./output/rec/predicts_rec.txt")
+ if not os.path.exists(os.path.dirname(save_res_path)):
+ os.makedirs(os.path.dirname(save_res_path))
+
model.eval()
- for file in get_image_file_list(config['Global']['infer_img']):
- logger.info("infer_img: {}".format(file))
- with open(file, 'rb') as f:
- img = f.read()
- data = {'image': img}
- batch = transform(data, ops)
- if config['Architecture']['algorithm'] == "SRN":
- encoder_word_pos_list = np.expand_dims(batch[1], axis=0)
- gsrm_word_pos_list = np.expand_dims(batch[2], axis=0)
- gsrm_slf_attn_bias1_list = np.expand_dims(batch[3], axis=0)
- gsrm_slf_attn_bias2_list = np.expand_dims(batch[4], axis=0)
-
- others = [
- paddle.to_tensor(encoder_word_pos_list),
- paddle.to_tensor(gsrm_word_pos_list),
- paddle.to_tensor(gsrm_slf_attn_bias1_list),
- paddle.to_tensor(gsrm_slf_attn_bias2_list)
- ]
-
- images = np.expand_dims(batch[0], axis=0)
- images = paddle.to_tensor(images)
- if config['Architecture']['algorithm'] == "SRN":
- preds = model(images, others)
- else:
- preds = model(images)
- post_result = post_process_class(preds)
- for rec_reuslt in post_result:
- logger.info('\t result: {}'.format(rec_reuslt))
+
+ with open(save_res_path, "w") as fout:
+ for file in get_image_file_list(config['Global']['infer_img']):
+ logger.info("infer_img: {}".format(file))
+ with open(file, 'rb') as f:
+ img = f.read()
+ data = {'image': img}
+ batch = transform(data, ops)
+ if config['Architecture']['algorithm'] == "SRN":
+ encoder_word_pos_list = np.expand_dims(batch[1], axis=0)
+ gsrm_word_pos_list = np.expand_dims(batch[2], axis=0)
+ gsrm_slf_attn_bias1_list = np.expand_dims(batch[3], axis=0)
+ gsrm_slf_attn_bias2_list = np.expand_dims(batch[4], axis=0)
+
+ others = [
+ paddle.to_tensor(encoder_word_pos_list),
+ paddle.to_tensor(gsrm_word_pos_list),
+ paddle.to_tensor(gsrm_slf_attn_bias1_list),
+ paddle.to_tensor(gsrm_slf_attn_bias2_list)
+ ]
+ if config['Architecture']['algorithm'] == "SAR":
+ valid_ratio = np.expand_dims(batch[-1], axis=0)
+ img_metas = [paddle.to_tensor(valid_ratio)]
+
+ images = np.expand_dims(batch[0], axis=0)
+ images = paddle.to_tensor(images)
+ if config['Architecture']['algorithm'] == "SRN":
+ preds = model(images, others)
+ elif config['Architecture']['algorithm'] == "SAR":
+ preds = model(images, img_metas)
+ else:
+ preds = model(images)
+ post_result = post_process_class(preds)
+ info = None
+ if isinstance(post_result, dict):
+ rec_info = dict()
+ for key in post_result:
+ if len(post_result[key][0]) >= 2:
+ rec_info[key] = {
+ "label": post_result[key][0][0],
+ "score": float(post_result[key][0][1]),
+ }
+ info = json.dumps(rec_info, ensure_ascii=False)
+ else:
+ if len(post_result[0]) >= 2:
+ info = post_result[0][0] + "\t" + str(post_result[0][1])
+
+ if info is not None:
+ logger.info("\t result: {}".format(info))
+ fout.write(file + "\t" + info)
logger.info("success!")
diff --git a/backend/tools/infer_table.py b/backend/tools/infer_table.py
new file mode 100644
index 00000000..66c2da44
--- /dev/null
+++ b/backend/tools/infer_table.py
@@ -0,0 +1,107 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+import os
+import sys
+import json
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
+
+os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+
+import paddle
+from paddle.jit import to_static
+
+from ppocr.data import create_operators, transform
+from ppocr.modeling.architectures import build_model
+from ppocr.postprocess import build_post_process
+from ppocr.utils.save_load import load_model
+from ppocr.utils.utility import get_image_file_list
+import tools.program as program
+import cv2
+
+
+def main(config, device, logger, vdl_writer):
+ global_config = config['Global']
+
+ # build post process
+ post_process_class = build_post_process(config['PostProcess'],
+ global_config)
+
+ # build model
+ if hasattr(post_process_class, 'character'):
+ config['Architecture']["Head"]['out_channels'] = len(
+ getattr(post_process_class, 'character'))
+
+ model = build_model(config['Architecture'])
+
+ load_model(config, model)
+
+ # create data ops
+ transforms = []
+ use_padding = False
+ for op in config['Eval']['dataset']['transforms']:
+ op_name = list(op)[0]
+ if 'Label' in op_name:
+ continue
+ if op_name == 'KeepKeys':
+ op[op_name]['keep_keys'] = ['image']
+ if op_name == "ResizeTableImage":
+ use_padding = True
+ padding_max_len = op['ResizeTableImage']['max_len']
+ transforms.append(op)
+
+ global_config['infer_mode'] = True
+ ops = create_operators(transforms, global_config)
+
+ model.eval()
+ for file in get_image_file_list(config['Global']['infer_img']):
+ logger.info("infer_img: {}".format(file))
+ with open(file, 'rb') as f:
+ img = f.read()
+ data = {'image': img}
+ batch = transform(data, ops)
+ images = np.expand_dims(batch[0], axis=0)
+ images = paddle.to_tensor(images)
+ preds = model(images)
+ post_result = post_process_class(preds)
+ res_html_code = post_result['res_html_code']
+ res_loc = post_result['res_loc']
+ img = cv2.imread(file)
+ imgh, imgw = img.shape[0:2]
+ res_loc_final = []
+ for rno in range(len(res_loc[0])):
+ x0, y0, x1, y1 = res_loc[0][rno]
+ left = max(int(imgw * x0), 0)
+ top = max(int(imgh * y0), 0)
+ right = min(int(imgw * x1), imgw - 1)
+ bottom = min(int(imgh * y1), imgh - 1)
+ cv2.rectangle(img, (left, top), (right, bottom), (0, 0, 255), 2)
+ res_loc_final.append([left, top, right, bottom])
+ res_loc_str = json.dumps(res_loc_final)
+ logger.info("result: {}, {}".format(res_html_code, res_loc_final))
+ logger.info("success!")
+
+
+if __name__ == '__main__':
+ config, device, logger, vdl_writer = program.preprocess()
+ main(config, device, logger, vdl_writer)
diff --git a/backend/tools/infer_vqa_token_ser.py b/backend/tools/infer_vqa_token_ser.py
new file mode 100755
index 00000000..83ed72b3
--- /dev/null
+++ b/backend/tools/infer_vqa_token_ser.py
@@ -0,0 +1,135 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+import os
+import sys
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
+
+os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+import cv2
+import json
+import paddle
+
+from ppocr.data import create_operators, transform
+from ppocr.modeling.architectures import build_model
+from ppocr.postprocess import build_post_process
+from ppocr.utils.save_load import load_model
+from ppocr.utils.visual import draw_ser_results
+from ppocr.utils.utility import get_image_file_list, load_vqa_bio_label_maps
+import tools.program as program
+
+
+def to_tensor(data):
+ import numbers
+ from collections import defaultdict
+ data_dict = defaultdict(list)
+ to_tensor_idxs = []
+ for idx, v in enumerate(data):
+ if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
+ if idx not in to_tensor_idxs:
+ to_tensor_idxs.append(idx)
+ data_dict[idx].append(v)
+ for idx in to_tensor_idxs:
+ data_dict[idx] = paddle.to_tensor(data_dict[idx])
+ return list(data_dict.values())
+
+
+class SerPredictor(object):
+ def __init__(self, config):
+ global_config = config['Global']
+
+ # build post process
+ self.post_process_class = build_post_process(config['PostProcess'],
+ global_config)
+
+ # build model
+ self.model = build_model(config['Architecture'])
+
+ load_model(
+ config, self.model, model_type=config['Architecture']["model_type"])
+
+ from paddleocr import PaddleOCR
+
+ self.ocr_engine = PaddleOCR(use_angle_cls=False, show_log=False)
+
+ # create data ops
+ transforms = []
+ for op in config['Eval']['dataset']['transforms']:
+ op_name = list(op)[0]
+ if 'Label' in op_name:
+ op[op_name]['ocr_engine'] = self.ocr_engine
+ elif op_name == 'KeepKeys':
+ op[op_name]['keep_keys'] = [
+ 'input_ids', 'labels', 'bbox', 'image', 'attention_mask',
+ 'token_type_ids', 'segment_offset_id', 'ocr_info',
+ 'entities'
+ ]
+
+ transforms.append(op)
+ global_config['infer_mode'] = True
+ self.ops = create_operators(config['Eval']['dataset']['transforms'],
+ global_config)
+ self.model.eval()
+
+ def __call__(self, img_path):
+ with open(img_path, 'rb') as f:
+ img = f.read()
+ data = {'image': img}
+ batch = transform(data, self.ops)
+ batch = to_tensor(batch)
+ preds = self.model(batch)
+ post_result = self.post_process_class(
+ preds,
+ attention_masks=batch[4],
+ segment_offset_ids=batch[6],
+ ocr_infos=batch[7])
+ return post_result, batch
+
+
+if __name__ == '__main__':
+ config, device, logger, vdl_writer = program.preprocess()
+ os.makedirs(config['Global']['save_res_path'], exist_ok=True)
+
+ ser_engine = SerPredictor(config)
+
+ infer_imgs = get_image_file_list(config['Global']['infer_img'])
+ with open(
+ os.path.join(config['Global']['save_res_path'],
+ "infer_results.txt"),
+ "w",
+ encoding='utf-8') as fout:
+ for idx, img_path in enumerate(infer_imgs):
+ save_img_path = os.path.join(
+ config['Global']['save_res_path'],
+ os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg")
+ logger.info("process: [{}/{}], save result to {}".format(
+ idx, len(infer_imgs), save_img_path))
+
+ result, _ = ser_engine(img_path)
+ result = result[0]
+ fout.write(img_path + "\t" + json.dumps(
+ {
+ "ocr_info": result,
+ }, ensure_ascii=False) + "\n")
+ img_res = draw_ser_results(img_path, result)
+ cv2.imwrite(save_img_path, img_res)
diff --git a/backend/tools/infer_vqa_token_ser_re.py b/backend/tools/infer_vqa_token_ser_re.py
new file mode 100755
index 00000000..40f1dd5c
--- /dev/null
+++ b/backend/tools/infer_vqa_token_ser_re.py
@@ -0,0 +1,199 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+import os
+import sys
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
+
+os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+import cv2
+import json
+import paddle
+import paddle.distributed as dist
+
+from ppocr.data import create_operators, transform
+from ppocr.modeling.architectures import build_model
+from ppocr.postprocess import build_post_process
+from ppocr.utils.save_load import load_model
+from ppocr.utils.visual import draw_re_results
+from ppocr.utils.logging import get_logger
+from ppocr.utils.utility import get_image_file_list, load_vqa_bio_label_maps, print_dict
+from tools.program import ArgsParser, load_config, merge_config, check_gpu
+from tools.infer_vqa_token_ser import SerPredictor
+
+
+class ReArgsParser(ArgsParser):
+ def __init__(self):
+ super(ReArgsParser, self).__init__()
+ self.add_argument(
+ "-c_ser", "--config_ser", help="ser configuration file to use")
+ self.add_argument(
+ "-o_ser",
+ "--opt_ser",
+ nargs='+',
+ help="set ser configuration options ")
+
+ def parse_args(self, argv=None):
+ args = super(ReArgsParser, self).parse_args(argv)
+ assert args.config_ser is not None, \
+ "Please specify --config_ser=ser_configure_file_path."
+ args.opt_ser = self._parse_opt(args.opt_ser)
+ return args
+
+
+def make_input(ser_inputs, ser_results):
+ entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2}
+
+ entities = ser_inputs[8][0]
+ ser_results = ser_results[0]
+ assert len(entities) == len(ser_results)
+
+ # entities
+ start = []
+ end = []
+ label = []
+ entity_idx_dict = {}
+ for i, (res, entity) in enumerate(zip(ser_results, entities)):
+ if res['pred'] == 'O':
+ continue
+ entity_idx_dict[len(start)] = i
+ start.append(entity['start'])
+ end.append(entity['end'])
+ label.append(entities_labels[res['pred']])
+ entities = dict(start=start, end=end, label=label)
+
+ # relations
+ head = []
+ tail = []
+ for i in range(len(entities["label"])):
+ for j in range(len(entities["label"])):
+ if entities["label"][i] == 1 and entities["label"][j] == 2:
+ head.append(i)
+ tail.append(j)
+
+ relations = dict(head=head, tail=tail)
+
+ batch_size = ser_inputs[0].shape[0]
+ entities_batch = []
+ relations_batch = []
+ entity_idx_dict_batch = []
+ for b in range(batch_size):
+ entities_batch.append(entities)
+ relations_batch.append(relations)
+ entity_idx_dict_batch.append(entity_idx_dict)
+
+ ser_inputs[8] = entities_batch
+ ser_inputs.append(relations_batch)
+ # remove ocr_info segment_offset_id and label in ser input
+ ser_inputs.pop(7)
+ ser_inputs.pop(6)
+ ser_inputs.pop(1)
+ return ser_inputs, entity_idx_dict_batch
+
+
+class SerRePredictor(object):
+ def __init__(self, config, ser_config):
+ self.ser_engine = SerPredictor(ser_config)
+
+ # init re model
+ global_config = config['Global']
+
+ # build post process
+ self.post_process_class = build_post_process(config['PostProcess'],
+ global_config)
+
+ # build model
+ self.model = build_model(config['Architecture'])
+
+ load_model(
+ config, self.model, model_type=config['Architecture']["model_type"])
+
+ self.model.eval()
+
+ def __call__(self, img_path):
+ ser_results, ser_inputs = self.ser_engine(img_path)
+ paddle.save(ser_inputs, 'ser_inputs.npy')
+ paddle.save(ser_results, 'ser_results.npy')
+ re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results)
+ preds = self.model(re_input)
+ post_result = self.post_process_class(
+ preds,
+ ser_results=ser_results,
+ entity_idx_dict_batch=entity_idx_dict_batch)
+ return post_result
+
+
+def preprocess():
+ FLAGS = ReArgsParser().parse_args()
+ config = load_config(FLAGS.settings_config)
+ config = merge_config(config, FLAGS.opt)
+
+ ser_config = load_config(FLAGS.config_ser)
+ ser_config = merge_config(ser_config, FLAGS.opt_ser)
+
+ logger = get_logger()
+
+ # check if set use_gpu=True in paddlepaddle cpu version
+ use_gpu = config['Global']['use_gpu']
+ check_gpu(use_gpu)
+
+ device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
+ device = paddle.set_device(device)
+
+ logger.info('{} re config {}'.format('*' * 10, '*' * 10))
+ print_dict(config, logger)
+ logger.info('\n')
+ logger.info('{} ser config {}'.format('*' * 10, '*' * 10))
+ print_dict(ser_config, logger)
+ logger.info('train with paddle {} and device {}'.format(paddle.__version__,
+ device))
+ return config, ser_config, device, logger
+
+
+if __name__ == '__main__':
+ config, ser_config, device, logger = preprocess()
+ os.makedirs(config['Global']['save_res_path'], exist_ok=True)
+
+ ser_re_engine = SerRePredictor(config, ser_config)
+
+ infer_imgs = get_image_file_list(config['Global']['infer_img'])
+ with open(
+ os.path.join(config['Global']['save_res_path'],
+ "infer_results.txt"),
+ "w",
+ encoding='utf-8') as fout:
+ for idx, img_path in enumerate(infer_imgs):
+ save_img_path = os.path.join(
+ config['Global']['save_res_path'],
+ os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg")
+ logger.info("process: [{}/{}], save result to {}".format(
+ idx, len(infer_imgs), save_img_path))
+
+ result = ser_re_engine(img_path)
+ result = result[0]
+ fout.write(img_path + "\t" + json.dumps(
+ {
+ "ser_result": result,
+ }, ensure_ascii=False) + "\n")
+ img_res = draw_re_results(img_path, result)
+ cv2.imwrite(save_img_path, img_res)
diff --git a/backend/tools/makedist.py b/backend/tools/makedist.py
new file mode 100644
index 00000000..c8fc9435
--- /dev/null
+++ b/backend/tools/makedist.py
@@ -0,0 +1,12 @@
+if __name__ == '__main__':
+ # 导入QPT
+ from qpt.executor import CreateExecutableModule as CEM
+ import os
+ WORK_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+ print(WORK_DIR)
+ LAUNCH_PATH = os.path.join(WORK_DIR, 'gui.py')
+ SAVE_PATH = os.path.join(os.path.dirname(WORK_DIR), 'vse_out')
+ ICON_PATH = os.path.join(WORK_DIR, "design", "vse.ico")
+ module = CEM(work_dir=WORK_DIR, launcher_py_path=LAUNCH_PATH, save_path=SAVE_PATH, icon=ICON_PATH, hidden_terminal=False)
+ # 开始打包
+ module.make()
diff --git a/backend/tools/ocr.py b/backend/tools/ocr.py
new file mode 100644
index 00000000..018140db
--- /dev/null
+++ b/backend/tools/ocr.py
@@ -0,0 +1,124 @@
+from tools.infer import utility
+from tools.infer.predict_system import TextSystem
+import config
+import importlib
+
+
+# 加载文本检测+识别模型
+class OcrRecogniser:
+ def __init__(self):
+ # 获取参数对象
+ importlib.reload(config)
+ self.args = utility.parse_args()
+ self.recogniser = self.init_model()
+
+ @staticmethod
+ def y_round(y):
+ y_min = y + 10 - y % 10
+ y_max = y - y % 10
+ if abs(y - y_min) < abs(y - y_max):
+ return y_min
+ else:
+ return y_max
+
+ def predict(self, image):
+ detection_box, recognise_result = self.recogniser(image)
+ if len(detection_box) > 0:
+ coordinate_list = list()
+ if isinstance(detection_box, list):
+ for i in detection_box:
+ i = list(i)
+ (x1, y1) = int(i[0][0]), int(i[0][1])
+ (x2, y2) = int(i[1][0]), int(i[1][1])
+ (x3, y3) = int(i[2][0]), int(i[2][1])
+ (x4, y4) = int(i[3][0]), int(i[3][1])
+ xmin = max(x1, x4)
+ xmax = min(x2, x3)
+ ymin = max(y1, y2)
+ ymax = min(y3, y4)
+ coordinate_list.append([xmin, xmax, ymin, ymax])
+
+ # 计算有多少行字幕,将每行字幕最小的ymin值放入lines
+ lines = []
+ for i in coordinate_list:
+ if len(lines) < 1:
+ lines.append(self.y_round(i[2]))
+ else:
+ if self.y_round(i[2]) not in lines \
+ and self.y_round(i[2]) + 10 not in lines \
+ and self.y_round(i[2]) - 10 not in lines:
+ lines.append(self.y_round(i[2]))
+ lines = sorted(lines)
+
+ for i in coordinate_list:
+ for j in lines:
+ if abs(j - self.y_round(i[2])) <= 10:
+ i[2] = j
+
+ to_rank_res = list(zip(coordinate_list, recognise_result))
+ ranked_res = []
+ for line in lines:
+ tmp_list = []
+ for i in to_rank_res:
+ if i[0][2] == line:
+ tmp_list.append(i)
+ # 先根据纵坐标排序
+ for k in range(1, len(tmp_list)):
+ for j in range(0, len(tmp_list) - k):
+ if tmp_list[j][0][2] > tmp_list[j + 1][0][2]:
+ print(tmp_list[j][0][2])
+ tmp_list[j], tmp_list[j + 1] = tmp_list[j + 1], tmp_list[j]
+ # 再根据横坐标排列
+ for l in range(1, len(tmp_list)):
+ for j in range(0, len(tmp_list) - l):
+ if tmp_list[j][0][0] > tmp_list[j + 1][0][0]:
+ tmp_list[j], tmp_list[j + 1] = tmp_list[j + 1], tmp_list[j]
+ for m in tmp_list:
+ ranked_res.append(m)
+ dt_box = []
+ for i in [j[0] for j in ranked_res]:
+ dt_box.append([(i[0], i[2]), (i[1], i[2]), (i[1], i[3]), (i[0], i[3])])
+ res = [i[1] for i in ranked_res]
+ return dt_box, res
+ else:
+ return detection_box, recognise_result
+
+ def init_model(self):
+ self.args.use_gpu = config.USE_GPU
+ if not config.USE_GPU:
+ import paddle
+ paddle.set_device('cpu')
+ # 设置文本检测模型路径
+ self.args.det_model_dir = config.DET_MODEL_PATH
+ # 设置文本识别模型路径
+ self.args.rec_model_dir = config.REC_MODEL_PATH
+ self.args.rec_char_dict_path = config.DICT_PATH
+ self.args.rec_image_shape = config.REC_IMAGE_SHAPE
+ # 设置识别文本的类型
+ self.args.rec_char_type = config.REC_CHAR_TYPE
+ # 设置每张图文本框批处理数量
+ self.args.rec_batch_num = config.REC_BATCH_NUM
+ self.args.max_batch_size = config.MAX_BATCH_SIZE
+ return TextSystem(self.args)
+
+
+def get_coordinates(dt_box):
+ """
+ 从返回的检测框中获取坐标
+ :param dt_box 检测框返回结果
+ :return list 坐标点列表
+ """
+ coordinate_list = list()
+ if isinstance(dt_box, list):
+ for i in dt_box:
+ i = list(i)
+ (x1, y1) = int(i[0][0]), int(i[0][1])
+ (x2, y2) = int(i[1][0]), int(i[1][1])
+ (x3, y3) = int(i[2][0]), int(i[2][1])
+ (x4, y4) = int(i[3][0]), int(i[3][1])
+ xmin = max(x1, x4)
+ xmax = min(x2, x3)
+ ymin = max(y1, y2)
+ ymax = min(y3, y4)
+ coordinate_list.append((xmin, xmax, ymin, ymax))
+ return coordinate_list
diff --git a/backend/tools/program.py b/backend/tools/program.py
index ae649176..7c02dc01 100755
--- a/backend/tools/program.py
+++ b/backend/tools/program.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,9 +18,10 @@
import os
import sys
+import platform
import yaml
import time
-import shutil
+import datetime
import paddle
import paddle.distributed as dist
from tqdm import tqdm
@@ -28,10 +29,11 @@
from ppocr.utils.stats import TrainingStats
from ppocr.utils.save_load import save_model
-from ppocr.utils.utility import print_dict
+from ppocr.utils.utility import print_dict, AverageMeter
from ppocr.utils.logging import get_logger
+from ppocr.utils.loggers import VDLLogger, WandbLogger, Loggers
+from ppocr.utils import profiler
from ppocr.data import build_dataloader
-import numpy as np
class ArgsParser(ArgumentParser):
@@ -41,6 +43,14 @@ def __init__(self):
self.add_argument("-c", "--config", help="configuration file to use")
self.add_argument(
"-o", "--opt", nargs='+', help="set configuration options")
+ self.add_argument(
+ '-p',
+ '--profiler_options',
+ type=str,
+ default=None,
+ help='The option of profiler, which should be in format ' \
+ '\"key1=value1;key2=value2;key3=value3\".'
+ )
def parse_args(self, argv=None):
args = super(ArgsParser, self).parse_args(argv)
@@ -60,24 +70,6 @@ def _parse_opt(self, opts):
return config
-class AttrDict(dict):
- """Single level attribute dict, NOT recursive"""
-
- def __init__(self, **kwargs):
- super(AttrDict, self).__init__()
- super(AttrDict, self).update(kwargs)
-
- def __getattr__(self, key):
- if key in self:
- return self[key]
- raise AttributeError("object has no attribute '{}'".format(key))
-
-
-global_config = AttrDict()
-
-default_config = {'Global': {'debug': False, }}
-
-
def load_config(file_path):
"""
Load config from yml/yaml file.
@@ -85,38 +77,39 @@ def load_config(file_path):
file_path (str): Path of the config file to be loaded.
Returns: global config
"""
- merge_config(default_config)
_, ext = os.path.splitext(file_path)
assert ext in ['.yml', '.yaml'], "only support yaml files for now"
- merge_config(yaml.load(open(file_path, 'rb'), Loader=yaml.Loader))
- return global_config
+ config = yaml.load(open(file_path, 'rb'), Loader=yaml.Loader)
+ return config
-def merge_config(config):
+def merge_config(config, opts):
"""
Merge config into global config.
Args:
config (dict): Config to be merged.
Returns: global config
"""
- for key, value in config.items():
+ for key, value in opts.items():
if "." not in key:
- if isinstance(value, dict) and key in global_config:
- global_config[key].update(value)
+ if isinstance(value, dict) and key in config:
+ config[key].update(value)
else:
- global_config[key] = value
+ config[key] = value
else:
sub_keys = key.split('.')
assert (
- sub_keys[0] in global_config
- ), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format(
- global_config.keys(), sub_keys[0])
- cur = global_config[sub_keys[0]]
+ sub_keys[0] in config
+ ), "the sub_keys can only be one of global_config: {}, but get: " \
+ "{}, please check your running command".format(
+ config.keys(), sub_keys[0])
+ cur = config[sub_keys[0]]
for idx, sub_key in enumerate(sub_keys[1:]):
if idx == len(sub_keys) - 2:
cur[sub_key] = value
else:
cur = cur[sub_key]
+ return config
def check_gpu(use_gpu):
@@ -138,6 +131,25 @@ def check_gpu(use_gpu):
pass
+def check_xpu(use_xpu):
+ """
+ Log error and exit when set use_xpu=true in paddlepaddle
+ cpu/gpu version.
+ """
+ err = "Config use_xpu cannot be set as true while you are " \
+ "using paddlepaddle cpu/gpu version ! \nPlease try: \n" \
+ "\t1. Install paddlepaddle-xpu to run model on XPU \n" \
+ "\t2. Set use_xpu as false in config file to run " \
+ "model on CPU/GPU"
+
+ try:
+ if use_xpu and not paddle.is_compiled_with_xpu():
+ print(err)
+ sys.exit(1)
+ except Exception as e:
+ pass
+
+
def train(config,
train_dataloader,
valid_dataloader,
@@ -150,26 +162,33 @@ def train(config,
eval_class,
pre_best_model_dict,
logger,
- vdl_writer=None):
+ log_writer=None,
+ scaler=None):
cal_metric_during_train = config['Global'].get('cal_metric_during_train',
False)
+ calc_epoch_interval = config['Global'].get('calc_epoch_interval', 1)
log_smooth_window = config['Global']['log_smooth_window']
epoch_num = config['Global']['epoch_num']
print_batch_step = config['Global']['print_batch_step']
eval_batch_step = config['Global']['eval_batch_step']
+ profiler_options = config['profiler_options']
global_step = 0
+ if 'global_step' in pre_best_model_dict:
+ global_step = pre_best_model_dict['global_step']
start_eval_step = 0
if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
start_eval_step = eval_batch_step[0]
eval_batch_step = eval_batch_step[1]
if len(valid_dataloader) == 0:
logger.info(
- 'No Images in eval dataset, evaluation during training will be disabled'
+ 'No Images in eval dataset, evaluation during training ' \
+ 'will be disabled'
)
start_eval_step = 1e111
logger.info(
- "During the training process, after the {}th iteration, an evaluation is run every {} iterations".
+ "During the training process, after the {}th iteration, " \
+ "an evaluation is run every {} iterations".
format(start_eval_step, eval_batch_step))
save_epoch_step = config['Global']['save_epoch_step']
save_model_dir = config['Global']['save_model_dir']
@@ -183,39 +202,96 @@ def train(config,
model.train()
use_srn = config['Architecture']['algorithm'] == "SRN"
-
- if 'start_epoch' in best_model_dict:
- start_epoch = best_model_dict['start_epoch']
+ extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR"]
+ extra_input = False
+ if config['Architecture']['algorithm'] == 'Distillation':
+ for key in config['Architecture']["Models"]:
+ extra_input = extra_input or config['Architecture']['Models'][key][
+ 'algorithm'] in extra_input_models
else:
- start_epoch = 1
+ extra_input = config['Architecture']['algorithm'] in extra_input_models
+ try:
+ model_type = config['Architecture']['model_type']
+ except:
+ model_type = None
+
+ algorithm = config['Architecture']['algorithm']
+
+ start_epoch = best_model_dict[
+ 'start_epoch'] if 'start_epoch' in best_model_dict else 1
+
+ total_samples = 0
+ train_reader_cost = 0.0
+ train_batch_cost = 0.0
+ reader_start = time.time()
+ eta_meter = AverageMeter()
+
+ max_iter = len(train_dataloader) - 1 if platform.system(
+ ) == "Windows" else len(train_dataloader)
for epoch in range(start_epoch, epoch_num + 1):
- train_dataloader = build_dataloader(
- config, 'Train', device, logger, seed=epoch)
- train_batch_cost = 0.0
- train_reader_cost = 0.0
- batch_sum = 0
- batch_start = time.time()
+ if train_dataloader.dataset.need_reset:
+ train_dataloader = build_dataloader(
+ config, 'Train', device, logger, seed=epoch)
+ max_iter = len(train_dataloader) - 1 if platform.system(
+ ) == "Windows" else len(train_dataloader)
for idx, batch in enumerate(train_dataloader):
- train_reader_cost += time.time() - batch_start
- if idx >= len(train_dataloader):
+ profiler.add_profiler_step(profiler_options)
+ train_reader_cost += time.time() - reader_start
+ if idx >= max_iter:
break
lr = optimizer.get_lr()
images = batch[0]
if use_srn:
- others = batch[-4:]
- preds = model(images, others)
model_average = True
+
+ # use amp
+ if scaler:
+ with paddle.amp.auto_cast():
+ if model_type == 'table' or extra_input:
+ preds = model(images, data=batch[1:])
+ else:
+ preds = model(images)
else:
- preds = model(images)
+ if model_type == 'table' or extra_input:
+ preds = model(images, data=batch[1:])
+ elif model_type in ["kie", 'vqa']:
+ preds = model(batch)
+ else:
+ preds = model(images)
+
loss = loss_class(preds, batch)
avg_loss = loss['loss']
- avg_loss.backward()
- optimizer.step()
+
+ if scaler:
+ scaled_avg_loss = scaler.scale(avg_loss)
+ scaled_avg_loss.backward()
+ scaler.minimize(optimizer, scaled_avg_loss)
+ else:
+ avg_loss.backward()
+ optimizer.step()
optimizer.clear_grad()
- train_batch_cost += time.time() - batch_start
- batch_sum += len(images)
+ if cal_metric_during_train and epoch % calc_epoch_interval == 0: # only rec and cls need
+ batch = [item.numpy() for item in batch]
+ if model_type in ['table', 'kie']:
+ eval_class(preds, batch)
+ else:
+ if config['Loss']['name'] in ['MultiLoss', 'MultiLoss_v2'
+ ]: # for multi head loss
+ post_result = post_process_class(
+ preds['ctc'], batch[1]) # for CTC head out
+ else:
+ post_result = post_process_class(preds, batch[1])
+ eval_class(post_result, batch)
+ metric = eval_class.get_metric()
+ train_stats.update(metric)
+
+ train_batch_time = time.time() - reader_start
+ train_batch_cost += train_batch_time
+ eta_meter.update(train_batch_time)
+ global_step += 1
+ total_samples += len(images)
if not isinstance(lr_scheduler, float):
lr_scheduler.step()
@@ -225,32 +301,34 @@ def train(config,
stats['lr'] = lr
train_stats.update(stats)
- if cal_metric_during_train: # only rec and cls need
- batch = [item.numpy() for item in batch]
- post_result = post_process_class(preds, batch[1])
- eval_class(post_result, batch)
- metric = eval_class.get_metric()
- train_stats.update(metric)
+ if log_writer is not None and dist.get_rank() == 0:
+ log_writer.log_metrics(metrics=train_stats.get(), prefix="TRAIN", step=global_step)
- if vdl_writer is not None and dist.get_rank() == 0:
- for k, v in train_stats.get().items():
- vdl_writer.add_scalar('TRAIN/{}'.format(k), v, global_step)
- vdl_writer.add_scalar('TRAIN/lr', lr, global_step)
-
- if dist.get_rank(
- ) == 0 and global_step > 0 and global_step % print_batch_step == 0:
+ if dist.get_rank() == 0 and (
+ (global_step > 0 and global_step % print_batch_step == 0) or
+ (idx >= len(train_dataloader) - 1)):
logs = train_stats.log()
- strs = 'epoch: [{}/{}], iter: {}, {}, reader_cost: {:.5f} s, batch_cost: {:.5f} s, samples: {}, ips: {:.5f}'.format(
- epoch, epoch_num, global_step, logs, train_reader_cost /
- print_batch_step, train_batch_cost / print_batch_step,
- batch_sum, batch_sum / train_batch_cost)
+
+ eta_sec = ((epoch_num + 1 - epoch) * \
+ len(train_dataloader) - idx - 1) * eta_meter.avg
+ eta_sec_format = str(datetime.timedelta(seconds=int(eta_sec)))
+ strs = 'epoch: [{}/{}], global_step: {}, {}, avg_reader_cost: ' \
+ '{:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, ' \
+ 'ips: {:.5f} samples/s, eta: {}'.format(
+ epoch, epoch_num, global_step, logs,
+ train_reader_cost / print_batch_step,
+ train_batch_cost / print_batch_step,
+ total_samples / print_batch_step,
+ total_samples / train_batch_cost, eta_sec_format)
logger.info(strs)
- train_batch_cost = 0.0
+
+ total_samples = 0
train_reader_cost = 0.0
- batch_sum = 0
+ train_batch_cost = 0.0
# eval
if global_step > start_eval_step and \
- (global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
+ (global_step - start_eval_step) % eval_batch_step == 0 \
+ and dist.get_rank() == 0:
if model_average:
Model_Average = paddle.incubate.optimizer.ModelAverage(
0.15,
@@ -263,17 +341,16 @@ def train(config,
valid_dataloader,
post_process_class,
eval_class,
- use_srn=use_srn)
+ model_type,
+ extra_input=extra_input)
cur_metric_str = 'cur metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
logger.info(cur_metric_str)
# logger metric
- if vdl_writer is not None:
- for k, v in cur_metric.items():
- if isinstance(v, (float, int)):
- vdl_writer.add_scalar('EVAL/{}'.format(k),
- cur_metric[k], global_step)
+ if log_writer is not None:
+ log_writer.log_metrics(metrics=cur_metric, prefix="EVAL", step=global_step)
+
if cur_metric[main_indicator] >= best_model_dict[
main_indicator]:
best_model_dict.update(cur_metric)
@@ -283,75 +360,111 @@ def train(config,
optimizer,
save_model_dir,
logger,
+ config,
is_best=True,
prefix='best_accuracy',
best_model_dict=best_model_dict,
- epoch=epoch)
+ epoch=epoch,
+ global_step=global_step)
best_str = 'best metric, {}'.format(', '.join([
'{}: {}'.format(k, v) for k, v in best_model_dict.items()
]))
logger.info(best_str)
# logger best metric
- if vdl_writer is not None:
- vdl_writer.add_scalar('EVAL/best_{}'.format(main_indicator),
- best_model_dict[main_indicator],
- global_step)
- global_step += 1
- optimizer.clear_grad()
- batch_start = time.time()
+ if log_writer is not None:
+ log_writer.log_metrics(metrics={
+ "best_{}".format(main_indicator): best_model_dict[main_indicator]
+ }, prefix="EVAL", step=global_step)
+
+ log_writer.log_model(is_best=True, prefix="best_accuracy", metadata=best_model_dict)
+
+ reader_start = time.time()
if dist.get_rank() == 0:
save_model(
model,
optimizer,
save_model_dir,
logger,
+ config,
is_best=False,
prefix='latest',
best_model_dict=best_model_dict,
- epoch=epoch)
+ epoch=epoch,
+ global_step=global_step)
+
+ if log_writer is not None:
+ log_writer.log_model(is_best=False, prefix="latest")
+
if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0:
save_model(
model,
optimizer,
save_model_dir,
logger,
+ config,
is_best=False,
prefix='iter_epoch_{}'.format(epoch),
best_model_dict=best_model_dict,
- epoch=epoch)
+ epoch=epoch,
+ global_step=global_step)
+ if log_writer is not None:
+ log_writer.log_model(is_best=False, prefix='iter_epoch_{}'.format(epoch))
+
best_str = 'best metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
logger.info(best_str)
- if dist.get_rank() == 0 and vdl_writer is not None:
- vdl_writer.close()
+ if dist.get_rank() == 0 and log_writer is not None:
+ log_writer.close()
return
-def eval(model, valid_dataloader, post_process_class, eval_class,
- use_srn=False):
+def eval(model,
+ valid_dataloader,
+ post_process_class,
+ eval_class,
+ model_type=None,
+ extra_input=False):
model.eval()
with paddle.no_grad():
total_frame = 0.0
total_time = 0.0
- pbar = tqdm(total=len(valid_dataloader), desc='eval model:')
+ pbar = tqdm(
+ total=len(valid_dataloader),
+ desc='eval model:',
+ position=0,
+ leave=True)
+ max_iter = len(valid_dataloader) - 1 if platform.system(
+ ) == "Windows" else len(valid_dataloader)
for idx, batch in enumerate(valid_dataloader):
- if idx >= len(valid_dataloader):
+ if idx >= max_iter:
break
images = batch[0]
start = time.time()
-
- if use_srn:
- others = batch[-4:]
- preds = model(images, others)
+ if model_type == 'table' or extra_input:
+ preds = model(images, data=batch[1:])
+ elif model_type in ["kie", 'vqa']:
+ preds = model(batch)
else:
preds = model(images)
- batch = [item.numpy() for item in batch]
+ batch_numpy = []
+ for item in batch:
+ if isinstance(item, paddle.Tensor):
+ batch_numpy.append(item.numpy())
+ else:
+ batch_numpy.append(item)
# Obtain usable results from post-processing methods
- post_result = post_process_class(preds, batch[1])
total_time += time.time() - start
# Evaluate the results of the current batch
- eval_class(post_result, batch)
+ if model_type in ['table', 'kie']:
+ eval_class(preds, batch_numpy)
+ elif model_type in ['vqa']:
+ post_result = post_process_class(preds, batch_numpy)
+ eval_class(post_result, batch_numpy)
+ else:
+ post_result = post_process_class(preds, batch_numpy[1])
+ eval_class(post_result, batch_numpy)
+
pbar.update(1)
total_frame += len(images)
# Get final metric,eg. acc or hmean
@@ -363,44 +476,127 @@ def eval(model, valid_dataloader, post_process_class, eval_class,
return metric
+def update_center(char_center, post_result, preds):
+ result, label = post_result
+ feats, logits = preds
+ logits = paddle.argmax(logits, axis=-1)
+ feats = feats.numpy()
+ logits = logits.numpy()
+
+ for idx_sample in range(len(label)):
+ if result[idx_sample][0] == label[idx_sample][0]:
+ feat = feats[idx_sample]
+ logit = logits[idx_sample]
+ for idx_time in range(len(logit)):
+ index = logit[idx_time]
+ if index in char_center.keys():
+ char_center[index][0] = (
+ char_center[index][0] * char_center[index][1] +
+ feat[idx_time]) / (char_center[index][1] + 1)
+ char_center[index][1] += 1
+ else:
+ char_center[index] = [feat[idx_time], 1]
+ return char_center
+
+
+def get_center(model, eval_dataloader, post_process_class):
+ pbar = tqdm(total=len(eval_dataloader), desc='get center:')
+ max_iter = len(eval_dataloader) - 1 if platform.system(
+ ) == "Windows" else len(eval_dataloader)
+ char_center = dict()
+ for idx, batch in enumerate(eval_dataloader):
+ if idx >= max_iter:
+ break
+ images = batch[0]
+ start = time.time()
+ preds = model(images)
+
+ batch = [item.numpy() for item in batch]
+ # Obtain usable results from post-processing methods
+ post_result = post_process_class(preds, batch[1])
+
+ #update char_center
+ char_center = update_center(char_center, post_result, preds)
+ pbar.update(1)
+
+ pbar.close()
+ for key in char_center.keys():
+ char_center[key] = char_center[key][0]
+ return char_center
+
+
def preprocess(is_train=False):
FLAGS = ArgsParser().parse_args()
+ profiler_options = FLAGS.profiler_options
config = load_config(FLAGS.config)
- merge_config(FLAGS.opt)
+ config = merge_config(config, FLAGS.opt)
+ profile_dic = {"profiler_options": FLAGS.profiler_options}
+ config = merge_config(config, profile_dic)
+
+ if is_train:
+ # save_config
+ save_model_dir = config['Global']['save_model_dir']
+ os.makedirs(save_model_dir, exist_ok=True)
+ with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
+ yaml.dump(
+ dict(config), f, default_flow_style=False, sort_keys=False)
+ log_file = '{}/train.log'.format(save_model_dir)
+ else:
+ log_file = None
+ logger = get_logger(log_file=log_file)
# check if set use_gpu=True in paddlepaddle cpu version
use_gpu = config['Global']['use_gpu']
check_gpu(use_gpu)
+ # check if set use_xpu=True in paddlepaddle cpu/gpu version
+ use_xpu = False
+ if 'use_xpu' in config['Global']:
+ use_xpu = config['Global']['use_xpu']
+ check_xpu(use_xpu)
+
alg = config['Architecture']['algorithm']
assert alg in [
- 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS'
+ 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
+ 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
+ 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR'
]
- device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
+ device = 'cpu'
+ if use_gpu:
+ device = 'gpu:{}'.format(dist.ParallelEnv().dev_id)
+ if use_xpu:
+ device = 'xpu'
device = paddle.set_device(device)
config['Global']['distributed'] = dist.get_world_size() != 1
- if is_train:
- # save_config
- save_model_dir = config['Global']['save_model_dir']
- os.makedirs(save_model_dir, exist_ok=True)
- with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
- yaml.dump(
- dict(config), f, default_flow_style=False, sort_keys=False)
- log_file = '{}/train.log'.format(save_model_dir)
- else:
- log_file = None
- logger = get_logger(name='root', log_file=log_file)
- if config['Global']['use_visualdl']:
- from visualdl import LogWriter
+
+ loggers = []
+
+ if 'use_visualdl' in config['Global'] and config['Global']['use_visualdl']:
save_model_dir = config['Global']['save_model_dir']
vdl_writer_path = '{}/vdl/'.format(save_model_dir)
- os.makedirs(vdl_writer_path, exist_ok=True)
- vdl_writer = LogWriter(logdir=vdl_writer_path)
+ log_writer = VDLLogger(save_model_dir)
+ loggers.append(log_writer)
+ if ('use_wandb' in config['Global'] and config['Global']['use_wandb']) or 'wandb' in config:
+ save_dir = config['Global']['save_model_dir']
+ wandb_writer_path = "{}/wandb".format(save_dir)
+ if "wandb" in config:
+ wandb_params = config['wandb']
+ else:
+ wandb_params = dict()
+ wandb_params.update({'save_dir': save_model_dir})
+ log_writer = WandbLogger(**wandb_params, config=config)
+ loggers.append(log_writer)
else:
- vdl_writer = None
+ log_writer = None
print_dict(config, logger)
+
+ if loggers:
+ log_writer = Loggers(loggers)
+ else:
+ log_writer = None
+
logger.info('train with paddle {} and device {}'.format(paddle.__version__,
device))
- return config, device, logger, vdl_writer
+ return config, device, logger, log_writer
diff --git a/backend/tools/reformat.py b/backend/tools/reformat.py
new file mode 100644
index 00000000..a044073d
--- /dev/null
+++ b/backend/tools/reformat.py
@@ -0,0 +1,153 @@
+# -*- coding: UTF-8 -*-
+"""
+@author: eritpchy
+@file : reformat.py
+@time : 2021/12/17 15:43
+@desc : 将连起来的英文单词切分
+"""
+import json
+import os
+import sys
+
+import pysrt
+import wordsegment as ws
+import re
+
+
+def execute(path, lang='en'):
+ # fix "RecursionError: maximum recursion depth exceeded in comparison" in wordsegment.segment call
+ if sys.getrecursionlimit() < 100000:
+ sys.setrecursionlimit(100000)
+
+ wordsegment = ws.Segmenter()
+ wordsegment.load()
+ subs = pysrt.open(path)
+ verb_forms = ["I'm", "you're", "he's", "she's", "we're", "it's", "isn't", "aren't", "they're", "there's", "wasn't",
+ "weren't", "I've", "you've", "we've", "they've", "hasn't", "haven't", "I'd", "you'd", "he'd", "she'd",
+ "it'd", "we'd", "they'd", "doesn't", "don't", "didn't", "I'll", "you'll", "he'll", "she'll", "we'll",
+ "they'll", "there'll", "there'd", "can't", "couldn't", "daren't", "hadn't", "mightn't", "mustn't",
+ "needn't", "oughtn't", "shan't", "shouldn't", "usedn't", "won't", "wouldn't", "that's", "what's", "it'll"]
+ verb_form_map = {}
+
+ with open(os.path.join(os.path.dirname(os.path.dirname(__file__)), 'configs', 'typoMap.json'), 'r', encoding='utf-8') as load_f:
+ typo_map = json.load(load_f)
+
+ for verb in verb_forms:
+ verb_form_map[verb.replace("'", "").lower()] = verb
+
+ def format_seg_list(seg_list):
+ new_seg = []
+ for seg in seg_list:
+ if seg in verb_form_map:
+ new_seg.append([seg, verb_form_map[seg]])
+ else:
+ new_seg.append([seg])
+ return new_seg
+
+ def typo_fix(text):
+ for k, v in typo_map.items():
+ text = re.sub(re.compile(k, re.I), v, text)
+ return text
+
+ # 逆向过滤seg
+ def remove_invalid_segment(seg, text):
+ seg_len = len(seg)
+ span = None
+ new_seg = []
+ for i in range(seg_len - 1, -1, -1):
+ s = seg[i]
+ if len(s) > 1:
+ regex = re.compile(f"({s[0]}|{s[1]})", re.I)
+ else:
+ regex = re.compile(f"({s[0]})", re.I)
+ try:
+ ss = [(i) for i in re.finditer(regex, text)][-1]
+ except IndexError:
+ ss = None
+ if ss is None:
+ continue
+ text = text[:ss.span()[0]]
+ if span is None:
+ span = ss.span()
+ new_seg.append(s)
+ continue
+ if span > ss.span():
+ new_seg.append(s)
+ span = ss.span()
+ return list(reversed(new_seg))
+
+ for sub in subs:
+ sub.text = typo_fix(sub.text)
+ seg = wordsegment.segment(sub.text)
+ if len(seg) == 1:
+ seg = wordsegment.segment(re.sub(re.compile(f"(\ni)([^\\s])", re.I), "\\1 \\2", sub.text))
+ seg = format_seg_list(seg)
+
+ # 替换中文前的多个空格成单个空格, 避免中英文分行出错
+ sub.text = re.sub(' +([\\u4e00-\\u9fa5])', ' \\1', sub.text)
+ # 中英文分行
+ if lang in ["ch", "ch_tra"]:
+ sub.text = sub.text.replace(" ", "\n")
+ lines = []
+ remain = sub.text
+ seg = remove_invalid_segment(seg, sub.text)
+ seg_len = len(seg)
+ for i in range(0, seg_len):
+ s = seg[i]
+ global regex
+ if len(s) > 1:
+ regex = re.compile(f"(.*?)({s[0]}|{s[1]})", re.I)
+ else:
+ regex = re.compile(f"(.*?)({s[0]})", re.I)
+ ss = re.search(regex, remain)
+ if ss is None:
+ if i == seg_len - 1:
+ lines.append(remain.strip())
+ continue
+
+ lines.append(remain[:ss.span()[1]].strip())
+ remain = remain[ss.span()[1]:].strip()
+ if i == seg_len - 1:
+ lines.append(remain)
+ if seg_len > 0:
+ ss = " ".join(lines)
+ else:
+ ss = remain
+ # again
+ ss = typo_fix(ss)
+ # 非大写字母的大写字母前加空格
+ ss = re.sub("([^\\sA-Z\\-])([A-Z])", "\\1 \\2", ss)
+ # 删除重复空格
+ ss = ss.replace(" ", " ")
+ ss = ss.replace("。", ".")
+ # 删除,?!,前的多个空格
+ ss = re.sub(" *([\\.\\?\\!\\,])", "\\1", ss)
+ # 删除'的前后多个空格
+ ss = re.sub(" *([\\']) *", "\\1", ss)
+ # 删除换行后的多个空格, 通常时第二行的开始的多个空格
+ ss = re.sub('\n\\s*', '\n', ss)
+ # 删除开始的多个空格
+ ss = re.sub('^\\s*', '', ss)
+ # 删除-左侧空格
+ ss = re.sub("([A-Za-z0-9]) (\\-[A-Za-z0-9])", '\\1\\2', ss)
+ # 删除%左侧空格
+ ss = re.sub("([A-Za-z0-9]) %", '\\1%', ss)
+ # 结尾·改成.
+ ss = re.sub('·$', '.', ss)
+ # 移除Dr.后的空格
+ ss = re.sub(r'\bDr\. *\b', "Dr.", ss)
+ # 中文引号转英文
+ ss = re.sub(r'[“”]', "\"", ss)
+ # 中文逗号转英文
+ ss = re.sub(r',', ",", ss)
+ # .,?后面加空格
+ ss = re.sub('([\\.,\\!\\?])([A-Za-z0-9\\u4e00-\\u9fa5])', '\\1 \\2', ss)
+ ss = ss.replace("\n\n", "\n")
+ sub.text = ss.strip()
+ subs.save(path, encoding='utf-8')
+
+
+if __name__ == '__main__':
+ path = "/home/yao/Videos/null.srt"
+ execute(path)
+
diff --git a/backend/tools/subtitle_ocr.py b/backend/tools/subtitle_ocr.py
new file mode 100644
index 00000000..5b452dd8
--- /dev/null
+++ b/backend/tools/subtitle_ocr.py
@@ -0,0 +1,280 @@
+import os
+import re
+from multiprocessing import Queue, Process
+import cv2
+from PIL import ImageFont, ImageDraw, Image
+from tqdm import tqdm
+from tools.ocr import OcrRecogniser, get_coordinates
+from tools.constant import SubtitleArea
+from tools import constant
+from threading import Thread
+import queue
+from shapely.geometry import Polygon
+from types import SimpleNamespace
+import shutil
+import numpy as np
+from collections import namedtuple
+
+
+def extract_subtitles(data, text_recogniser, img, raw_subtitle_file,
+ sub_area, options, dt_box_arg, rec_res_arg, ocr_loss_debug_path):
+ """
+ 提取视频帧中的字幕信息
+ """
+ # 从参数中获取检测框与检测结果
+ dt_box = dt_box_arg
+ rec_res = rec_res_arg
+ # 如果没有检测结果,则获取检测结果
+ if dt_box is None or rec_res is None:
+ dt_box, rec_res = text_recogniser.predict(img)
+ # rec_res格式为: ("hello", 0.997)
+ # 获取文本坐标
+ coordinates = get_coordinates(dt_box)
+ # 将结果写入txt文本中
+ if options.REC_CHAR_TYPE == 'en':
+ # 如果识别语言为英文,则去除中文
+ text_res = [(re.sub('[\u4e00-\u9fa5]', '', res[0]), res[1]) for res in rec_res]
+ else:
+ text_res = [(res[0], res[1]) for res in rec_res]
+ line = ''
+ loss_list = []
+ for content, coordinate in zip(text_res, coordinates):
+ text = content[0]
+ prob = content[1]
+ if sub_area is not None:
+ selected = False
+ # 初始化超界偏差为0
+ overflow_area_rate = 0
+ # 用户指定的字幕区域
+ sub_area_polygon = sub_area_to_polygon(sub_area)
+ # 识别出的字幕区域
+ coordinate_polygon = coordinate_to_polygon(coordinate)
+ # 计算两个区域是否有交集交集
+ intersection = sub_area_polygon.intersection(coordinate_polygon)
+ # 如果有交集
+ if not intersection.is_empty:
+ # 计算越界允许偏差
+ overflow_area_rate = ((sub_area_polygon.area + coordinate_polygon.area - intersection.area) / sub_area_polygon.area) - 1
+ # 如果越界比例低于设定阈值且该行文本识别的置信度高于设定阈值
+ if overflow_area_rate <= options.SUB_AREA_DEVIATION_RATE and prob > options.DROP_SCORE:
+ # 保留该帧
+ selected = True
+ line += f'{str(data["i"]).zfill(8)}\t{coordinate}\t{text}\n'
+ raw_subtitle_file.write(f'{str(data["i"]).zfill(8)}\t{coordinate}\t{text}\n')
+ # 保存丢掉的识别结果
+ loss_info = namedtuple('loss_info', 'text prob overflow_area_rate coordinate selected')
+ loss_list.append(loss_info(text, prob, overflow_area_rate, coordinate, selected))
+ else:
+ raw_subtitle_file.write(f'{str(data["i"]).zfill(8)}\t{coordinate}\t{text}\n')
+ # 输出调试信息
+ dump_debug_info(options, line, img, loss_list, ocr_loss_debug_path, sub_area, data)
+
+
+def dump_debug_info(options, line, img, loss_list, ocr_loss_debug_path, sub_area, data):
+ loss = False
+ if options.DEBUG_OCR_LOSS and options.REC_CHAR_TYPE in ('ch', 'japan ', 'korea', 'ch_tra'):
+ loss = len(line) > 0 and re.search(r'[\u4e00-\u9fa5\u3400-\u4db5\u3130-\u318F\uAC00-\uD7A3\u0800-\u4e00]', line) is None
+ if loss:
+ if not os.path.exists(ocr_loss_debug_path):
+ os.makedirs(ocr_loss_debug_path, mode=0o777, exist_ok=True)
+ img = cv2.rectangle(img, (sub_area[2], sub_area[0]), (sub_area[3], sub_area[1]), constant.BGR_COLOR_BLUE, 2)
+ for loss_info in loss_list:
+ coordinate = loss_info.coordinate
+ color = constant.BGR_COLOR_GREEN if loss_info.selected else constant.BGR_COLOR_RED
+ text = f"[{loss_info.text}] prob:{loss_info.prob:.4f} or:{loss_info.overflow_area_rate:.2f}"
+ img = paint_chinese_opencv(img, text, pos=(coordinate[0], coordinate[2] - 30), color=color)
+ img = cv2.rectangle(img, (coordinate[0], coordinate[2]), (coordinate[1], coordinate[3]), color, 2)
+ cv2.imwrite(os.path.join(os.path.abspath(ocr_loss_debug_path), f'{str(data["i"]).zfill(8)}.png'), img)
+
+
+def sub_area_to_polygon(sub_area):
+ s_ymin = sub_area[0]
+ s_ymax = sub_area[1]
+ s_xmin = sub_area[2]
+ s_xmax = sub_area[3]
+ return Polygon([[s_xmin, s_ymin], [s_xmax, s_ymin], [s_xmax, s_ymax], [s_xmin, s_ymax]])
+
+
+def coordinate_to_polygon(coordinate):
+ xmin = coordinate[0]
+ xmax = coordinate[1]
+ ymin = coordinate[2]
+ ymax = coordinate[3]
+ return Polygon([[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]])
+
+
+FONT_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'NotoSansCJK-Bold.otf')
+FONT = ImageFont.truetype(FONT_PATH, 20)
+
+
+def paint_chinese_opencv(im, chinese, pos, color):
+ img_pil = Image.fromarray(im)
+ fill_color = color # (color[2], color[1], color[0])
+ position = pos
+ draw = ImageDraw.Draw(img_pil)
+ draw.text(position, chinese, font=FONT, fill=fill_color)
+ img = np.asarray(img_pil)
+ return img
+
+
+def ocr_task_consumer(ocr_queue, raw_subtitle_path, sub_area, video_path, options):
+ """
+ 消费者: 消费ocr_queue,将ocr队列中的数据取出,进行ocr识别,写入字幕文件中
+ :param ocr_queue (current_frame_no当前帧帧号, frame 视频帧, dt_box检测框, rec_res识别结果)
+ :param raw_subtitle_path
+ :param sub_area
+ :param video_path
+ :param options
+ """
+ data = {'i': 1}
+ # 初始化文本识别对象
+ text_recogniser = OcrRecogniser()
+ # 丢失字幕的存储路径
+ ocr_loss_debug_path = os.path.join(os.path.abspath(os.path.splitext(video_path)[0]), 'loss')
+ # 删除之前的缓存垃圾
+ if os.path.exists(ocr_loss_debug_path):
+ shutil.rmtree(ocr_loss_debug_path, True)
+
+ with open(raw_subtitle_path, mode='w+', encoding='utf-8') as raw_subtitle_file:
+ while True:
+ try:
+ frame_no, frame, dt_box, rec_res = ocr_queue.get(block=True)
+ if frame_no == -1:
+ return
+ data['i'] = frame_no
+ extract_subtitles(data, text_recogniser, frame, raw_subtitle_file, sub_area, options, dt_box,
+ rec_res, ocr_loss_debug_path)
+ except Exception as e:
+ print(e)
+ break
+
+
+def ocr_task_producer(ocr_queue, task_queue, progress_queue, video_path, raw_subtitle_path):
+ """
+ 生产者:负责生产用于OCR识别的数据,将需要进行ocr识别的数据加入ocr_queue中
+ :param ocr_queue (current_frame_no当前帧帧号, frame 视频帧, dt_box检测框, rec_res识别结果)
+ :param task_queue (total_frame_count总帧数, current_frame_no当前帧帧号, dt_box检测框, rec_res识别结果, subtitle_area字幕区域)
+ :param progress_queue
+ :param video_path
+ :param raw_subtitle_path
+ """
+ cap = cv2.VideoCapture(video_path)
+ tbar = None
+ while True:
+ try:
+ # 从任务队列中提取任务信息
+ total_frame_count, current_frame_no, dt_box, rec_res, total_ms, default_subtitle_area = task_queue.get(block=True)
+ progress_queue.put(current_frame_no)
+ if tbar is None:
+ tbar = tqdm(total=round(total_frame_count), position=1)
+ # current_frame 等于-1说明所有视频帧已经读完
+ if current_frame_no == -1:
+ # ocr识别队列加入结束标志
+ ocr_queue.put((-1, None, None, None))
+ # 更新进度条
+ tbar.update(tbar.total - tbar.n)
+ break
+ tbar.update(round(current_frame_no - tbar.n))
+ # 设置当前视频帧
+ # 如果total_ms不为空,则使用了VSF提取字幕
+ if total_ms is not None:
+ cap.set(cv2.CAP_PROP_POS_MSEC, total_ms)
+ else:
+ cap.set(cv2.CAP_PROP_POS_FRAMES, current_frame_no - 1)
+ # 读取视频帧
+ ret, frame = cap.read()
+ ocr = OcrRecogniser()
+ dt_box, rec_res = ocr.predict(frame)
+ # 如果读取成功
+ if ret:
+ # 根据默认字幕位置,则对视频帧进行裁剪,裁剪后处理
+ if default_subtitle_area is not None:
+ frame = frame_preprocess(default_subtitle_area, frame)
+ ocr_queue.put((current_frame_no, frame, dt_box, rec_res))
+ except Exception as e:
+ print(e)
+ break
+ cap.release()
+
+
+def subtitle_extract_handler(task_queue, progress_queue, video_path, raw_subtitle_path, sub_area, options):
+ """
+ 创建并开启一个视频帧提取线程与一个ocr识别线程
+ :param task_queue 任务队列,(total_frame_count总帧数, current_frame_no当前帧, dt_box检测框, rec_res识别结果, subtitle_area字幕区域)
+ :param progress_queue 进度队列
+ :param video_path 视频路径
+ :param raw_subtitle_path 原始字幕文件路径
+ :param sub_area 字幕区域
+ :param options 选项
+ """
+ # 删除缓存
+ if os.path.exists(raw_subtitle_path):
+ os.remove(raw_subtitle_path)
+ # 创建一个OCR队列,大小建议值8-20
+ ocr_queue = queue.Queue(20)
+ # 创建一个OCR事件生产者线程
+ ocr_event_producer_thread = Thread(target=ocr_task_producer,
+ args=(ocr_queue, task_queue, progress_queue, video_path, raw_subtitle_path,),
+ daemon=True)
+ # 创建一个OCR事件消费者提取线程
+ ocr_event_consumer_thread = Thread(target=ocr_task_consumer,
+ args=(ocr_queue, raw_subtitle_path, sub_area, video_path, options,),
+ daemon=True)
+ # 开启消费者线程
+ ocr_event_producer_thread.start()
+ # 开启生产者线程
+ ocr_event_consumer_thread.start()
+ # join方法让主线程任务结束之后,进入阻塞状态,一直等待其他的子线程执行结束之后,主线程再终止
+ ocr_event_producer_thread.join()
+ ocr_event_consumer_thread.join()
+
+
+def async_start(video_path, raw_subtitle_path, sub_area, options):
+ """
+ 开始进程处理异步任务
+ options.REC_CHAR_TYPE
+ options.DROP_SCORE
+ options.SUB_AREA_DEVIATION_RATE
+ options.DEBUG_OCR_LOSS
+ """
+ assert 'REC_CHAR_TYPE' in options, "options缺少参数:REC_CHAR_TYPE"
+ assert 'DROP_SCORE' in options, "options缺少参数: DROP_SCORE'"
+ assert 'SUB_AREA_DEVIATION_RATE' in options, "options缺少参数: SUB_AREA_DEVIATION_RATE"
+ assert 'DEBUG_OCR_LOSS' in options, "options缺少参数: DEBUG_OCR_LOSS"
+ # 创建一个任务队列
+ # 任务格式为:(total_frame_count总帧数, current_frame_no当前帧, dt_box检测框, rec_res识别结果, subtitle_area字幕区域)
+ task_queue = Queue()
+ # 创建一个进度更新队列
+ progress_queue = Queue()
+ # 新建一个进程
+ p = Process(target=subtitle_extract_handler,
+ args=(task_queue, progress_queue, video_path, raw_subtitle_path, sub_area, SimpleNamespace(**options),))
+ # 启动进程
+ p.start()
+ return p, task_queue, progress_queue
+
+
+def frame_preprocess(subtitle_area, frame):
+ """
+ 将视频帧进行裁剪
+ """
+ # 对于分辨率大于1920*1080的视频,将其视频帧进行等比缩放至1280*720进行识别
+ # paddlepaddle会将图像压缩为640*640
+ # if self.frame_width > 1280:
+ # scale_rate = round(float(1280 / self.frame_width), 2)
+ # frames = cv2.resize(frames, None, fx=scale_rate, fy=scale_rate, interpolation=cv2.INTER_AREA)
+ # 如果字幕出现的区域在下部分
+ if subtitle_area == SubtitleArea.LOWER_PART:
+ cropped = int(frame.shape[0] // 2)
+ # 将视频帧切割为下半部分
+ frame = frame[cropped:]
+ # 如果字幕出现的区域在上半部分
+ elif subtitle_area == SubtitleArea.UPPER_PART:
+ cropped = int(frame.shape[0] // 2)
+ # 将视频帧切割为下半部分
+ frame = frame[:cropped]
+ return frame
+
+
+if __name__ == "__main__":
+ pass
diff --git a/backend/tools/test_hubserving.py b/backend/tools/test_hubserving.py
index 3beb4965..ec17a941 100755
--- a/backend/tools/test_hubserving.py
+++ b/backend/tools/test_hubserving.py
@@ -25,7 +25,9 @@
import time
from PIL import Image
from ppocr.utils.utility import get_image_file_list
-from tools.infer.utility import draw_ocr, draw_boxes
+from tools.infer.utility import draw_ocr, draw_boxes, str2bool
+from ppstructure.utility import draw_structure_result
+from ppstructure.predict_system import to_excel
import requests
import json
@@ -64,12 +66,38 @@ def draw_server_result(image_file, res):
scores.append(res[dno]['confidence'])
boxes = np.array(boxes)
scores = np.array(scores)
- draw_img = draw_ocr(image, boxes, texts, scores, drop_score=0.5)
+ draw_img = draw_ocr(
+ image, boxes, texts, scores, draw_txt=True, drop_score=0.5)
return draw_img
-def main(url, image_path):
- image_file_list = get_image_file_list(image_path)
+def save_structure_res(res, save_folder, image_file):
+ img = cv2.imread(image_file)
+ excel_save_folder = os.path.join(save_folder, os.path.basename(image_file))
+ os.makedirs(excel_save_folder, exist_ok=True)
+ # save res
+ with open(
+ os.path.join(excel_save_folder, 'res.txt'), 'w',
+ encoding='utf8') as f:
+ for region in res:
+ if region['type'] == 'Table':
+ excel_path = os.path.join(excel_save_folder,
+ '{}.xlsx'.format(region['bbox']))
+ to_excel(region['res'], excel_path)
+ elif region['type'] == 'Figure':
+ x1, y1, x2, y2 = region['bbox']
+ print(region['bbox'])
+ roi_img = img[y1:y2, x1:x2, :]
+ img_path = os.path.join(excel_save_folder,
+ '{}.jpg'.format(region['bbox']))
+ cv2.imwrite(img_path, roi_img)
+ else:
+ for text_result in region['res']:
+ f.write('{}\n'.format(json.dumps(text_result)))
+
+
+def main(args):
+ image_file_list = get_image_file_list(args.image_dir)
is_visualize = False
headers = {"Content-type": "application/json"}
cnt = 0
@@ -79,38 +107,51 @@ def main(url, image_path):
if img is None:
logger.info("error in loading image:{}".format(image_file))
continue
-
- # 发送HTTP请求
+ img_name = os.path.basename(image_file)
+ # seed http request
starttime = time.time()
data = {'images': [cv2_to_base64(img)]}
- r = requests.post(url=url, headers=headers, data=json.dumps(data))
+ r = requests.post(
+ url=args.server_url, headers=headers, data=json.dumps(data))
elapse = time.time() - starttime
total_time += elapse
logger.info("Predict time of %s: %.3fs" % (image_file, elapse))
res = r.json()["results"][0]
logger.info(res)
- if is_visualize:
- draw_img = draw_server_result(image_file, res)
+ if args.visualize:
+ draw_img = None
+ if 'structure_table' in args.server_url:
+ to_excel(res['html'], './{}.xlsx'.format(img_name))
+ elif 'structure_system' in args.server_url:
+ save_structure_res(res['regions'], args.output, image_file)
+ else:
+ draw_img = draw_server_result(image_file, res)
if draw_img is not None:
- draw_img_save = "./server_results/"
- if not os.path.exists(draw_img_save):
- os.makedirs(draw_img_save)
+ if not os.path.exists(args.output):
+ os.makedirs(args.output)
cv2.imwrite(
- os.path.join(draw_img_save, os.path.basename(image_file)),
+ os.path.join(args.output, os.path.basename(image_file)),
draw_img[:, :, ::-1])
logger.info("The visualized image saved in {}".format(
- os.path.join(draw_img_save, os.path.basename(image_file))))
+ os.path.join(args.output, os.path.basename(image_file))))
cnt += 1
if cnt % 100 == 0:
logger.info("{} processed".format(cnt))
logger.info("avg time cost: {}".format(float(total_time) / cnt))
+def parse_args():
+ import argparse
+ parser = argparse.ArgumentParser(description="args for hub serving")
+ parser.add_argument("--server_url", type=str, required=True)
+ parser.add_argument("--image_dir", type=str, required=True)
+ parser.add_argument("--visualize", type=str2bool, default=False)
+ parser.add_argument("--output", type=str, default='./hubserving_result')
+ args = parser.parse_args()
+ return args
+
+
if __name__ == '__main__':
- if len(sys.argv) != 3:
- logger.info("Usage: %s server_url image_path" % sys.argv[0])
- else:
- server_url = sys.argv[1]
- image_path = sys.argv[2]
- main(server_url, image_path)
+ args = parse_args()
+ main(args)
diff --git a/backend/tools/train.py b/backend/tools/train.py
index fab10b64..42aba548 100755
--- a/backend/tools/train.py
+++ b/backend/tools/train.py
@@ -21,21 +21,20 @@
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
-sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
import yaml
import paddle
import paddle.distributed as dist
-paddle.seed(2)
-
from ppocr.data import build_dataloader
from ppocr.modeling.architectures import build_model
from ppocr.losses import build_loss
from ppocr.optimizer import build_optimizer
from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric
-from ppocr.utils.save_load import init_model
+from ppocr.utils.save_load import load_model
+from ppocr.utils.utility import set_seed
import tools.program as program
dist.get_world_size()
@@ -52,7 +51,10 @@ def main(config, device, logger, vdl_writer):
train_dataloader = build_dataloader(config, 'Train', device, logger)
if len(train_dataloader) == 0:
logger.error(
- 'No Images in train dataset, please check annotation file and path in the configuration file'
+ "No Images in train dataset, please ensure\n" +
+ "\t1. The images num in the train label_file_list should be larger than or equal with batch size.\n"
+ +
+ "\t2. The annotation file and path in the configuration file are provided normally."
)
return
@@ -69,7 +71,52 @@ def main(config, device, logger, vdl_writer):
# for rec algorithm
if hasattr(post_process_class, 'character'):
char_num = len(getattr(post_process_class, 'character'))
- config['Architecture']["Head"]['out_channels'] = char_num
+ if config['Architecture']["algorithm"] in ["Distillation",
+ ]: # distillation model
+ for key in config['Architecture']["Models"]:
+ if config['Architecture']['Models'][key]['Head'][
+ 'name'] == 'MultiHead': # for multi head
+ if config['PostProcess'][
+ 'name'] == 'DistillationSARLabelDecode':
+ char_num = char_num - 2
+ # update SARLoss params
+ assert list(config['Loss']['loss_config_list'][-1].keys())[
+ 0] == 'DistillationSARLoss'
+ config['Loss']['loss_config_list'][-1][
+ 'DistillationSARLoss']['ignore_index'] = char_num + 1
+ out_channels_list = {}
+ out_channels_list['CTCLabelDecode'] = char_num
+ out_channels_list['SARLabelDecode'] = char_num + 2
+ config['Architecture']['Models'][key]['Head'][
+ 'out_channels_list'] = out_channels_list
+ else:
+ config['Architecture']["Models"][key]["Head"][
+ 'out_channels'] = char_num
+ elif config['Architecture']['Head'][
+ 'name'] == 'MultiHead': # for multi head
+ if config['PostProcess']['name'] == 'SARLabelDecode':
+ char_num = char_num - 2
+ # update SARLoss params
+ assert list(config['Loss']['loss_config_list'][1].keys())[
+ 0] == 'SARLoss'
+ if config['Loss']['loss_config_list'][1]['SARLoss'] is None:
+ config['Loss']['loss_config_list'][1]['SARLoss'] = {
+ 'ignore_index': char_num + 1
+ }
+ else:
+ config['Loss']['loss_config_list'][1]['SARLoss'][
+ 'ignore_index'] = char_num + 1
+ out_channels_list = {}
+ out_channels_list['CTCLabelDecode'] = char_num
+ out_channels_list['SARLabelDecode'] = char_num + 2
+ config['Architecture']['Head'][
+ 'out_channels_list'] = out_channels_list
+ else: # base rec model
+ config['Architecture']["Head"]['out_channels'] = char_num
+
+ if config['PostProcess']['name'] == 'SARLabelDecode': # for SAR model
+ config['Loss']['ignore_index'] = char_num - 1
+
model = build_model(config['Architecture'])
if config['Global']['distributed']:
model = paddle.DataParallel(model)
@@ -82,19 +129,38 @@ def main(config, device, logger, vdl_writer):
config['Optimizer'],
epochs=config['Global']['epoch_num'],
step_each_epoch=len(train_dataloader),
- parameters=model.parameters())
+ model=model)
# build metric
eval_class = build_metric(config['Metric'])
# load pretrain model
- pre_best_model_dict = init_model(config, model, logger, optimizer)
+ pre_best_model_dict = load_model(config, model, optimizer,
+ config['Architecture']["model_type"])
+ logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
+ if valid_dataloader is not None:
+ logger.info('valid dataloader has {} iters'.format(
+ len(valid_dataloader)))
+
+ use_amp = config["Global"].get("use_amp", False)
+ if use_amp:
+ AMP_RELATED_FLAGS_SETTING = {
+ 'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
+ 'FLAGS_max_inplace_grad_add': 8,
+ }
+ paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
+ scale_loss = config["Global"].get("scale_loss", 1.0)
+ use_dynamic_loss_scaling = config["Global"].get(
+ "use_dynamic_loss_scaling", False)
+ scaler = paddle.amp.GradScaler(
+ init_loss_scaling=scale_loss,
+ use_dynamic_loss_scaling=use_dynamic_loss_scaling)
+ else:
+ scaler = None
- logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
- format(len(train_dataloader), len(valid_dataloader)))
# start train
program.train(config, train_dataloader, valid_dataloader, device, model,
loss_class, optimizer, lr_scheduler, post_process_class,
- eval_class, pre_best_model_dict, logger, vdl_writer)
+ eval_class, pre_best_model_dict, logger, vdl_writer, scaler)
def test_reader(config, device, logger):
@@ -117,5 +183,7 @@ def test_reader(config, device, logger):
if __name__ == '__main__':
config, device, logger, vdl_writer = program.preprocess(is_train=True)
+ seed = config['Global']['seed'] if 'seed' in config['Global'] else 1024
+ set_seed(seed)
main(config, device, logger, vdl_writer)
# test_reader(config, device, logger)
diff --git a/config.py b/config.py
deleted file mode 100644
index 80a930c6..00000000
--- a/config.py
+++ /dev/null
@@ -1,75 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-@Author : Fang Yao
-@Time : 2021/3/24 9:36 上午
-@FileName: config.py
-@desc: 项目配置文件,可以在这里调参,牺牲时间换取精确度,或者牺牲准确度换取时间
-"""
-import os
-from pathlib import Path
-from enum import Enum
-from fsplit.filesplit import Filesplit
-
-# --------------------- 请你不要改 start-----------------------------
-# 项目的base目录
-BASE_DIR = str(Path(os.path.abspath(__file__)).parent)
-
-# 模型文件目录
-# 文本检测模型
-DET_MODEL_PATH = os.path.join(BASE_DIR, 'backend', 'models', 'ch_det')
-# 文本识别模型
-REC_MODEL_PATH = os.path.join(BASE_DIR, 'backend', 'models', 'ch_rec')
-
-# 查看该路径下是否有文本模型识别完整文件,没有的话合并小文件生成完整文件
-if 'inference.pdiparams' not in (os.listdir(REC_MODEL_PATH)):
- fs = Filesplit()
- fs.merge(input_dir=REC_MODEL_PATH)
-
-# 字典路径
-DICT_PATH = os.path.join(BASE_DIR, 'backend', 'ppocr', 'utils', 'ppocr_keys_v1.txt')
-
-
-# 默认字幕出现的大致区域
-class SubtitleArea(Enum):
- # 字幕区域出现在下半部分
- LOWER_PART = 0
- # 字幕区域出现在上半部分
- UPPER_PART = 1
- # 不知道字幕区域可能出现的位置
- UNKNOWN = 2
- # 明确知道字幕区域出现的位置
- CUSTOM = 3
-# --------------------- 请你不要改 end-----------------------------
-
-
-# --------------------- 请根据自己的实际情况改 start-----------------------------
-# 是否使用GPU
-# 使用GPU可以提速20倍+,你要是有N卡你就改成 True
-USE_GPU = False
-
-# 默认字幕出现区域为下方
-SUBTITLE_AREA = SubtitleArea.LOWER_PART
-
-# 余弦相似度阈值
-# 数值越小生成的视频帧越少,相对提取速度更快但生成的字幕越不精准
-# 1表示最精准,每一帧视频帧都进行字幕检测与提取,生成的字幕最精准
-# 0.925表示,当视频帧1与视频帧2相似度高达92.5%时,视频帧2将直接pass,不字检测与提取视频帧2的字幕
-COSINE_SIMILARITY_THRESHOLD = 0.95 if SUBTITLE_AREA == SubtitleArea.UNKNOWN else 0.91
-
-# 欧式距离相似值
-EUCLIDEAN_SIMILARITY_THRESHOLD = 0.9
-
-# 容忍的像素点偏差
-PIXEL_TOLERANCE_Y = 50 # 允许检测框纵向偏差50个像素点
-PIXEL_TOLERANCE_X = 100 # 允许检测框横向偏差100个像素点
-
-# 字幕区域偏移量
-SUBTITLE_AREA_DEVIATION_PIXEL = 50
-
-# 最有可能出现的水印区域
-WATERMARK_AREA_NUM = 5
-
-# 文本相似度阈值
-# 用于去重时判断两行字幕是不是统一行
-TEXT_SIMILARITY_THRESHOLD = 0.95
-# --------------------- 请根据自己的实际情况改 end-----------------------------
diff --git a/design/UI design.png b/design/UI design.png
new file mode 100644
index 00000000..fb2e0ed3
Binary files /dev/null and b/design/UI design.png differ
diff --git a/design/bg.png b/design/bg.png
new file mode 100644
index 00000000..9223cb14
Binary files /dev/null and b/design/bg.png differ
diff --git a/design/demo.gif b/design/demo.gif
new file mode 100644
index 00000000..276874e8
Binary files /dev/null and b/design/demo.gif differ
diff --git a/design/demo.png b/design/demo.png
new file mode 100644
index 00000000..a9c20e8b
Binary files /dev/null and b/design/demo.png differ
diff --git a/design/demo2.gif b/design/demo2.gif
new file mode 100644
index 00000000..d96b9927
Binary files /dev/null and b/design/demo2.gif differ
diff --git a/design/gui.spec b/design/gui.spec
new file mode 100644
index 00000000..4fe25080
--- /dev/null
+++ b/design/gui.spec
@@ -0,0 +1,41 @@
+# -*- mode: python ; coding: utf-8 -*-
+
+
+block_cipher = None
+
+
+a = Analysis(['gui.py'],
+ pathex=['/Users/yao/anaconda3/envs/subEnv/lib/python3.7/site-packages', '/Users/yao/Github/video-subtitle-extractor'],
+ binaries=[('/Users/yao/Github/video-subtitle-extractor/dylib/libgeos_c.dylib', '.')],
+ datas=[('/Users/yao/Github/video-subtitle-extractor/backend', 'backend'),
+ ('/Users/yao/Github/video-subtitle-extractor/vse.ico', '.')
+ ],
+ hiddenimports=['imgaug', 'skimage.filters.rank.core_cy_3d',
+ 'pyclipper', 'lmdb'],
+ hookspath=[],
+ runtime_hooks=[],
+ excludes=[],
+ win_no_prefer_redirects=False,
+ win_private_assemblies=False,
+ cipher=block_cipher,
+ noarchive=False)
+pyz = PYZ(a.pure, a.zipped_data,
+ cipher=block_cipher)
+exe = EXE(pyz,
+ a.scripts,
+ a.binaries,
+ a.zipfiles,
+ a.datas,
+ [],
+ name='vse',
+ debug=False,
+ bootloader_ignore_signals=False,
+ strip=False,
+ upx=True,
+ upx_exclude=[],
+ runtime_tmpdir=None,
+ console=False , icon='vse.ico')
+app = BUNDLE(exe,
+ name='Subtitle Extractor.app',
+ icon='vse.ico',
+ bundle_identifier=None)
diff --git a/design/paper (2020).pdf b/design/paper (2020).pdf
new file mode 100644
index 00000000..440a065a
Binary files /dev/null and b/design/paper (2020).pdf differ
diff --git a/design/vse.ico b/design/vse.ico
new file mode 100644
index 00000000..84a782b7
Binary files /dev/null and b/design/vse.ico differ
diff --git a/google_colab.ipynb b/google_colab.ipynb
new file mode 100644
index 00000000..61c4789e
--- /dev/null
+++ b/google_colab.ipynb
@@ -0,0 +1,164 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "WAJ7lA2wuvR8"
+ },
+ "source": [
+ "# 运行教程\n",
+ "\n",
+ "1. 点击“修改” -> \"笔记本设置\" -> \"硬件加速器GPU\" -> 保存\n",
+ "\n",
+ "\n",
+ "2. 点左侧文件夹图标,上传视频文件,复制上传的视频路径\n",
+ "\n",
+ "\n",
+ "\n",
+ "3. 运行代码, 输入粘贴的视频路径\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "_jPi_FBwyZyr"
+ },
+ "source": [
+ "查看是否有GPU"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "eHPHc_Bheo-j"
+ },
+ "outputs": [],
+ "source": [
+ "!nvidia-smi"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "TkQKKGKZkkT2"
+ },
+ "outputs": [],
+ "source": [
+ "!nvcc -V"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "_85O6zgPyhto"
+ },
+ "source": [
+ "# 安装依赖"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "ICeq0T1FeqjT"
+ },
+ "outputs": [],
+ "source": [
+ "!git clone --depth=1 https://github.com/YaoFANGUK/video-subtitle-extractor.git"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "GHutEWynkMKR"
+ },
+ "outputs": [],
+ "source": [
+ "cd video-subtitle-extractor"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "ynJydzo1kMKR"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install -r requirements_gpu.txt"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "3-GdvmaGl-aF"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install paddlepaddle-gpu==2.4.2.post117 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "SGb0i3tPyw9Q"
+ },
+ "source": [
+ "# 运行程序"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "输入视频路径,如:/content/video-subtitle-extractor/test/test_cn2.mp4\n",
+ "\n",
+ "输入字幕区域,如:842 1069 72 1368"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "B2MPjMOOgGbD"
+ },
+ "outputs": [],
+ "source": [
+ "!python ./backend/main.py"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "collapsed_sections": [],
+ "name": "video-subtitle-extractor.ipynb",
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.5"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 1
+}
diff --git a/google_colab_en.ipynb b/google_colab_en.ipynb
new file mode 100644
index 00000000..aac8014e
--- /dev/null
+++ b/google_colab_en.ipynb
@@ -0,0 +1,177 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Set Up\n",
+ "\n",
+ "1. click “Edit” -> \"Notebook Settings\" -> \"Hardware accelerator, GPU\" -> Save\n",
+ "\n",
+ "\n",
+ "2. Click the folder icon on the left, upload your video file, and copy the uploaded video path\n",
+ "\n",
+ "\n",
+ "\n",
+ "3. Run the code, enter the pasted video path\n",
+ "\n",
+ "\n"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "_jPi_FBwyZyr"
+ },
+ "source": [
+ "check whether GPU works"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "eHPHc_Bheo-j"
+ },
+ "outputs": [],
+ "source": [
+ "!nvidia-smi"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "TkQKKGKZkkT2"
+ },
+ "outputs": [],
+ "source": [
+ "!nvcc -V"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "_85O6zgPyhto"
+ },
+ "source": [
+ "# Install Dependencies"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "ICeq0T1FeqjT"
+ },
+ "outputs": [],
+ "source": [
+ "!git clone --depth=1 https://github.com/YaoFANGUK/video-subtitle-extractor.git"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "GHutEWynkMKR"
+ },
+ "outputs": [],
+ "source": [
+ "cd video-subtitle-extractor"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "ynJydzo1kMKR"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install -r requirements_gpu.txt"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "outputs": [],
+ "source": [
+ "!echo -e '[DEFAULT]\\nInterface = English\\nLanguage = en\\nMode = fast' > /content/video-subtitle-extractor/settings.ini"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "3-GdvmaGl-aF"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install paddlepaddle-gpu==2.4.2.post117 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "SGb0i3tPyw9Q"
+ },
+ "source": [
+ "# Run"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Here is an example:\n",
+ "\n",
+ "input video path: /content/video-subtitle-extractor/test/test_en.mp4\n",
+ "\n",
+ "input subtitle area: 612 717 90 1191"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "B2MPjMOOgGbD"
+ },
+ "outputs": [],
+ "source": [
+ "!python ./backend/main.py"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "collapsed_sections": [],
+ "name": "video-subtitle-extractor.ipynb",
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.5"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 1
+}
diff --git a/gui.py b/gui.py
new file mode 100644
index 00000000..b9b9dfc8
--- /dev/null
+++ b/gui.py
@@ -0,0 +1,592 @@
+# -*- coding: utf-8 -*-
+"""
+@Author : Fang Yao
+@Time : 2021/4/1 6:07 下午
+@FileName: gui.py
+@desc: 字幕提取器图形化界面
+"""
+import backend.main
+import os
+import configparser
+import PySimpleGUI as sg
+import cv2
+from threading import Thread
+import multiprocessing
+
+
+class SubtitleExtractorGUI:
+ def _load_config(self):
+ self.config_file = os.path.join(os.path.dirname(__file__), 'settings.ini')
+ self.subtitle_config_file = os.path.join(os.path.dirname(__file__), 'subtitle.ini')
+ self.config = configparser.ConfigParser()
+ self.interface_config = configparser.ConfigParser()
+ if not os.path.exists(self.config_file):
+ # 如果没有配置文件,默认弹出语言选择界面
+ LanguageModeGUI(self).run()
+ self.INTERFACE_KEY_NAME_MAP = {
+ '简体中文': 'ch',
+ '繁體中文': 'chinese_cht',
+ 'English': 'en',
+ '한국어': 'ko',
+ '日本語': 'japan',
+ 'Tiếng Việt': 'vi',
+ 'Español': 'es'
+ }
+ self.config.read(self.config_file, encoding='utf-8')
+ self.interface_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'backend', 'interface',
+ f"{self.INTERFACE_KEY_NAME_MAP[self.config['DEFAULT']['Interface']]}.ini")
+ self.interface_config.read(self.interface_file, encoding='utf-8')
+
+ def __init__(self):
+ # 初次运行检查运行环境是否正常
+ from paddle import utils
+ utils.run_check()
+ self.font = 'Arial 10'
+ self.theme = 'LightBrown12'
+ sg.theme(self.theme)
+ self.icon = os.path.join(os.path.dirname(__file__), 'design', 'vse.ico')
+ self._load_config()
+ self.screen_width, self.screen_height = sg.Window.get_screen_size()
+ print(self.screen_width, self.screen_height)
+ # 设置视频预览区域大小
+ self.video_preview_width = 960
+ self.video_preview_height = self.video_preview_width * 9 // 16
+ # 默认组件大小
+ self.horizontal_slider_size = (120, 20)
+ self.output_size = (100, 10)
+ self.progressbar_size = (60, 20)
+ # 分辨率低于1080
+ if self.screen_width // 2 < 960:
+ self.video_preview_width = 640
+ self.video_preview_height = self.video_preview_width * 9 // 16
+ self.horizontal_slider_size = (60, 20)
+ self.output_size = (58, 10)
+ self.progressbar_size = (28, 20)
+ # 字幕提取器布局
+ self.layout = None
+ # 字幕提取其窗口
+ self.window = None
+ # 视频路径
+ self.video_path = None
+ # 视频cap
+ self.video_cap = None
+ # 视频的帧率
+ self.fps = None
+ # 视频的帧数
+ self.frame_count = None
+ # 视频的宽
+ self.frame_width = None
+ # 视频的高
+ self.frame_height = None
+ # 设置字幕区域高宽
+ self.xmin = None
+ self.xmax = None
+ self.ymin = None
+ self.ymax = None
+ # 字幕提取器
+ self.se = None
+
+ def run(self):
+ # 创建布局
+ self._create_layout()
+ # 创建窗口
+ self.window = sg.Window(title=self.interface_config['SubtitleExtractorGUI']['Title'], layout=self.layout,
+ icon=self.icon)
+ while True:
+ # 循环读取事件
+ event, values = self.window.read(timeout=10)
+ # 处理【打开】事件
+ self._file_event_handler(event, values)
+ # 处理【滑动】事件
+ self._slide_event_handler(event, values)
+ # 处理【识别语言】事件
+ self._language_mode_event_handler(event)
+ # 处理【运行】事件
+ self._run_event_handler(event, values)
+ # 如果关闭软件,退出
+ if event == sg.WIN_CLOSED:
+ break
+ # 更新进度条
+ if self.se is not None:
+ self.window['-PROG-'].update(self.se.progress_total)
+ if self.se.isFinished:
+ # 1) 打开修改字幕滑块区域按钮
+ self.window['-Y-SLIDER-'].update(disabled=False)
+ self.window['-X-SLIDER-'].update(disabled=False)
+ self.window['-Y-SLIDER-H-'].update(disabled=False)
+ self.window['-X-SLIDER-W-'].update(disabled=False)
+ # 2) 打开【运行】、【打开】和【识别语言】按钮
+ self.window['-RUN-'].update(disabled=False)
+ self.window['-FILE-'].update(disabled=False)
+ self.window['-FILE_BTN-'].update(disabled=False)
+ self.window['-LANGUAGE-MODE-'].update(disabled=False)
+ self.se = None
+ if len(self.video_paths) >= 1:
+ # 1) 关闭修改字幕滑块区域按钮
+ self.window['-Y-SLIDER-'].update(disabled=True)
+ self.window['-X-SLIDER-'].update(disabled=True)
+ self.window['-Y-SLIDER-H-'].update(disabled=True)
+ self.window['-X-SLIDER-W-'].update(disabled=True)
+ # 2) 关闭【运行】、【打开】和【识别语言】按钮
+ self.window['-RUN-'].update(disabled=True)
+ self.window['-FILE-'].update(disabled=True)
+ self.window['-FILE_BTN-'].update(disabled=True)
+ self.window['-LANGUAGE-MODE-'].update(disabled=True)
+
+ def update_interface_text(self):
+ self._load_config()
+ self.window.set_title(self.interface_config['SubtitleExtractorGUI']['Title'])
+ self.window['-FILE_BTN-'].Update(self.interface_config['SubtitleExtractorGUI']['Open'])
+ self.window['-FRAME1-'].Update(self.interface_config['SubtitleExtractorGUI']['Vertical'])
+ self.window['-FRAME2-'].Update(self.interface_config['SubtitleExtractorGUI']['Horizontal'])
+ self.window['-RUN-'].Update(self.interface_config['SubtitleExtractorGUI']['Run'])
+ self.window['-LANGUAGE-MODE-'].Update(self.interface_config['SubtitleExtractorGUI']['Setting'])
+
+ def _create_layout(self):
+ """
+ 创建字幕提取器布局
+ """
+ garbage = os.path.join(os.path.dirname(__file__), 'output')
+ if os.path.exists(garbage):
+ import shutil
+ shutil.rmtree(garbage, True)
+ self.layout = [
+ # 显示视频预览
+ [sg.Image(size=(self.video_preview_width, self.video_preview_height), background_color='black',
+ key='-DISPLAY-')],
+ # 打开按钮 + 快进快退条
+ [sg.Input(key='-FILE-', visible=False, enable_events=True),
+ sg.FilesBrowse(button_text=self.interface_config['SubtitleExtractorGUI']['Open'], file_types=((
+ self.interface_config['SubtitleExtractorGUI']['AllFile'], '*.*'), ('mp4', '*.mp4'),
+ ('flv', '*.flv'),
+ ('wmv', '*.wmv'),
+ ('avi', '*.avi')),
+ key='-FILE_BTN-', size=(10, 1), font=self.font),
+ sg.Slider(size=self.horizontal_slider_size, range=(1, 1), key='-SLIDER-', orientation='h',
+ enable_events=True, font=self.font,
+ disable_number_display=True),
+ ],
+ # 输出区域
+ [sg.Output(size=self.output_size, font=self.font),
+ sg.Frame(title=self.interface_config['SubtitleExtractorGUI']['Vertical'], font=self.font, key='-FRAME1-',
+ layout=[[
+ sg.Slider(range=(0, 0), orientation='v', size=(10, 20),
+ disable_number_display=True,
+ enable_events=True, font=self.font,
+ pad=((10, 10), (20, 20)),
+ default_value=0, key='-Y-SLIDER-'),
+ sg.Slider(range=(0, 0), orientation='v', size=(10, 20),
+ disable_number_display=True,
+ enable_events=True, font=self.font,
+ pad=((10, 10), (20, 20)),
+ default_value=0, key='-Y-SLIDER-H-'),
+ ]], pad=((15, 5), (0, 0))),
+ sg.Frame(title=self.interface_config['SubtitleExtractorGUI']['Horizontal'], font=self.font, key='-FRAME2-',
+ layout=[[
+ sg.Slider(range=(0, 0), orientation='v', size=(10, 20),
+ disable_number_display=True,
+ pad=((10, 10), (20, 20)),
+ enable_events=True, font=self.font,
+ default_value=0, key='-X-SLIDER-'),
+ sg.Slider(range=(0, 0), orientation='v', size=(10, 20),
+ disable_number_display=True,
+ pad=((10, 10), (20, 20)),
+ enable_events=True, font=self.font,
+ default_value=0, key='-X-SLIDER-W-'),
+ ]], pad=((15, 5), (0, 0)))
+ ],
+
+ # 运行按钮 + 进度条
+ [sg.Button(button_text=self.interface_config['SubtitleExtractorGUI']['Run'], key='-RUN-',
+ font=self.font, size=(20, 1)),
+ sg.Button(button_text=self.interface_config['SubtitleExtractorGUI']['Setting'], key='-LANGUAGE-MODE-',
+ font=self.font, size=(20, 1)),
+ sg.ProgressBar(100, orientation='h', size=self.progressbar_size, key='-PROG-', auto_size_text=True)
+ ],
+ ]
+
+ def _file_event_handler(self, event, values):
+ """
+ 当点击打开按钮时:
+ 1)打开视频文件,将画布显示视频帧
+ 2)获取视频信息,初始化进度条滑块范围
+ """
+ if event == '-FILE-':
+ self.video_paths = values['-FILE-'].split(';')
+ self.video_path = self.video_paths[0]
+ if self.video_path != '':
+ self.video_cap = cv2.VideoCapture(self.video_path)
+ if self.video_cap is None:
+ return
+ if self.video_cap.isOpened():
+ ret, frame = self.video_cap.read()
+ if ret:
+ for video in self.video_paths:
+ print(f"{self.interface_config['SubtitleExtractorGUI']['OpenVideoSuccess']}:{video}")
+ # 获取视频的帧数
+ self.frame_count = self.video_cap.get(cv2.CAP_PROP_FRAME_COUNT)
+ # 获取视频的高度
+ self.frame_height = self.video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
+ # 获取视频的宽度
+ self.frame_width = self.video_cap.get(cv2.CAP_PROP_FRAME_WIDTH)
+ # 获取视频的帧率
+ self.fps = self.video_cap.get(cv2.CAP_PROP_FPS)
+ # 调整视频帧大小,使播放器能够显示
+ resized_frame = self._img_resize(frame)
+ # resized_frame = cv2.resize(src=frame, dsize=(self.video_preview_width, self.video_preview_height))
+ # 显示视频帧
+ self.window['-DISPLAY-'].update(data=cv2.imencode('.png', resized_frame)[1].tobytes())
+ # 更新视频进度条滑块range
+ self.window['-SLIDER-'].update(range=(1, self.frame_count))
+ self.window['-SLIDER-'].update(1)
+ # 预设字幕区域位置
+ y_p, h_p, x_p, w_p = self.parse_subtitle_config()
+ y = self.frame_height * y_p
+ h = self.frame_height * h_p
+ x = self.frame_width * x_p
+ w = self.frame_width * w_p
+ # 更新视频字幕位置滑块range
+ # 更新Y-SLIDER范围
+ self.window['-Y-SLIDER-'].update(range=(0, self.frame_height), disabled=False)
+ # 更新Y-SLIDER默认值
+ self.window['-Y-SLIDER-'].update(y)
+ # 更新X-SLIDER范围
+ self.window['-X-SLIDER-'].update(range=(0, self.frame_width), disabled=False)
+ # 更新X-SLIDER默认值
+ self.window['-X-SLIDER-'].update(x)
+ # 更新Y-SLIDER-H范围
+ self.window['-Y-SLIDER-H-'].update(range=(0, self.frame_height - y))
+ # 更新Y-SLIDER-H默认值
+ self.window['-Y-SLIDER-H-'].update(h)
+ # 更新X-SLIDER-W范围
+ self.window['-X-SLIDER-W-'].update(range=(0, self.frame_width - x))
+ # 更新X-SLIDER-W默认值
+ self.window['-X-SLIDER-W-'].update(w)
+ self._update_preview(frame, (y, h, x, w))
+
+ def _language_mode_event_handler(self, event):
+ if event != '-LANGUAGE-MODE-':
+ return
+ if 'OK' == LanguageModeGUI(self).run():
+ # 重新加载config
+ pass
+
+ def _run_event_handler(self, event, values):
+ """
+ 当点击运行按钮时:
+ 1) 禁止修改字幕滑块区域
+ 2) 禁止再次点击【运行】和【打开】按钮
+ 3) 设定字幕区域位置
+ """
+ if event == '-RUN-':
+ if self.video_cap is None:
+ print(self.interface_config['SubtitleExtractorGUI']['OpenVideoFirst'])
+ else:
+ # 1) 禁止修改字幕滑块区域
+ self.window['-Y-SLIDER-'].update(disabled=True)
+ self.window['-X-SLIDER-'].update(disabled=True)
+ self.window['-Y-SLIDER-H-'].update(disabled=True)
+ self.window['-X-SLIDER-W-'].update(disabled=True)
+ # 2) 禁止再次点击【运行】、【打开】和【识别语言】按钮
+ self.window['-RUN-'].update(disabled=True)
+ self.window['-FILE-'].update(disabled=True)
+ self.window['-FILE_BTN-'].update(disabled=True)
+ self.window['-LANGUAGE-MODE-'].update(disabled=True)
+ # 3) 设定字幕区域位置
+ self.xmin = int(values['-X-SLIDER-'])
+ self.xmax = int(values['-X-SLIDER-'] + values['-X-SLIDER-W-'])
+ self.ymin = int(values['-Y-SLIDER-'])
+ self.ymax = int(values['-Y-SLIDER-'] + values['-Y-SLIDER-H-'])
+ if self.ymax > self.frame_height:
+ self.ymax = self.frame_height
+ if self.xmax > self.frame_width:
+ self.xmax = self.frame_width
+ print(f"{self.interface_config['SubtitleExtractorGUI']['SubtitleArea']}:({self.ymin},{self.ymax},{self.xmin},{self.xmax})")
+ subtitle_area = (self.ymin, self.ymax, self.xmin, self.xmax)
+ y_p = self.ymin / self.frame_height
+ h_p = (self.ymax - self.ymin) / self.frame_height
+ x_p = self.xmin / self.frame_width
+ w_p = (self.xmax - self.xmin) / self.frame_width
+ self.set_subtitle_config(y_p, h_p, x_p, w_p)
+
+ def task():
+ while self.video_paths:
+ video_path = self.video_paths.pop()
+ self.se = backend.main.SubtitleExtractor(video_path, subtitle_area)
+ self.se.run()
+ Thread(target=task, daemon=True).start()
+ self.video_cap.release()
+ self.video_cap = None
+
+ def _slide_event_handler(self, event, values):
+ """
+ 当滑动视频进度条/滑动字幕选择区域滑块时:
+ 1) 判断视频是否存在,如果存在则显示对应的视频帧
+ 2) 绘制rectangle
+ """
+ if event == '-SLIDER-' or event == '-Y-SLIDER-' or event == '-Y-SLIDER-H-' or event == '-X-SLIDER-' or event \
+ == '-X-SLIDER-W-':
+ if self.video_cap is not None and self.video_cap.isOpened():
+ frame_no = int(values['-SLIDER-'])
+ self.video_cap.set(cv2.CAP_PROP_POS_FRAMES, frame_no)
+ ret, frame = self.video_cap.read()
+ if ret:
+ self.window['-Y-SLIDER-H-'].update(range=(0, self.frame_height-values['-Y-SLIDER-']))
+ self.window['-X-SLIDER-W-'].update(range=(0, self.frame_width-values['-X-SLIDER-']))
+ # 画字幕框
+ y = int(values['-Y-SLIDER-'])
+ h = int(values['-Y-SLIDER-H-'])
+ x = int(values['-X-SLIDER-'])
+ w = int(values['-X-SLIDER-W-'])
+ self._update_preview(frame, (y, h, x, w))
+
+ def _update_preview(self, frame, y_h_x_w):
+ y, h, x, w = y_h_x_w
+ # 画字幕框
+ draw = cv2.rectangle(img=frame, pt1=(int(x), int(y)), pt2=(int(x) + int(w), int(y) + int(h)),
+ color=(0, 255, 0), thickness=3)
+ # 调整视频帧大小,使播放器能够显示
+ resized_frame = self._img_resize(draw)
+ # 显示视频帧
+ self.window['-DISPLAY-'].update(data=cv2.imencode('.png', resized_frame)[1].tobytes())
+
+
+ def _img_resize(self, image):
+ top, bottom, left, right = (0, 0, 0, 0)
+ height, width = image.shape[0], image.shape[1]
+ # 对长短不想等的图片,找到最长的一边
+ longest_edge = height
+ # 计算短边需要增加多少像素宽度使其与长边等长
+ if width < longest_edge:
+ dw = longest_edge - width
+ left = dw // 2
+ right = dw - left
+ else:
+ pass
+ # 给图像增加边界
+ constant = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=[0, 0, 0])
+ return cv2.resize(constant, (self.video_preview_width, self.video_preview_height))
+
+ def set_subtitle_config(self, y, h, x, w):
+ # 写入配置文件
+ with open(self.subtitle_config_file, mode='w', encoding='utf-8') as f:
+ f.write('[AREA]\n')
+ f.write(f'Y = {y}\n')
+ f.write(f'H = {h}\n')
+ f.write(f'X = {x}\n')
+ f.write(f'W = {w}\n')
+
+ def parse_subtitle_config(self):
+ y_p, h_p, x_p, w_p = .78, .21, .05, .9
+ # 如果配置文件不存在,则写入配置文件
+ if not os.path.exists(self.subtitle_config_file):
+ self.set_subtitle_config(y_p, h_p, x_p, w_p)
+ return y_p, h_p, x_p, w_p
+ else:
+ try:
+ config = configparser.ConfigParser()
+ config.read(self.subtitle_config_file, encoding='utf-8')
+ conf_y_p, conf_h_p, conf_x_p, conf_w_p = float(config['AREA']['Y']), float(config['AREA']['H']), float(config['AREA']['X']), float(config['AREA']['W'])
+ return conf_y_p, conf_h_p, conf_x_p, conf_w_p
+ except Exception:
+ self.set_subtitle_config(y_p, h_p, x_p, w_p)
+ return y_p, h_p, x_p, w_p
+
+
+class LanguageModeGUI:
+ def __init__(self, subtitle_extractor_gui):
+ self.subtitle_extractor_gui = subtitle_extractor_gui
+ self.icon = os.path.join(os.path.dirname(__file__), 'design', 'vse.ico')
+ self.config_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'settings.ini')
+ # 设置界面
+ self.INTERFACE_DEF = '简体中文'
+ if not os.path.exists(self.config_file):
+ self.interface_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'backend', 'interface',
+ "ch.ini")
+ self.interface_config = configparser.ConfigParser()
+ # 设置语言
+ self.INTERFACE_KEY_NAME_MAP = {
+ '简体中文': 'ch',
+ '繁體中文': 'chinese_cht',
+ 'English': 'en',
+ '한국어': 'ko',
+ '日本語': 'japan',
+ 'Tiếng Việt': 'vi',
+ 'Español': 'es'
+ }
+ # 设置语言
+ self.LANGUAGE_DEF = 'ch'
+ self.LANGUAGE_NAME_KEY_MAP = None
+ self.LANGUAGE_KEY_NAME_MAP = None
+ self.MODE_DEF = 'fast'
+ self.MODE_NAME_KEY_MAP = None
+ self.MODE_KEY_NAME_MAP = None
+ # 语言选择布局
+ self.layout = None
+ # 语言选择窗口
+ self.window = None
+
+ def run(self):
+ # 创建布局
+ title = self._create_layout()
+ # 创建窗口
+ self.window = sg.Window(title=title, layout=self.layout, icon=self.icon)
+ while True:
+ # 循环读取事件
+ event, values = self.window.read(timeout=10)
+ # 处理【OK】事件
+ self._ok_event_handler(event, values)
+ # 处理【切换界面语言】事件
+ self._interface_event_handler(event, values)
+ # 如果关闭软件,退出
+ if event == sg.WIN_CLOSED:
+ if os.path.exists(self.config_file):
+ break
+ else:
+ exit(0)
+ if event == 'Cancel':
+ if os.path.exists(self.config_file):
+ self.window.close()
+ break
+ else:
+ exit(0)
+
+ def _load_interface_text(self):
+ self.interface_config.read(self.interface_file, encoding='utf-8')
+ config_language_mode_gui = self.interface_config["LanguageModeGUI"]
+ # 设置界面
+ self.INTERFACE_DEF = config_language_mode_gui["InterfaceDefault"]
+
+ self.LANGUAGE_DEF = config_language_mode_gui["LanguageCH"]
+ self.LANGUAGE_NAME_KEY_MAP = {}
+ for lang in backend.main.config.MULTI_LANG:
+ self.LANGUAGE_NAME_KEY_MAP[config_language_mode_gui[f"Language{lang.upper()}"]] = lang
+ self.LANGUAGE_NAME_KEY_MAP = dict(sorted(self.LANGUAGE_NAME_KEY_MAP.items(), key=lambda item: item[1]))
+ self.LANGUAGE_KEY_NAME_MAP = {v: k for k, v in self.LANGUAGE_NAME_KEY_MAP.items()}
+ self.MODE_DEF = config_language_mode_gui['ModeFast']
+ self.MODE_NAME_KEY_MAP = {
+ config_language_mode_gui['ModeAuto']: 'auto',
+ config_language_mode_gui['ModeFast']: 'fast',
+ config_language_mode_gui['ModeAccurate']: 'accurate',
+ }
+ self.MODE_KEY_NAME_MAP = {v: k for k, v in self.MODE_NAME_KEY_MAP.items()}
+
+ def _create_layout(self):
+ interface_def, language_def, mode_def = self.parse_config(self.config_file)
+ # 加载界面文本
+ self._load_interface_text()
+ choose_language_text = self.interface_config["LanguageModeGUI"]["InterfaceLanguage"]
+ choose_sub_lang_text = self.interface_config["LanguageModeGUI"]["SubtitleLanguage"]
+ choose_mode_text = self.interface_config["LanguageModeGUI"]["Mode"]
+ self.layout = [
+ # 显示选择界面语言
+ [sg.Text(choose_language_text),
+ sg.DropDown(values=list(self.INTERFACE_KEY_NAME_MAP.keys()), size=(30, 20),
+ pad=(0, 20),
+ key='-INTERFACE-', readonly=True,
+ default_value=interface_def),
+ sg.OK(key='-INTERFACE-OK-')],
+ # 显示选择字幕语言
+ [sg.Text(choose_sub_lang_text),
+ sg.DropDown(values=list(self.LANGUAGE_NAME_KEY_MAP.keys()), size=(30, 20),
+ pad=(0, 20),
+ key='-LANGUAGE-', readonly=True, default_value=language_def)],
+ # 显示识别模式
+ [sg.Text(choose_mode_text),
+ sg.DropDown(values=list(self.MODE_NAME_KEY_MAP.keys()), size=(30, 20), pad=(0, 20),
+ key='-MODE-', readonly=True, default_value=mode_def)],
+ # 显示确认关闭按钮
+ [sg.OK(), sg.Cancel()]
+ ]
+ return self.interface_config["LanguageModeGUI"]["Title"]
+
+ def _ok_event_handler(self, event, values):
+ if event == 'OK':
+ # 设置模型语言配置
+ interface = None
+ language = None
+ mode = None
+ # 设置界面语言
+ interface_str = values['-INTERFACE-']
+ if interface_str in self.INTERFACE_KEY_NAME_MAP:
+ interface = interface_str
+ language_str = values['-LANGUAGE-']
+ # 设置字幕语言
+ print(self.interface_config["LanguageModeGUI"]["SubtitleLanguage"], language_str)
+ if language_str in self.LANGUAGE_NAME_KEY_MAP:
+ language = self.LANGUAGE_NAME_KEY_MAP[language_str]
+ # 设置模型语言配置
+ mode_str = values['-MODE-']
+ print(self.interface_config["LanguageModeGUI"]["Mode"], mode_str)
+ if mode_str in self.MODE_NAME_KEY_MAP:
+ mode = self.MODE_NAME_KEY_MAP[mode_str]
+ self.set_config(self.config_file, interface, language, mode)
+ if self.subtitle_extractor_gui is not None:
+ self.subtitle_extractor_gui.update_interface_text()
+ self.window.close()
+
+ def _interface_event_handler(self, event, values):
+ if event == '-INTERFACE-OK-':
+ self.interface_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'backend', 'interface',
+ f"{self.INTERFACE_KEY_NAME_MAP[values['-INTERFACE-']]}.ini")
+ self.interface_config.read(self.interface_file, encoding='utf-8')
+ config = configparser.ConfigParser()
+ if os.path.exists(self.config_file):
+ config.read(self.config_file, encoding='utf-8')
+ self.set_config(self.config_file, values['-INTERFACE-'], config['DEFAULT']['Language'],
+ config['DEFAULT']['Mode'])
+ self.window.close()
+ title = self._create_layout()
+ self.window = sg.Window(title=title, layout=self.layout, icon=self.icon)
+
+ @staticmethod
+ def set_config(config_file, interface, language_code, mode):
+ # 写入配置文件
+ with open(config_file, mode='w', encoding='utf-8') as f:
+ f.write('[DEFAULT]\n')
+ f.write(f'Interface = {interface}\n')
+ f.write(f'Language = {language_code}\n')
+ f.write(f'Mode = {mode}\n')
+
+ def parse_config(self, config_file):
+ if not os.path.exists(config_file):
+ self.interface_config.read(self.interface_file, encoding='utf-8')
+ interface_def = self.interface_config['LanguageModeGUI']['InterfaceDefault']
+ language_def = self.interface_config['LanguageModeGUI']['InterfaceDefault']
+ mode_def = self.interface_config['LanguageModeGUI']['ModeFast']
+ return interface_def, language_def, mode_def
+ config = configparser.ConfigParser()
+ config.read(config_file, encoding='utf-8')
+ interface = config['DEFAULT']['Interface']
+ language = config['DEFAULT']['Language']
+ mode = config['DEFAULT']['Mode']
+ self.interface_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'backend', 'interface',
+ f"{self.INTERFACE_KEY_NAME_MAP[interface]}.ini")
+ self._load_interface_text()
+ interface_def = interface if interface in self.INTERFACE_KEY_NAME_MAP else \
+ self.INTERFACE_DEF
+ language_def = self.LANGUAGE_KEY_NAME_MAP[language] if language in self.LANGUAGE_KEY_NAME_MAP else \
+ self.LANGUAGE_DEF
+ mode_def = self.MODE_KEY_NAME_MAP[mode] if mode in self.MODE_KEY_NAME_MAP else self.MODE_DEF
+ return interface_def, language_def, mode_def
+
+
+if __name__ == '__main__':
+ try:
+ multiprocessing.set_start_method("spawn")
+ # 运行图形化界面
+ subtitleExtractorGUI = SubtitleExtractorGUI()
+ subtitleExtractorGUI.run()
+ except Exception as e:
+ print(f'[{type(e)}] {e}')
+ import traceback
+ traceback.print_exc()
+ msg = traceback.format_exc()
+ err_log_path = os.path.join(os.path.expanduser('~'), 'VSE-Error-Message.log')
+ with open(err_log_path, 'w', encoding='utf-8') as f:
+ f.writelines(msg)
+ import platform
+ if platform.system() == 'Windows':
+ os.system('pause')
+ else:
+ input()
diff --git a/main.py b/main.py
deleted file mode 100644
index ad1c98ad..00000000
--- a/main.py
+++ /dev/null
@@ -1,545 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-@Author : Fang Yao
-@Time : 2021/3/24 9:28 上午
-@FileName: main.py
-@desc: 主程序入口文件
-"""
-import config
-from config import SubtitleArea
-from backend.tools.infer.predict_system import TextSystem
-from backend.tools.infer import utility
-import cv2
-import random
-import os
-import math
-from collections import Counter
-import numpy as np
-from PIL import Image
-from numpy import average, dot, linalg
-from Levenshtein import ratio
-
-
-# 加载文本检测+识别模型
-def load_model():
- # 获取参数对象
- args = utility.parse_args()
- # 设置文本检测模型路径
- args.det_model_dir = config.DET_MODEL_PATH
- # 设置文本识别模型路径
- args.rec_model_dir = config.REC_MODEL_PATH
- # 设置字典路径
- args.rec_char_dict_path = config.DICT_PATH
- # 是否使用GPU加速
- args.use_gpu = config.USE_GPU
- return TextSystem(args)
-
-
-class SubtitleExtractor:
- """
- 视频字幕提取类
- """
-
- def __init__(self, vd_path):
- # 视频路径
- self.video_path = vd_path
- self.video_cap = cv2.VideoCapture(vd_path)
- # 视频帧总数
- self.frame_count = self.video_cap.get(cv2.CAP_PROP_FRAME_COUNT)
- # 视频帧率
- self.fps = self.video_cap.get(cv2.CAP_PROP_FPS)
- # 视频尺寸
- self.frame_height = int(self.video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
- self.frame_width = int(self.video_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
- # 字幕出现区域
- self.subtitle_area = config.SUBTITLE_AREA
- print(f'帧数:{self.frame_count},帧率:{self.fps}')
- # 临时存储文件夹
- self.temp_output_dir = os.path.join(config.BASE_DIR, 'output')
- # 提取的视频帧储存目录
- self.frame_output_dir = os.path.join(self.temp_output_dir, 'frames')
- # 提取的字幕文件存储目录
- self.subtitle_output_dir = os.path.join(self.temp_output_dir, 'subtitle')
- # 不存在则创建文件夹
- if not os.path.exists(self.frame_output_dir):
- os.makedirs(self.frame_output_dir)
- if not os.path.exists(self.subtitle_output_dir):
- os.makedirs(self.subtitle_output_dir)
- # 提取的原始字幕文本存储路径
- self.raw_subtitle_path = os.path.join(self.subtitle_output_dir, 'raw.txt')
-
- def run(self):
- """
- 运行整个提取视频的步骤
- """
- self.extract_frame()
- self.extract_subtitles()
- self.detect_watermark_area()
- self.filter_watermark()
- self.detect_subtitle_area()
- self.filter_scene_text()
- self.generate_subtitle_file()
-
- def extract_frame(self):
- """
- 根据视频的分辨率,将高分辨的视频帧缩放到1280*720p
- 根据字幕区域位置,将该图像区域截取出来
- """
- # 当前视频帧的帧号
- frame_no = 0
-
- while self.video_cap.isOpened():
- ret, frame = self.video_cap.read()
- # 如果读取视频帧失败(视频读到最后一帧)
- if not ret:
- break
- # 读取视频帧成功
- else:
- frame_no += 1
- # 对于分辨率大于1920*1080的视频,将其视频帧进行等比缩放至1280*720进行识别
- # paddlepaddle会将图像压缩为640*640
- if self.frame_width > 1280:
- scale_rate = round(float(1280 / self.frame_width), 2)
- frame = cv2.resize(frame, None, fx=scale_rate, fy=scale_rate, interpolation=cv2.INTER_AREA)
-
- cropped = int(frame.shape[0] // 2)
-
- # 如果字幕出现的区域在下部分
- if self.subtitle_area == SubtitleArea.LOWER_PART:
- # 将视频帧切割为下半部分
- frame = frame[cropped:]
- # 如果字幕出现的区域在上半部分
- elif self.subtitle_area == SubtitleArea.UPPER_PART:
- # 将视频帧切割为下半部分
- frame = frame[:cropped]
-
- # 帧名往前补零,后续用于排序与时间戳转换,补足8位
- # 一部10h电影,fps120帧最多也才1*60*60*120=432000 6位,所以8位足够
- filename = os.path.join(self.frame_output_dir, str(frame_no).zfill(8) + '.jpg')
- # 保存视频帧
- cv2.imwrite(filename, frame)
-
- # 将当前帧与接下来的帧进行比较,计算余弦相似度
- while self.video_cap.isOpened():
- ret, frame_next = self.video_cap.read()
- if ret:
- frame_no += 1
- cosine_distance = self._compute_image_similarity(Image.fromarray(frame),
- Image.fromarray(frame_next))
- print(cosine_distance)
- if cosine_distance > config.COSINE_SIMILARITY_THRESHOLD:
- # 如果下一帧与当前帧的相似度大于设定阈值,则略过该帧
- continue
- # 如果相似度小于设定阈值,停止该while循环
- else:
- break
- else:
- break
-
- self.video_cap.release()
- cv2.destroyAllWindows()
-
- def extract_subtitles(self):
- """
- 提取视频帧中的字幕信息,生成一个txt文件
- """
- # 初始化文本识别对象
- text_recogniser = load_model()
- # 视频帧列表
- frame_list = [i for i in sorted(os.listdir(self.frame_output_dir)) if i.endswith('.jpg')]
-
- # 新建文件
- f = open(self.raw_subtitle_path, mode='w+', encoding='utf-8')
-
- for frame in frame_list:
- # 读取视频帧
- img = cv2.imread(os.path.join(self.frame_output_dir, frame))
- # 获取检测结果
- dt_box, rec_res = text_recogniser(img)
- # 获取文本坐标
- coordinates = self.__get_coordinates(dt_box)
- # 将结果写入txt文本中
- for content, coordinate in zip(([res[0] for res in rec_res]), coordinates):
- f.write(f'{os.path.splitext(frame)[0]}\t'
- f'{coordinate}\t'
- f'{content}\n')
- # 关闭文件
- f.close()
-
- def detect_watermark_area(self):
- """
- 根据识别出来的raw txt文件中的坐标点信息,查找水印区域
- 假定:水印区域(台标)的坐标在水平和垂直方向都是固定的,也就是具有(xmin, xmax, ymin, ymax)相对固定
- 根据坐标点信息,进行统计,将一直具有固定坐标的文本区域选出
- :return 返回最有可能的水印区域
- """
- f = open(self.raw_subtitle_path, mode='r', encoding='utf-8') # 打开txt文件,以‘utf-8’编码读取
- line = f.readline() # 以行的形式进行读取文件
- # 坐标点列表
- coordinates_list = []
- # 帧列表
- frame_no_list = []
- # 内容列表
- content_list = []
- while line:
- frame_no = line.split('\t')[0]
- text_position = line.split('\t')[1].split('(')[1].split(')')[0].split(', ')
- content = line.split('\t')[2]
- frame_no_list.append(frame_no)
- coordinates_list.append((int(text_position[0]),
- int(text_position[1]),
- int(text_position[2]),
- int(text_position[3])))
- content_list.append(content)
- line = f.readline()
- f.close()
- # 将坐标列表的相似值统一
- coordinates_list = self._unite_coordinates(coordinates_list)
-
- # 将原txt文件的坐标更新为归一后的坐标
- with open(self.raw_subtitle_path, mode='w', encoding='utf-8') as f:
- for frame_no, coordinate, content in zip(frame_no_list, coordinates_list, content_list):
- f.write(f'{frame_no}\t{coordinate}\t{content}')
-
- if len(Counter(coordinates_list).most_common()) > config.WATERMARK_AREA_NUM:
- # 读取配置文件,返回可能为水印区域的坐标列表
- return Counter(coordinates_list).most_common(config.WATERMARK_AREA_NUM)
- else:
- # 不够则有几个返回几个
- return Counter(coordinates_list).most_common()
-
- def filter_watermark(self):
- """
- 去除原始字幕文本中的水印区域的文本
- """
- # 获取潜在水印区域
- watermark_areas = self.detect_watermark_area()
-
- # 从frame目录随机读取一张图片,将所水印区域标记出来,用户看图判断是否是水印区域
- frame_path = os.path.join(self.frame_output_dir,
- random.choice(
- [i for i in sorted(os.listdir(self.frame_output_dir)) if i.endswith('.jpg')]))
- sample_frame = cv2.imread(frame_path)
-
- # 给潜在的水印区域编号
- area_num = ['E', 'D', 'C', 'B', 'A']
-
- for watermark_area in watermark_areas:
- ymin = watermark_area[0][2]
- ymax = watermark_area[0][3]
- xmin = watermark_area[0][0]
- xmax = watermark_area[0][1]
- cover = sample_frame[ymin:ymax, xmin:xmax]
- cover = cv2.blur(cover, (10, 10))
- cv2.rectangle(cover, pt1=(0, cover.shape[0]), pt2=(cover.shape[1], 0), color=(0, 0, 255), thickness=3)
- sample_frame[watermark_area[0][2]:watermark_area[0][3], watermark_area[0][0]:watermark_area[0][1]] = cover
- position = ((xmin + xmax) // 2, ymax)
-
- cv2.putText(sample_frame, text=area_num.pop(), org=position, fontFace=cv2.FONT_HERSHEY_SIMPLEX,
- fontScale=1, color=(255, 0, 0), thickness=2, lineType=cv2.LINE_AA)
-
- sample_frame_file_path = os.path.join(os.path.dirname(self.frame_output_dir), 'watermark_area.jpg')
- cv2.imwrite(sample_frame_file_path, sample_frame)
- print(f'请查看图片, 确定水印区域: {sample_frame_file_path}')
-
- area_num = ['E', 'D', 'C', 'B', 'A']
- for watermark_area in watermark_areas:
- user_input = input(f'是否去除区域{area_num.pop()}{str(watermark_area)}中的字幕?'
- f'\n输入 "y" 或 "回车" 表示去除,输入"n"或其他表示不去除: ').strip()
- if user_input == 'y' or user_input == '\n':
- with open(self.raw_subtitle_path, mode='r+', encoding='utf-8') as f:
- content = f.readlines()
- f.seek(0)
- for i in content:
- if i.find(str(watermark_area[0])) == -1:
- f.write(i)
- f.truncate()
- print(f'已经删除该区域字幕...')
- print('水印区域字幕过滤完毕...')
-
- def detect_subtitle_area(self):
- """
- 读取过滤水印区域后的raw txt文件,根据坐标信息,查找字幕区域
- 假定:字幕区域在y轴上有一个相对固定的坐标范围,相对于场景文本,这个范围出现频率更高
- :return 返回字幕的区域位置
- """
- # 打开去水印区域处理过的raw txt
- f = open(self.raw_subtitle_path, mode='r', encoding='utf-8') # 打开txt文件,以‘utf-8’编码读取
- line = f.readline() # 以行的形式进行读取文件
- # y坐标点列表
- y_coordinates_list = []
- while line:
- text_position = line.split('\t')[1].split('(')[1].split(')')[0].split(', ')
- y_coordinates_list.append((int(text_position[2]), int(text_position[3])))
- line = f.readline()
- f.close()
- return Counter(y_coordinates_list).most_common(1)
-
- def filter_scene_text(self):
- """
- 将场景里提取的文字过滤,仅保留字幕区域
- """
- # 获取潜在字幕区域
- subtitle_area = self.detect_subtitle_area()[0][0]
-
- # 从frame目录随机读取一张图片,将所水印区域标记出来,用户看图判断是否是水印区域
- frame_path = os.path.join(self.frame_output_dir,
- random.choice(
- [i for i in sorted(os.listdir(self.frame_output_dir)) if i.endswith('.jpg')]))
- sample_frame = cv2.imread(frame_path)
-
- # 为了防止有双行字幕,根据容忍度,将字幕区域y范围加高
- ymin = abs(subtitle_area[0] - config.SUBTITLE_AREA_DEVIATION_PIXEL)
- ymax = subtitle_area[1] + config.SUBTITLE_AREA_DEVIATION_PIXEL
- # 画出字幕框的区域
- cv2.rectangle(sample_frame, pt1=(0, ymin), pt2=(sample_frame.shape[1], ymax), color=(0, 0, 255), thickness=3)
- sample_frame_file_path = os.path.join(os.path.dirname(self.frame_output_dir), 'subtitle_area.jpg')
- cv2.imwrite(sample_frame_file_path, sample_frame)
- print(f'请查看图片, 确定字幕区域是否正确: {sample_frame_file_path}')
-
- user_input = input(f'是否去除红色框区域外{(ymin, ymax)}的字幕?'
- f'\n输入 "y" 或 "回车" 表示去除,输入"n"或其他表示不去除: ').strip()
- if user_input == 'y' or user_input == '\n':
- with open(self.raw_subtitle_path, mode='r+', encoding='utf-8') as f:
- content = f.readlines()
- f.seek(0)
- for i in content:
- i_ymin = int(i.split('\t')[1].split('(')[1].split(')')[0].split(', ')[2])
- i_ymax = int(i.split('\t')[1].split('(')[1].split(')')[0].split(', ')[3])
- if ymin <= i_ymin and i_ymax <= ymax:
- f.write(i)
- f.truncate()
-
- def generate_subtitle_file(self):
- """
- 生成srt格式的字幕文件
- """
- subtitle_content = self._remove_duplicate_subtitle()
- print(os.path.splitext(self.video_path)[0])
- srt_filename = os.path.join(os.path.splitext(self.video_path)[0] + '.srt')
- with open(srt_filename, mode='w', encoding='utf-8') as f:
- for index, content in enumerate(subtitle_content):
- line_code = index + 1
- frame_start = self._frame_to_timecode(int(content[0]))
- frame_end = self._frame_to_timecode(int(content[1]))
- frame_content = content[2]
- subtitle_line = f'{line_code}\n{frame_start} --> {frame_end}\n{frame_content}\n'
- f.write(subtitle_line)
-
- def _frame_to_timecode(self, frame_no):
- """
- 将视频帧转换成时间
- :param frame_no: 视频的帧号,i.e. 第几帧视频帧
- :param frame_rate: 视频的帧率
- :param drop: 帧率不为整数时,是否添加drop frame进行补帧
- :returns: SMPTE格式时间戳 as string, 如'01:02:12:32' 或者 '01:02:12;32'
- """
- # 将小数点后两位的数字丢弃
- tmp = str(self.fps).split('.')
- tmp[1] = tmp[1][:2]
- frame_rate = float('.'.join(tmp))
-
- drop = False
-
- if frame_rate in [29.97, 59.94]:
- drop = True
-
- # 将fps就近取整,如29.97或59.94取整为30和60
- fps_int = int(round(frame_rate))
- # 然后添加drop frames进行补偿
-
- if drop:
- # drop-frame-mode
- # 每分钟添加两个fake frames,每十分钟的时候不添加
- # 1分钟内,non-drop和drop的时间戳对比
- # frame: 1795 non-drop: 00:00:59:25 drop: 00:00:59;25
- # frame: 1796 non-drop: 00:00:59:26 drop: 00:00:59;26
- # frame: 1797 non-drop: 00:00:59:27 drop: 00:00:59;27
- # frame: 1798 non-drop: 00:00:59:28 drop: 00:00:59;28
- # frame: 1799 non-drop: 00:00:59:29 drop: 00:00:59;29
- # frame: 1800 non-drop: 00:01:00:00 drop: 00:01:00;02
- # frame: 1801 non-drop: 00:01:00:01 drop: 00:01:00;03
- # frame: 1802 non-drop: 00:01:00:02 drop: 00:01:00;04
- # frame: 1803 non-drop: 00:01:00:03 drop: 00:01:00;05
- # frame: 1804 non-drop: 00:01:00:04 drop: 00:01:00;06
- # frame: 1805 non-drop: 00:01:00:05 drop: 00:01:00;07
- #
- # 10分钟内,non-drop和drop的时间戳对比
- #
- # frame: 17977 non-drop: 00:09:59:07 drop: 00:09:59;25
- # frame: 17978 non-drop: 00:09:59:08 drop: 00:09:59;26
- # frame: 17979 non-drop: 00:09:59:09 drop: 00:09:59;27
- # frame: 17980 non-drop: 00:09:59:10 drop: 00:09:59;28
- # frame: 17981 non-drop: 00:09:59:11 drop: 00:09:59;29
- # frame: 17982 non-drop: 00:09:59:12 drop: 00:10:00;00
- # frame: 17983 non-drop: 00:09:59:13 drop: 00:10:00;01
- # frame: 17984 non-drop: 00:09:59:14 drop: 00:10:00;02
- # frame: 17985 non-drop: 00:09:59:15 drop: 00:10:00;03
- # frame: 17986 non-drop: 00:09:59:16 drop: 00:10:00;04
- # frame: 17987 non-drop: 00:09:59:17 drop: 00:10:00;05
-
- # 计算29.97 std NTSC工作流程的丢帧数。1分钟一共有30*60 = 1800 frames
-
- FRAMES_IN_ONE_MINUTE = 1800 - 2
-
- FRAMES_IN_TEN_MINUTES = (FRAMES_IN_ONE_MINUTE * 10) - 2
-
- ten_minute_chunks = frame_no / FRAMES_IN_TEN_MINUTES
- one_minute_chunks = frame_no % FRAMES_IN_TEN_MINUTES
-
- ten_minute_part = 18 * ten_minute_chunks
- one_minute_part = 2 * ((one_minute_chunks - 2) / FRAMES_IN_ONE_MINUTE)
-
- if one_minute_part < 0:
- one_minute_part = 0
-
- # 添加额外的帧
- frame_no += ten_minute_part + one_minute_part
-
- # 对于60 fps的drop frame计算, 添加两倍的帧数
- if fps_int == 60:
- frame_no = frame_no * 2
-
- # time codes are on the form 12:12:12;12
- smpte_token = ";"
-
- else:
- # time codes are on the form 12:12:12:12
- smpte_token = ","
-
- # 将视频帧转化为时间戳
- hours = int(frame_no / (3600 * fps_int))
- minutes = int(frame_no / (60 * fps_int) % 60)
- seconds = int(frame_no / fps_int % 60)
- frames = int(frame_no % fps_int)
- return "%02d:%02d:%02d%s%02d" % (hours, minutes, seconds, smpte_token, frames)
-
- def _remove_duplicate_subtitle(self):
- """
- 读取原始的raw txt,去除重复行,返回去除了重复后的字幕列表
- """
- with open(self.raw_subtitle_path, 'r') as r:
- lines = r.readlines()
- content_list = []
- for line in lines:
- frame_no = line.split('\t')[0]
- content = line.split('\t')[2]
- content_list.append((frame_no, content))
- # 循环遍历每行字幕,记录开始时间与结束时间
- index = 0
- # 去重后的字幕列表
- unique_subtitle_list = []
- for i in content_list:
- # TODO: 时间复杂度非常高,有待优化
- # 定义字幕开始帧帧号
- start_frame = i[0]
- for j in content_list[index:]:
- # 计算当前行与下一行的Levenshtein距离
- distance = ratio(i[1], j[1])
- if distance < config.TEXT_SIMILARITY_THRESHOLD or j == content_list[-1]:
- # 定义字幕结束帧帧号
- end_frame = content_list[content_list.index(j) - 1][0]
- if end_frame == start_frame:
- end_frame = j[0]
- if str(unique_subtitle_list).find(i[1].replace('\n', '')) == -1:
- unique_subtitle_list.append((start_frame, end_frame, i[1]))
- index += 1
- break
- else:
- continue
- return unique_subtitle_list
-
- def _unite_coordinates(self, coordinates_list):
- """
- 给定一个坐标列表,将这个列表中相似的坐标统一为一个值
- e.g. 由于检测框检测的结果不是一直的,相同位置文字的坐标可能一次检测为(255,123,456,789),另一次检测为(253,122,456,799)
- 因此要对相似的坐标进行值的统一
- :param coordinates_list 包含坐标点的列表
- :return: 返回一个统一值后的坐标列表
- """
- # 将相似的坐标统一为一个
- index = 0
- for coordinate in coordinates_list: # TODO:时间复杂度n^2,待优化
- for i in coordinates_list:
- if self.__is_coordinate_similar(coordinate, i):
- coordinates_list[index] = i
- index += 1
- return coordinates_list
-
- def _compute_image_similarity(self, image1, image2):
- """
- 计算两张图片的余弦相似度
- """
- image1 = self.__get_thum(image1)
- image2 = self.__get_thum(image2)
- images = [image1, image2]
- vectors = []
- norms = []
- for image in images:
- vector = []
- for pixel_tuple in image.getdata():
- vector.append(average(pixel_tuple))
- vectors.append(vector)
- # linalg=linear(线性)+algebra(代数),norm则表示范数
- # 求图片的范数
- norms.append(linalg.norm(vector, 2))
- a, b = vectors
- a_norm, b_norm = norms
- # dot返回的是点积,对二维数组(矩阵)进行计算
- res = dot(a / a_norm, b / b_norm)
- return res
-
- @staticmethod
- def __get_coordinates(dt_box):
- """
- 从返回的检测框中获取坐标
- :param dt_box 检测框返回结果
- :return list 坐标点列表
- """
- coordinate_list = list()
- if isinstance(dt_box, list):
- for i in dt_box:
- i = list(i)
- (x1, y1) = int(i[0][0]), int(i[0][1])
- (x2, y2) = int(i[1][0]), int(i[1][1])
- (x3, y3) = int(i[2][0]), int(i[2][1])
- (x4, y4) = int(i[3][0]), int(i[3][1])
- xmin = max(x1, x4)
- xmax = min(x2, x3)
- ymin = max(y1, y2)
- ymax = min(y3, y4)
- coordinate_list.append((xmin, xmax, ymin, ymax))
- return coordinate_list
-
- @staticmethod
- def __is_coordinate_similar(coordinate1, coordinate2):
- """
- 计算两个坐标是否相似,如果两个坐标点的xmin,xmax,ymin,ymax的差值都在像素点容忍度内
- 则认为这两个坐标点相似
- """
- return abs(coordinate1[0] - coordinate2[0]) < config.PIXEL_TOLERANCE_X and \
- abs(coordinate1[1] - coordinate2[1]) < config.PIXEL_TOLERANCE_X and \
- abs(coordinate1[2] - coordinate2[2]) < config.PIXEL_TOLERANCE_Y and \
- abs(coordinate1[3] - coordinate2[3]) < config.PIXEL_TOLERANCE_Y
-
- @staticmethod
- def __get_thum(image, size=(64, 64), greyscale=False):
- """
- 对图片进行统一化处理
- """
- # 利用image对图像大小重新设置, Image.ANTIALIAS为高质量的
- image = image.resize(size, Image.ANTIALIAS)
- if greyscale:
- # 将图片转换为L模式,其为灰度图,其每个像素用8个bit表示
- image = image.convert('L')
- return image
-
-
-if __name__ == '__main__':
- # 提示用户输入视频路径
- video_path = input("请输入视频完整路径:").strip()
- # 新建字幕提取对象
- se = SubtitleExtractor(video_path)
- # 开始提取字幕
- se.run()
-
diff --git a/requirements.txt b/requirements.txt
index c324df0d..3b952f7a 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,14 +1,13 @@
-shapely~=1.7.1
-scikit-image==0.17.2
+opencv-python==4.10.0.84
+python-Levenshtein==0.26.0
+pillow==10.4.0
+tqdm==4.66.5
+filesplit==3.0.2
+pysrt==1.1.2
+wordsegment==1.3.1
+scikit-image==0.24.0
+lmdb==1.5.1
imgaug==0.4.0
-pyclipper~=1.2.1
-lmdb~=1.1.1
-opencv-python==4.2.0.32
-tqdm~=4.59.0
-numpy~=1.19.0
-visualdl
-python-Levenshtein
-six~=1.15.0
-pillow~=8.1.2
-pyyaml~=5.4.1
-requests~=2.25.1
\ No newline at end of file
+pyclipper==1.3.0.post5
+PySimpleGUI==4.70.1
+numpy==1.26.4
\ No newline at end of file
diff --git a/test/test_ar.flv b/test/test_ar.flv
new file mode 100644
index 00000000..1edac235
Binary files /dev/null and b/test/test_ar.flv differ
diff --git a/test/test_chinese_cht.flv b/test/test_chinese_cht.flv
new file mode 100644
index 00000000..09818b91
Binary files /dev/null and b/test/test_chinese_cht.flv differ
diff --git a/test/test_cn.mp4 b/test/test_cn.mp4
new file mode 100644
index 00000000..e7899c05
Binary files /dev/null and b/test/test_cn.mp4 differ
diff --git a/test/test_cn2.mp4 b/test/test_cn2.mp4
new file mode 100644
index 00000000..c0ff2d44
Binary files /dev/null and b/test/test_cn2.mp4 differ
diff --git a/test/test_en.mp4 b/test/test_en.mp4
new file mode 100644
index 00000000..2d435b56
Binary files /dev/null and b/test/test_en.mp4 differ
diff --git a/test/test_en_ch.mp4 b/test/test_en_ch.mp4
new file mode 100644
index 00000000..2b058dc5
Binary files /dev/null and b/test/test_en_ch.mp4 differ
diff --git a/test/test_es.flv b/test/test_es.flv
new file mode 100644
index 00000000..8e8243df
Binary files /dev/null and b/test/test_es.flv differ
diff --git a/test/test_german.mp4 b/test/test_german.mp4
new file mode 100644
index 00000000..f31f9ae5
Binary files /dev/null and b/test/test_german.mp4 differ
diff --git a/test/test_it.flv b/test/test_it.flv
new file mode 100644
index 00000000..de8b917e
Binary files /dev/null and b/test/test_it.flv differ
diff --git a/test/test_japan.mp4 b/test/test_japan.mp4
new file mode 100644
index 00000000..f0f854db
Binary files /dev/null and b/test/test_japan.mp4 differ
diff --git a/test/test_korean.flv b/test/test_korean.flv
new file mode 100644
index 00000000..d0f0e86b
Binary files /dev/null and b/test/test_korean.flv differ
diff --git a/test/test_ru.flv b/test/test_ru.flv
new file mode 100644
index 00000000..cfba4cf9
Binary files /dev/null and b/test/test_ru.flv differ
| |