diff --git a/alternate_slu.py b/alternate_slu.py index 22b9677a8f6fc4565b784760b6621dc5022c0361..173f44f70a8b38ecf1cf31fef1b743492d3e35c2 100644 --- a/alternate_slu.py +++ b/alternate_slu.py @@ -49,14 +49,17 @@ class SLU: continue history_score = len(history & set(word.history)) / float(history_length) target_score = len(set(word.word) & set(target)) / float(len(target)) + distance = abs(word.num - expected) if history_score == 0 or target_score == 0: continue + if expected != -1 and distance > 10: + continue value = history_score + target_score if value > max_value: max_value = value argmax = word - if expected != -1 and value == max_value and (min_distance == None or abs(word.num - expected) < min_distance): - min_distance = abs(word.num - expected) + if expected != -1 and value == max_value and (min_distance == None or distance < min_distance): + min_distance = distance max_value = value argmax = word @@ -67,6 +70,12 @@ class SLU: return True return False + def expected_at_section_start(self, section): + for word in self.words: + if word.section == section: + return word.num + return -1 + if __name__ == '__main__': slu = SLU('data/homeostasis_25nov.xml') history = 'open technical'.split() diff --git a/main.py b/main.py index 412f0e18cdf6a444db68a5150752f18bf50a9977..4706c8cd3cd855fc35b325cfe49c7d113a089bff 100644 --- a/main.py +++ b/main.py @@ -178,17 +178,20 @@ class ScriptedASR(Gtk.Window): self.slu2_last_section = section_id self.slu2_performed = set() self.slu2_history = [] - self.slu2_history.extend(hyp[-1].strip().replace('_', ' ').split()) - found = self.slu2.process(self.slu2_history[:-1], self.slu2_history[-1], self.slu2_expected, section_id) - if found and found.num not in self.slu2_performed and found.word == self.slu2_history[-1]: - self.slu2_performed.add(found.num) - print found.section, found.action, found.history, found.word - found.keyword = self.xmlview.keywords[found.num] - self.xmlview.highlight(found) - self.slu2_expected = found.num + 1 - if self.slu2.last_in_section(found): - new_section = self.xmlview.get_section() + 1 - self.confirmer.confirm('Go to section %d?' % (new_section + 1), 3, lambda: self.set_section(new_section)) + self.slu2_expected = self.slu2.expected_at_section_start(section_id) + for word in hyp[-1].strip().replace('_', ' ').split(): + #self.slu2_history.extend(hyp[-1].strip().replace('_', ' ').split()) + self.slu2_history.append(word) + found = self.slu2.process(self.slu2_history[:-1], self.slu2_history[-1], self.slu2_expected, section_id) + if found and found.num not in self.slu2_performed and found.word == self.slu2_history[-1]: + self.slu2_performed.add(found.num) + print found.section, found.action, found.history, found.word + found.keyword = self.xmlview.keywords[found.num] + self.xmlview.highlight(found) + self.slu2_expected = found.num + 1 + if self.slu2.last_in_section(found): + new_section = self.xmlview.get_section() + 1 + self.confirmer.confirm('Go to section %d?' % (new_section + 1), 3, lambda: self.set_section(new_section)) def quit(self, window): for slu in self.slu.values():