Skip to content
Snippets Groups Projects
Commit 3489e388 authored by Franck Dary's avatar Franck Dary
Browse files

Config is now aware of what is predicted

parent e85a3122
No related branches found
No related tags found
No related merge requests found
......@@ -7,6 +7,8 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine)
void Decoder::decode(BaseConfig & config, std::size_t beamSize)
{
config.addPredicted(machine.getPredicted());
try
{
config.setState(machine.getStrategy().getInitialState());
......
......@@ -30,6 +30,7 @@ class Config
private :
std::vector<String> lines;
std::set<std::string> predicted;
protected :
......@@ -61,6 +62,8 @@ class Config
String & get(int colIndex, int lineIndex, int hypothesisIndex);
const String & getConst(int colIndex, int lineIndex, int hypothesisIndex) const;
String & getLastNotEmpty(int colIndex, int lineIndex);
String & getLastNotEmptyHyp(int colIndex, int lineIndex);
const String & getLastNotEmptyHypConst(int colIndex, int lineIndex) const;
const String & getLastNotEmptyConst(int colIndex, int lineIndex) const;
ValueIterator getIterator(int colIndex, int lineIndex, int hypothesisIndex);
ConstValueIterator getConstIterator(int colIndex, int lineIndex, int hypothesisIndex) const;
......@@ -75,6 +78,8 @@ class Config
const String & getConst(const std::string & colName, int lineIndex, int hypothesisIndex) const;
String & getLastNotEmpty(const std::string & colName, int lineIndex);
const String & getLastNotEmptyConst(const std::string & colName, int lineIndex) const;
String & getLastNotEmptyHyp(const std::string & colName, int lineIndex);
const String & getLastNotEmptyHypConst(const std::string & colName, int lineIndex) const;
String & getFirstEmpty(int colIndex, int lineIndex);
String & getFirstEmpty(const std::string & colName, int lineIndex);
bool hasCharacter(int letterIndex) const;
......@@ -100,7 +105,8 @@ class Config
void setState(const std::string state);
bool stateIsDone() const;
std::vector<long> extractContext(int leftBorder, int rightBorder, Dict & dict) const;
void addPredicted(const std::set<std::string> & predicted);
bool isPredicted(const std::string & colName) const;
};
#endif
......@@ -39,6 +39,7 @@ class ReadingMachine
Classifier * getClassifier();
void save() const;
bool isPredicted(const std::string & columnName) const;
const std::set<std::string> & getPredicted() const;
};
#endif
......@@ -70,7 +70,10 @@ void Config::print(FILE * dest) const
continue;
}
for (unsigned int i = 0; i < getNbColumns()-1; i++)
fmt::print(dest, "{}{}", getLastNotEmptyConst(i, getFirstLineIndex()+line), i < getNbColumns()-2 ? "\t" : "\n");
{
auto & colContent = isPredicted(getColName(i)) ? getLastNotEmptyHypConst(i, getFirstLineIndex()+line) : getLastNotEmptyConst(i, getFirstLineIndex()+line);
fmt::print(dest, "{}{}", colContent, i < getNbColumns()-2 ? "\t" : "\n");
}
if (getLastNotEmptyConst(EOSColName, getFirstLineIndex()+line) == EOSSymbol1)
fmt::print(dest, "\n");
}
......@@ -105,7 +108,10 @@ void Config::printForDebug(FILE * dest) const
toPrint.emplace_back();
toPrint.back().emplace_back(line == (int)wordIndex ? "=>" : "");
for (unsigned int i = 0; i < getNbColumns(); i++)
toPrint.back().emplace_back(util::shrink(getLastNotEmptyConst(i, line), maxWordLength));
{
auto & colContent = isPredicted(getColName(i)) ? getLastNotEmptyHypConst(i, line) : getLastNotEmptyConst(i, getFirstLineIndex()+line);
toPrint.back().emplace_back(util::shrink(colContent, maxWordLength));
}
}
std::vector<std::size_t> colLength(toPrint[0].size(), 0);
......@@ -167,6 +173,17 @@ Config::String & Config::getLastNotEmpty(int colIndex, int lineIndex)
return lines[baseIndex];
}
Config::String & Config::getLastNotEmptyHyp(int colIndex, int lineIndex)
{
int baseIndex = getIndexOfLine(lineIndex-getFirstLineIndex()) + getIndexOfCol(colIndex);
for (int i = nbHypothesesMax; i > 0; --i)
if (!util::isEmpty(lines[baseIndex+i]))
return lines[baseIndex+i];
return lines[baseIndex+1];
}
Config::String & Config::getFirstEmpty(int colIndex, int lineIndex)
{
int baseIndex = getIndexOfLine(lineIndex-getFirstLineIndex()) + getIndexOfCol(colIndex);
......@@ -194,16 +211,37 @@ const Config::String & Config::getLastNotEmptyConst(int colIndex, int lineIndex)
return lines[baseIndex];
}
const Config::String & Config::getLastNotEmptyHypConst(int colIndex, int lineIndex) const
{
int baseIndex = getIndexOfLine(lineIndex-getFirstLineIndex()) + getIndexOfCol(colIndex);
for (int i = nbHypothesesMax; i > 0; --i)
if (!util::isEmpty(lines[baseIndex+i]))
return lines[baseIndex+i];
return lines[baseIndex+1];
}
Config::String & Config::getLastNotEmpty(const std::string & colName, int lineIndex)
{
return getLastNotEmpty(getColIndex(colName), lineIndex);
}
Config::String & Config::getLastNotEmptyHyp(const std::string & colName, int lineIndex)
{
return getLastNotEmptyHyp(getColIndex(colName), lineIndex);
}
const Config::String & Config::getLastNotEmptyConst(const std::string & colName, int lineIndex) const
{
return getLastNotEmptyConst(getColIndex(colName), lineIndex);
}
const Config::String & Config::getLastNotEmptyHypConst(const std::string & colName, int lineIndex) const
{
return getLastNotEmptyHypConst(getColIndex(colName), lineIndex);
}
Config::ValueIterator Config::getIterator(int colIndex, int lineIndex, int hypothesisIndex)
{
return lines.begin() + getIndexOfLine(lineIndex-getFirstLineIndex()) + getIndexOfCol(colIndex) + hypothesisIndex;
......@@ -393,3 +431,13 @@ std::vector<long> Config::extractContext(int leftBorder, int rightBorder, Dict &
return context;
}
void Config::addPredicted(const std::set<std::string> & predicted)
{
this->predicted.insert(predicted.begin(), predicted.end());
}
bool Config::isPredicted(const std::string & colName) const
{
return predicted.count(colName);
}
......@@ -119,3 +119,8 @@ bool ReadingMachine::isPredicted(const std::string & columnName) const
return predicted.count(columnName);
}
const std::set<std::string> & ReadingMachine::getPredicted() const
{
return predicted;
}
......@@ -7,6 +7,7 @@ Trainer::Trainer(ReadingMachine & machine) : machine(machine)
void Trainer::createDataset(SubConfig & config)
{
config.addPredicted(machine.getPredicted());
config.setState(machine.getStrategy().getInitialState());
std::vector<torch::Tensor> contexts;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment