From d813dc2bf3929733699290b6003686e1f246750c Mon Sep 17 00:00:00 2001 From: Benoit Favre <benoit.favre@lif.univ-mrs.fr> Date: Fri, 27 Feb 2015 13:52:26 +0100 Subject: [PATCH] refine alternative slu --- alternate_slu.py | 13 +++++++++++-- main.py | 25 ++++++++++++++----------- 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/alternate_slu.py b/alternate_slu.py index 22b9677..173f44f 100644 --- a/alternate_slu.py +++ b/alternate_slu.py @@ -49,14 +49,17 @@ class SLU: continue history_score = len(history & set(word.history)) / float(history_length) target_score = len(set(word.word) & set(target)) / float(len(target)) + distance = abs(word.num - expected) if history_score == 0 or target_score == 0: continue + if expected != -1 and distance > 10: + continue value = history_score + target_score if value > max_value: max_value = value argmax = word - if expected != -1 and value == max_value and (min_distance == None or abs(word.num - expected) < min_distance): - min_distance = abs(word.num - expected) + if expected != -1 and value == max_value and (min_distance == None or distance < min_distance): + min_distance = distance max_value = value argmax = word @@ -67,6 +70,12 @@ class SLU: return True return False + def expected_at_section_start(self, section): + for word in self.words: + if word.section == section: + return word.num + return -1 + if __name__ == '__main__': slu = SLU('data/homeostasis_25nov.xml') history = 'open technical'.split() diff --git a/main.py b/main.py index 412f0e1..4706c8c 100644 --- a/main.py +++ b/main.py @@ -178,17 +178,20 @@ class ScriptedASR(Gtk.Window): self.slu2_last_section = section_id self.slu2_performed = set() self.slu2_history = [] - self.slu2_history.extend(hyp[-1].strip().replace('_', ' ').split()) - found = self.slu2.process(self.slu2_history[:-1], self.slu2_history[-1], self.slu2_expected, section_id) - if found and found.num not in self.slu2_performed and found.word == self.slu2_history[-1]: - self.slu2_performed.add(found.num) - print found.section, found.action, found.history, found.word - found.keyword = self.xmlview.keywords[found.num] - self.xmlview.highlight(found) - self.slu2_expected = found.num + 1 - if self.slu2.last_in_section(found): - new_section = self.xmlview.get_section() + 1 - self.confirmer.confirm('Go to section %d?' % (new_section + 1), 3, lambda: self.set_section(new_section)) + self.slu2_expected = self.slu2.expected_at_section_start(section_id) + for word in hyp[-1].strip().replace('_', ' ').split(): + #self.slu2_history.extend(hyp[-1].strip().replace('_', ' ').split()) + self.slu2_history.append(word) + found = self.slu2.process(self.slu2_history[:-1], self.slu2_history[-1], self.slu2_expected, section_id) + if found and found.num not in self.slu2_performed and found.word == self.slu2_history[-1]: + self.slu2_performed.add(found.num) + print found.section, found.action, found.history, found.word + found.keyword = self.xmlview.keywords[found.num] + self.xmlview.highlight(found) + self.slu2_expected = found.num + 1 + if self.slu2.last_in_section(found): + new_section = self.xmlview.get_section() + 1 + self.confirmer.confirm('Go to section %d?' % (new_section + 1), 3, lambda: self.set_section(new_section)) def quit(self, window): for slu in self.slu.values(): -- GitLab