Skip to content
Snippets Groups Projects
Commit d813dc2b authored by Benoit Favre's avatar Benoit Favre
Browse files

refine alternative slu

parent 234a6d88
No related branches found
No related tags found
No related merge requests found
...@@ -49,14 +49,17 @@ class SLU: ...@@ -49,14 +49,17 @@ class SLU:
continue continue
history_score = len(history & set(word.history)) / float(history_length) history_score = len(history & set(word.history)) / float(history_length)
target_score = len(set(word.word) & set(target)) / float(len(target)) target_score = len(set(word.word) & set(target)) / float(len(target))
distance = abs(word.num - expected)
if history_score == 0 or target_score == 0: if history_score == 0 or target_score == 0:
continue continue
if expected != -1 and distance > 10:
continue
value = history_score + target_score value = history_score + target_score
if value > max_value: if value > max_value:
max_value = value max_value = value
argmax = word argmax = word
if expected != -1 and value == max_value and (min_distance == None or abs(word.num - expected) < min_distance): if expected != -1 and value == max_value and (min_distance == None or distance < min_distance):
min_distance = abs(word.num - expected) min_distance = distance
max_value = value max_value = value
argmax = word argmax = word
...@@ -67,6 +70,12 @@ class SLU: ...@@ -67,6 +70,12 @@ class SLU:
return True return True
return False return False
def expected_at_section_start(self, section):
for word in self.words:
if word.section == section:
return word.num
return -1
if __name__ == '__main__': if __name__ == '__main__':
slu = SLU('data/homeostasis_25nov.xml') slu = SLU('data/homeostasis_25nov.xml')
history = 'open technical'.split() history = 'open technical'.split()
......
...@@ -178,17 +178,20 @@ class ScriptedASR(Gtk.Window): ...@@ -178,17 +178,20 @@ class ScriptedASR(Gtk.Window):
self.slu2_last_section = section_id self.slu2_last_section = section_id
self.slu2_performed = set() self.slu2_performed = set()
self.slu2_history = [] self.slu2_history = []
self.slu2_history.extend(hyp[-1].strip().replace('_', ' ').split()) self.slu2_expected = self.slu2.expected_at_section_start(section_id)
found = self.slu2.process(self.slu2_history[:-1], self.slu2_history[-1], self.slu2_expected, section_id) for word in hyp[-1].strip().replace('_', ' ').split():
if found and found.num not in self.slu2_performed and found.word == self.slu2_history[-1]: #self.slu2_history.extend(hyp[-1].strip().replace('_', ' ').split())
self.slu2_performed.add(found.num) self.slu2_history.append(word)
print found.section, found.action, found.history, found.word found = self.slu2.process(self.slu2_history[:-1], self.slu2_history[-1], self.slu2_expected, section_id)
found.keyword = self.xmlview.keywords[found.num] if found and found.num not in self.slu2_performed and found.word == self.slu2_history[-1]:
self.xmlview.highlight(found) self.slu2_performed.add(found.num)
self.slu2_expected = found.num + 1 print found.section, found.action, found.history, found.word
if self.slu2.last_in_section(found): found.keyword = self.xmlview.keywords[found.num]
new_section = self.xmlview.get_section() + 1 self.xmlview.highlight(found)
self.confirmer.confirm('Go to section %d?' % (new_section + 1), 3, lambda: self.set_section(new_section)) self.slu2_expected = found.num + 1
if self.slu2.last_in_section(found):
new_section = self.xmlview.get_section() + 1
self.confirmer.confirm('Go to section %d?' % (new_section + 1), 3, lambda: self.set_section(new_section))
def quit(self, window): def quit(self, window):
for slu in self.slu.values(): for slu in self.slu.values():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment