forked from voicegain/transcription-compare
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtranscribe-compare
148 lines (126 loc) · 6.1 KB
/
transcribe-compare
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
#!/usr/bin/env python
import click
from transcription_compare.levenshtein_distance_calculator import UKKLevenshteinDistanceCalculator
from transcription_compare.tokenizer import CharacterTokenizer, WordTokenizer
from transcription_compare.local_optimizer.digit_util import DigitUtil
from transcription_compare.local_optimizer.local_cer_optimizer import LocalCerOptimizer
import os
from transcription_compare.utils.plot_util import plot_alignment_result_only_distance
from transcription_compare.results import MultiResult
@click.command()
@click.option('--reference', '-r', type=str, help='source string')
@click.option('--output', '-o', type=str, multiple=True, help='target string')
@click.option('--reference_file', '-R', type=click.File('r'), help='source file path')
@click.option('--output_file', '-O', multiple=True, type=click.File('r'), help='target file path')
@click.option('--alignment', '-a', default=False, is_flag=True,
help='Do you want to see the alignment result? True/False')
@click.option('--error_type', '-e', default='CER', type=click.Choice(['CER', 'WER']))
@click.option('--output_format', '-j', default='TABLE',
type=click.Choice(['JSON', 'TABLE', 'HTML']))
@click.option('--to_lower', '-l', default=False, is_flag=True, help='Do you want to lower all the words? True/False')
@click.option('--remove_punctuation', '-p', default=False, is_flag=True,
help='Do you want to remove all the punctuation? True/False')
@click.option('--to_save_plot', '-P', default=False, is_flag=True, help='Do you want to see the windows? True/False')
@click.option('--to_edit_step', '-s', type=int, default=500, help='Please enter the step')
@click.option('--to_edit_width', '-w', type=int, default=500, help='Please enter the width')
@click.option('--file_path', '-f', help='Please enter the path where you would like to save the files')
def main(reference, output, reference_file, output_file, alignment, error_type, output_format, to_lower,
remove_punctuation, to_save_plot, to_edit_step, to_edit_width, file_path):
"""
Transcription compare tool provided by VoiceGain
"""
if file_path is not None:
if os.path.isdir(file_path) is False:
raise ValueError("No such file or directory")
if reference is not None:
reference = reference
reference_file_name = "reference"
elif reference_file is not None:
# with open(reference_file, 'r') as file1:
# print(reference_file)
# print(str(reference_file)[5:10])
reference_file_name = os.path.basename(reference_file.name)
# print('split', str(reference_file).split(" ")[1].split("\\")[-1].split(".")[0])
reference = reference_file.read()
else:
raise ValueError("One of --reference and --reference_file must be specified")
total_outputs = len(output) + len(output_file)
if total_outputs == 0:
raise ValueError("One of --output and --output_file must be specified")
if total_outputs == 1:
is_multiple = False
file_name = reference_file_name + '_' + "output"
else:
is_multiple = True
# file_name = reference_file_name + '_' + 'multi_output'
if error_type == "CER":
calculator = UKKLevenshteinDistanceCalculator(
tokenizer=CharacterTokenizer(),
get_alignment_result=alignment
)
else:
if is_multiple:
local_optimizers = [LocalCerOptimizer()]
else:
local_optimizers = [DigitUtil(), LocalCerOptimizer()]
calculator = UKKLevenshteinDistanceCalculator(
tokenizer=WordTokenizer(),
get_alignment_result=alignment,
local_optimizers=local_optimizers
)
output_all = dict() # (output identifier -> output string)
for (M, o) in enumerate(output):
output_all["string_output_{}".format(M)] = o
for o in output_file:
output_path = o.read()
output_path_name = os.path.basename(o.name)
# print('output_path_name', output_path_name)
output_all[output_path_name] = output_path
output_results = dict() # (output_identifier -> output_string)
for (key, value) in output_all.items():
# print('key', key)# need more time
output_results[key] = calculator.get_distance(reference, value, to_lower=to_lower,
remove_punctuation=remove_punctuation)
if not is_multiple:
calculator_local = UKKLevenshteinDistanceCalculator(
tokenizer=CharacterTokenizer(),
get_alignment_result=False
)
result = MultiResult(output_results, calculator_local)
# result = list(output_results.values())[0]
# print(type(result))
else:
calculator_local = UKKLevenshteinDistanceCalculator(
tokenizer=CharacterTokenizer(),
get_alignment_result=False
)
result = MultiResult(output_results, calculator_local)
# print(result)
# print(type(result))
if output_format == 'TABLE':
if is_multiple:
raise ValueError("TABLE output doesn't support multi-way comparison")
click.echo(result)
elif output_format == 'JSON':
click.echo(result.to_json())
elif output_format == 'HTML':
s = result.to_html()
gen_html = "transcription-compare.html"
if file_path is not None:
f = open(os.path.join(file_path, gen_html), 'w')
else:
f = open(gen_html, 'w')
f.write(s)
f.close()
if to_save_plot:
alignment_results = []
sub_plot_name = []
for (k, v) in output_results.items():
alignment_results.append(v.alignment_result)
sub_plot_name.append(k)
plot_alignment_result_only_distance(alignment_results, to_edit_width, to_edit_step, sub_plot_name, file_path)
# plot_alignment_result(alignment_results, to_edit_width, to_edit_step, sub_plot_name, file_path)
# plot_alignment_result_density(alignment_results, to_edit_width, to_edit_step, sub_plot_name,
# file_path)
if __name__ == '__main__':
main()