Skip to content
Snippets Groups Projects
Commit d813dc2b authored by Benoit Favre's avatar Benoit Favre
Browse files

refine alternative slu

parent 234a6d88
Branches
No related tags found
No related merge requests found
......@@ -49,14 +49,17 @@ class SLU:
continue
history_score = len(history & set(word.history)) / float(history_length)
target_score = len(set(word.word) & set(target)) / float(len(target))
distance = abs(word.num - expected)
if history_score == 0 or target_score == 0:
continue
if expected != -1 and distance > 10:
continue
value = history_score + target_score
if value > max_value:
max_value = value
argmax = word
if expected != -1 and value == max_value and (min_distance == None or abs(word.num - expected) < min_distance):
min_distance = abs(word.num - expected)
if expected != -1 and value == max_value and (min_distance == None or distance < min_distance):
min_distance = distance
max_value = value
argmax = word
......@@ -67,6 +70,12 @@ class SLU:
return True
return False
def expected_at_section_start(self, section):
for word in self.words:
if word.section == section:
return word.num
return -1
if __name__ == '__main__':
slu = SLU('data/homeostasis_25nov.xml')
history = 'open technical'.split()
......
......@@ -178,7 +178,10 @@ class ScriptedASR(Gtk.Window):
self.slu2_last_section = section_id
self.slu2_performed = set()
self.slu2_history = []
self.slu2_history.extend(hyp[-1].strip().replace('_', ' ').split())
self.slu2_expected = self.slu2.expected_at_section_start(section_id)
for word in hyp[-1].strip().replace('_', ' ').split():
#self.slu2_history.extend(hyp[-1].strip().replace('_', ' ').split())
self.slu2_history.append(word)
found = self.slu2.process(self.slu2_history[:-1], self.slu2_history[-1], self.slu2_expected, section_id)
if found and found.num not in self.slu2_performed and found.word == self.slu2_history[-1]:
self.slu2_performed.add(found.num)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment