-
Notifications
You must be signed in to change notification settings - Fork 1
/
evaluate_cer.py
169 lines (140 loc) · 5.79 KB
/
evaluate_cer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import json
from evaluate import load
from itertools import product
import regex as re
import doctest
from tqdm import tqdm
cjk_char_re = re.compile(r'[\p{Unified_Ideograph}\u3006\u3007]')
punctuation_re = re.compile(r'\p{P}')
cer = load("cer")
# Read JSON data from file
with open("google/common-voice-15-transcriptions.json", "r", encoding="utf-8") as f:
transcript_json = json.load(f)
references = []
best_predictions = []
closest_predictions = []
sound_closest_predictions = []
sound_closest_references = []
char_sounds = {}
with open("charlist.csv", "r") as charlist:
for row in map(lambda line: line.strip().split(","), charlist.readlines()):
char_sounds[row[0]] = row[1:]
def to_sounds(sentence: str, output_tones=True) -> [[str]]:
"""
Convert a sentence containing Chinese characters, English words,
punctuations, and potentially other symbols into Jyutping.
* Keep Chinese characters outside the vocabulary and English words
as-is.
* Ignore punctuations, except periods and dashes
>>> to_sounds('佢行山也')
[['heoi5', 'keoi5'], ['haang1', 'haang4', 'hang4', 'hang6', 'hong2', 'hong4'], ['saan1'], ['jaa2', 'jaa4', 'jaa5', 'jaa6', 'je1']]
>>> to_sounds('我鍾意Python。')
[['ngo5'], ['zung1'], ['ji1', 'ji2', 'ji3', 'ji5'], ['Python']]
>>> to_sounds('Hello!')
[['Hello']]
>>> to_sounds('你好 - Hello!')
[['nei5'], ['hou2', 'hou3'], ['-'], ['Hello']]
>>> to_sounds(',,abc,,')
[['abc']]
>>> to_sounds('')
[]
>>> to_sounds('。')
[]
"""
sounds = []
segment = ""
sentence = re.sub(r"\s+", " ", sentence.strip())
for char in sentence:
if cjk_char_re.match(char):
# Flush the previous segment before appending the sounds of the Chinese character
if len(segment) > 0:
sounds.append([segment.strip()])
segment = ""
if char in char_sounds:
if output_tones:
sounds.append(char_sounds[char])
else:
sounds.append(list(set(map(lambda jyutping: jyutping[:-1], char_sounds[char]))))
else:
sounds.append([char])
elif punctuation_re.match(char) and char != '.' and char != '-':
# ignore punctuations except periods and dashes
continue
elif char == ' ':
# Flush the previous segment before appending the sounds of the Chinese character
if len(segment) > 0:
sounds.append([segment.strip()])
segment = ""
else:
segment += char
# Flush the previous segment before appending the sounds of the Chinese character
if len(segment) > 0:
sounds.append([segment.strip()])
segment = ""
return sounds
def normalize_sentence(s: str) -> str:
return punctuation_re.sub("", s).replace("噶", "㗎").replace("咧", "呢")
def evaluate_cer():
# Iterate over each sentence in the transcript
for entry in tqdm(transcript_json):
reference = normalize_sentence(entry["sentence"])
references.append(reference)
"""
Best prediction outputted by the ASR
"""
if "segments" in entry:
# apple
best_segments = [segment["substring"] for segment in entry["segments"]]
best_prediction = "".join(best_segments)
best_predictions.append(best_prediction)
elif "DisplayText" in entry["result"]:
# azure
best_prediction = entry["result"]["DisplayText"]
best_predictions.append(best_prediction)
else:
# google
# not empty audio
if len(entry["result"]) > 0:
best_prediction = entry["result"][0]["transcript"]
best_predictions.append(best_prediction)
else:
# remove empty audio's reference
references.pop()
"""
Find the closest prediction to the reference among all alternative segments
"""
if "segments" in entry:
# apple
# Prepare a list of lists, where each inner list contains the substring
# and its alternativeSubstrings.
all_segment_options = []
for segment in entry["segments"]:
options = [segment["substring"]] + segment.get("alternativeSubstrings", [])
all_segment_options.append(options)
# Generate all possible combinations of segments and alternative segments.
all_combinations = list(product(*all_segment_options))
elif "DisplayText" in entry["result"]:
# azure
all_combinations = list(map(lambda candidate: candidate["Display"], entry["result"]["NBest"]))
else:
# google
all_combinations = list(map(lambda candidate: candidate["transcript"], entry["result"]))
# skip empty audios
if len(all_combinations) > 0:
min_cer = float('inf')
closest_prediction = None
for combination in all_combinations:
prediction = ''.join(combination)
cer_result = cer.compute(predictions=[prediction], references=[reference])
if cer_result < min_cer:
min_cer = cer_result
closest_prediction = prediction
closest_predictions.append(closest_prediction)
cer_score_best = cer.compute(predictions=best_predictions, references=references)
print(f"CER Score for best predictions: {cer_score_best}")
cer_score_closest = cer.compute(predictions=closest_predictions, references=references)
print(f"CER Score for closest predictions: {cer_score_closest}")
if __name__ == "__main__":
import doctest
doctest.testmod()
evaluate_cer()