Skip to content
Snippets Groups Projects
Commit 15d915ca authored by Franck Dary's avatar Franck Dary
Browse files

Fixed some problems in dependency parsing

parent cac30a69
Branches
No related tags found
No related merge requests found
......@@ -89,11 +89,13 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool
// Force EOS when needed
if (machine.getTransitionSet().getTransition("EOS") and config.getLastNotEmptyHypConst(Config::EOSColName, config.getWordIndex()) != Config::EOSSymbol1)
{
Action shift = Action::pushWordIndexOnStack();
shift.apply(config, shift);
machine.getTransitionSet().getTransition("SHIFT")->apply(config);
machine.getTransitionSet().getTransition("EOS")->apply(config);
if (debug)
{
fmt::print(stderr, "Forcing EOS transition\n");
config.printForDebug(stderr);
}
}
// Fill holes in important columns like "ID" and "HEAD" to be compatible with eval script
......
......@@ -21,6 +21,8 @@ class Config
static constexpr const char * headColName = "HEAD";
static constexpr const char * deprelColName = "DEPREL";
static constexpr const char * idColName = "ID";
static constexpr const char * isMultiColName = "MULTI";
static constexpr const char * childsColName = "CHILDS";
static constexpr int nbHypothesesMax = 1;
static constexpr int maxNbAppliableSplitTransitions = 8;
......
......@@ -270,6 +270,9 @@ Action Action::pushWordIndexOnStack()
if (config.hasStack(0) and config.getStack(0) == config.getWordIndex())
return false;
if (config.hasStack(0) and !config.isTokenPredicted(config.getStack(0)))
return false;
return (int)config.getWordIndex() != config.getLastPoppedStack();
};
......@@ -292,7 +295,7 @@ Action Action::popStack()
auto appliable = [](const Config & config, const Action &)
{
return config.hasStack(0);
return config.hasStack(0) and config.getStack(0) != config.getWordIndex();
};
return {Type::Pop, apply, undo, appliable};
......@@ -499,7 +502,7 @@ Action Action::setRoot()
auto appliable = [](const Config & config, const Action &)
{
return config.hasStack(0);
return config.hasStack(0) and config.isTokenPredicted(config.getStack(0)) and config.getLastNotEmptyConst(Config::isMultiColName, config.getStack(0)) != Config::EOSSymbol1;
};
return {Type::Write, apply, undo, appliable};
......@@ -605,6 +608,9 @@ Action Action::attach(Object governorObject, int governorIndex, Object dependent
depLineIndex = config.getStack(dependentIndex);
}
if (!config.isTokenPredicted(govLineIndex) or !config.isTokenPredicted(depLineIndex))
return false;
// Check for cycles
while (govLineIndex != depLineIndex)
{
......
......@@ -28,6 +28,16 @@ void BaseConfig::readMCD(std::string_view mcdFilename)
std::fclose(file);
if (colName2Index.count(isMultiColName))
util::myThrow(fmt::format("mcd '{}' must not contain column '{}'", mcdFilename, isMultiColName));
colIndex2Name.emplace_back(isMultiColName);
colName2Index.emplace(isMultiColName, colIndex2Name.size()-1);
if (colName2Index.count(childsColName))
util::myThrow(fmt::format("mcd '{}' must not contain column '{}'", mcdFilename, childsColName));
colIndex2Name.emplace_back(childsColName);
colName2Index.emplace(childsColName, colIndex2Name.size()-1);
if (colName2Index.count(EOSColName))
util::myThrow(fmt::format("mcd '{}' must not contain column '{}'", mcdFilename, EOSColName));
colIndex2Name.emplace_back(EOSColName);
......@@ -64,6 +74,7 @@ void BaseConfig::readTSVInput(std::string_view tsvFilename)
int inputLineIndex = 0;
bool inputHasBeenRead = false;
int usualNbCol = -1;
int nbMultiwords = 0;
while (!std::feof(file))
{
......@@ -116,6 +127,7 @@ void BaseConfig::readTSVInput(std::string_view tsvFilename)
{
addLines(1);
get(EOSColName, getNbLines()-1, 0) = EOSSymbol0;
get(isMultiColName, getNbLines()-1, 0) = EOSSymbol0;
get(0, getNbLines()-1, 0) = std::string(line);
continue;
}
......@@ -134,6 +146,13 @@ void BaseConfig::readTSVInput(std::string_view tsvFilename)
addLines(1);
get(EOSColName, getNbLines()-1, 0) = EOSSymbol0;
if (nbMultiwords > 0)
{
get(isMultiColName, getNbLines()-1, 0) = EOSSymbol1;
nbMultiwords--;
}
else
get(isMultiColName, getNbLines()-1, 0) = EOSSymbol0;
for (unsigned int i = 0; i < splited.size(); i++)
if (i < colIndex2Name.size())
......@@ -141,6 +160,9 @@ void BaseConfig::readTSVInput(std::string_view tsvFilename)
std::string value = std::string(splited[i]);
get(i, getNbLines()-1, 0) = value;
}
if (isMultiword(getNbLines()-1))
nbMultiwords = getMultiwordSize(getNbLines()-1)+1;
}
std::fclose(file);
......
......@@ -98,6 +98,12 @@ void Config::print(FILE * dest) const
}
for (unsigned int i = 0; i < getNbColumns()-1; i++)
{
if (getColName(i) == isMultiColName or getColName(i) == childsColName)
{
if (i == getNbColumns()-2)
currentSequence.back().back() = '\n';
continue;
}
auto & colContent = getAsFeature(i, getFirstLineIndex()+line);
std::string valueToPrint = colContent;
try
......@@ -139,7 +145,11 @@ void Config::printForDebug(FILE * dest) const
toPrint.emplace_back();
toPrint.back().emplace_back("");
for (unsigned int i = 0; i < getNbColumns(); i++)
{
if (getColName(i) == isMultiColName or getColName(i) == childsColName)
continue;
toPrint.back().emplace_back(getColName(i));
}
for (int line = firstLineToPrint; line <= lastLineToPrint; line++)
{
......@@ -149,6 +159,8 @@ void Config::printForDebug(FILE * dest) const
toPrint.back().emplace_back(line == (int)wordIndex ? "=>" : "");
for (unsigned int i = 0; i < getNbColumns(); i++)
{
if (getColName(i) == isMultiColName or getColName(i) == childsColName)
continue;
std::string colContent = has(i,line,0) ? getAsFeature(i, line).get() : "?";
std::string toPrintCol = colContent;
try
......
......@@ -2,22 +2,28 @@
#define DEPTHLAYERTREEEMBEDDING__H
#include <torch/torch.h>
#include "fmt/core.h"
#include "Submodule.hpp"
#include "LSTM.hpp"
class DepthLayerTreeEmbeddingImpl : public torch::nn::Module
class DepthLayerTreeEmbeddingImpl : public torch::nn::Module, public Submodule
{
private :
std::vector<std::string> columns{"DEPREL"};
std::vector<int> focusedBuffer{0};
std::vector<int> focusedStack{0};
std::string firstElem{"__special_DepthLayerTreeEmbeddingImpl__"};
std::vector<LSTM> depthLstm;
int maxDepth;
int maxElemPerDepth;
public :
DepthLayerTreeEmbeddingImpl(int maxDepth, int maxElemPerDepth);
DepthLayerTreeEmbeddingImpl(int maxDepth, int maxElemPerDepth, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options);
torch::Tensor forward(torch::Tensor input);
int getOutputSize();
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
};
TORCH_MODULE(DepthLayerTreeEmbedding);
......
#include "DepthLayerTreeEmbedding.hpp"
DepthLayerTreeEmbeddingImpl::DepthLayerTreeEmbeddingImpl(int maxDepth, int maxElemPerDepth) : maxDepth(maxDepth), maxElemPerDepth(maxElemPerDepth)
DepthLayerTreeEmbeddingImpl::DepthLayerTreeEmbeddingImpl(int maxDepth, int maxElemPerDepth, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options) : maxDepth(maxDepth), maxElemPerDepth(maxElemPerDepth)
{
for (int i = 0; i < maxDepth; i++)
depthLstm.emplace_back(register_module(fmt::format("lstm_{}",i), LSTM(embeddingsSize, outEmbeddingsSize, options)));
}
torch::Tensor DepthLayerTreeEmbeddingImpl::forward(torch::Tensor input)
{
auto context = input.narrow(1, firstInputIndex, getInputSize());
std::vector<torch::Tensor> outputs;
for (unsigned int i = 0; i < depthLstm.size(); i++)
for (unsigned int j = 0; j < focusedBuffer.size()+focusedStack.size(); j++)
outputs.emplace_back(depthLstm[i](input.narrow(1,i*(focusedBuffer.size()+focusedStack.size())*columns.size()*maxElemPerDepth + j*maxElemPerDepth, maxElemPerDepth)));
return torch::cat(outputs, 1);
}
std::size_t DepthLayerTreeEmbeddingImpl::getOutputSize()
{
std::size_t outputSize = 0;
for (auto & lstm : depthLstm)
outputSize += lstm->getOutputSize(maxElemPerDepth);
return outputSize;
}
std::size_t DepthLayerTreeEmbeddingImpl::getInputSize()
{
return (focusedBuffer.size()+focusedStack.size())*columns.size()*maxDepth*maxElemPerDepth;
}
int DepthLayerTreeEmbeddingImpl::getOutputSize()
void DepthLayerTreeEmbeddingImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
{
std::vector<long> focusedIndexes;
for (int index : focusedBuffer)
focusedIndexes.emplace_back(config.getRelativeWordIndex(index));
for (int index : focusedStack)
if (config.hasStack(index))
focusedIndexes.emplace_back(config.getStack(index));
else
focusedIndexes.emplace_back(-1);
for (auto & contextElement : context)
{
for (auto index : focusedIndexes)
{
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment