Commit 11e87ce1 authored by Franck Dary's avatar Franck Dary
Browse files

extractContext now directly gives a torch::Tensor

parent 0c86cb53
......@@ -45,8 +45,7 @@ void Beam::update(ReadingMachine & machine, bool debug)
auto appliableTransitions = machine.getTransitionSet(elements[index].config.getState()).getAppliableTransitions(elements[index].config);
elements[index].config.setAppliableTransitions(appliableTransitions);
auto context = classifier.getNN()->extractContext(elements[index].config).back();
auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(torch::kLong)).clone().to(NeuralNetworkImpl::device);
auto neuralInput = classifier.getNN()->extractContext(elements[index].config);
auto prediction = classifier.isRegression() ? classifier.getNN()->forward(neuralInput, elements[index].config.getState()).squeeze(0) : torch::softmax(classifier.getNN()->forward(neuralInput, elements[index].config.getState()).squeeze(0), 0);
float entropy = classifier.isRegression() ? 0.0 : NeuralNetworkImpl::entropy(prediction);
......
......@@ -19,7 +19,7 @@ class AppliableTransModuleImpl : public Submodule
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void addToContext(torch::Tensor & context, const Config & config) override;
void registerEmbeddings() override;
};
TORCH_MODULE(AppliableTransModule);
......
......@@ -30,7 +30,7 @@ class ContextModuleImpl : public Submodule
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void addToContext(torch::Tensor & context, const Config & config) override;
void registerEmbeddings() override;
};
TORCH_MODULE(ContextModule);
......
......@@ -31,7 +31,7 @@ class ContextualModuleImpl : public Submodule
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void addToContext(torch::Tensor & context, const Config & config) override;
void registerEmbeddings() override;
};
TORCH_MODULE(ContextualModule);
......
......@@ -27,7 +27,7 @@ class DepthLayerTreeEmbeddingModuleImpl : public Submodule
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void addToContext(torch::Tensor & context, const Config & config) override;
void registerEmbeddings() override;
};
TORCH_MODULE(DepthLayerTreeEmbeddingModule);
......
......@@ -26,7 +26,7 @@ class DistanceModuleImpl : public Submodule
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void addToContext(torch::Tensor & context, const Config & config) override;
void registerEmbeddings() override;
};
TORCH_MODULE(DistanceModule);
......
......@@ -29,7 +29,7 @@ class FocusedColumnModuleImpl : public Submodule
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void addToContext(torch::Tensor & context, const Config & config) override;
void registerEmbeddings() override;
};
TORCH_MODULE(FocusedColumnModule);
......
......@@ -25,7 +25,7 @@ class HistoryModuleImpl : public Submodule
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void addToContext(torch::Tensor & context, const Config & config) override;
void registerEmbeddings() override;
};
TORCH_MODULE(HistoryModule);
......
......@@ -25,12 +25,13 @@ class ModularNetworkImpl : public NeuralNetworkImpl
MLP mlp{nullptr};
std::vector<std::shared_ptr<Submodule>> modules;
std::map<std::string,torch::nn::Linear> outputLayersPerState;
std::size_t totalInputSize{0};
public :
ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions, std::filesystem::path path);
torch::Tensor forward(torch::Tensor input, const std::string & state) override;
std::vector<std::vector<long>> extractContext(Config & config) override;
torch::Tensor extractContext(Config & config) override;
void registerEmbeddings() override;
void saveDicts(std::filesystem::path path) override;
void loadDicts(std::filesystem::path path) override;
......
......@@ -15,7 +15,7 @@ class NeuralNetworkImpl : public torch::nn::Module, public NameHolder
public :
virtual torch::Tensor forward(torch::Tensor input, const std::string & state) = 0;
virtual std::vector<std::vector<long>> extractContext(Config & config) = 0;
virtual torch::Tensor extractContext(Config & config) = 0;
virtual void registerEmbeddings() = 0;
virtual void saveDicts(std::filesystem::path path) = 0;
virtual void loadDicts(std::filesystem::path path) = 0;
......
......@@ -24,7 +24,7 @@ class NumericColumnModuleImpl : public Submodule
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void addToContext(torch::Tensor & context, const Config & config) override;
void registerEmbeddings() override;
};
TORCH_MODULE(NumericColumnModule);
......
......@@ -13,7 +13,7 @@ class RandomNetworkImpl : public NeuralNetworkImpl
RandomNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState);
torch::Tensor forward(torch::Tensor input, const std::string & state) override;
std::vector<std::vector<long>> extractContext(Config &) override;
torch::Tensor extractContext(Config &) override;
void registerEmbeddings() override;
void saveDicts(std::filesystem::path path) override;
void loadDicts(std::filesystem::path path) override;
......
......@@ -25,7 +25,7 @@ class RawInputModuleImpl : public Submodule
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void addToContext(torch::Tensor & context, const Config & config) override;
void registerEmbeddings() override;
};
TORCH_MODULE(RawInputModule);
......
......@@ -24,7 +24,7 @@ class SplitTransModuleImpl : public Submodule
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void addToContext(torch::Tensor & context, const Config & config) override;
void registerEmbeddings() override;
};
TORCH_MODULE(SplitTransModule);
......
......@@ -21,7 +21,7 @@ class StateNameModuleImpl : public Submodule
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void addToContext(torch::Tensor & context, const Config & config) override;
void registerEmbeddings() override;
};
TORCH_MODULE(StateNameModule);
......
......@@ -24,7 +24,7 @@ class Submodule : public torch::nn::Module, public DictHolder
void loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std::filesystem::path path, std::string prefix);
virtual std::size_t getOutputSize() = 0;
virtual std::size_t getInputSize() = 0;
virtual void addToContext(std::vector<std::vector<long>> & context, const Config & config) = 0;
virtual void addToContext(torch::Tensor & context, const Config & config) = 0;
virtual torch::Tensor forward(torch::Tensor input) = 0;
virtual void registerEmbeddings() = 0;
std::function<std::string(const std::string &)> getFunction(const std::string functionNames);
......
......@@ -22,7 +22,7 @@ class UppercaseRateModuleImpl : public Submodule
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void addToContext(torch::Tensor & context, const Config & config) override;
void registerEmbeddings() override;
};
TORCH_MODULE(UppercaseRateModule);
......
......@@ -20,15 +20,12 @@ std::size_t AppliableTransModuleImpl::getInputSize()
return nbTrans;
}
void AppliableTransModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
void AppliableTransModuleImpl::addToContext(torch::Tensor & context, const Config & config)
{
auto & appliableTrans = config.getAppliableTransitions();
for (auto & contextElement : context)
for (int i = 0; i < nbTrans; i++)
if (i < (int)appliableTrans.size())
contextElement.emplace_back(appliableTrans[i]);
else
contextElement.emplace_back(0);
for (int i = 0; i < nbTrans; i++)
if (i < (int)appliableTrans.size())
context[firstInputIndex+i] = appliableTrans[i];
}
void AppliableTransModuleImpl::registerEmbeddings()
......
......@@ -83,7 +83,7 @@ std::size_t ContextModuleImpl::getInputSize()
return columns.size()*(targets.size());
}
void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
void ContextModuleImpl::addToContext(torch::Tensor & context, const Config & config)
{
auto & dict = getDict();
std::vector<long> contextIndexes;
......@@ -125,24 +125,22 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c
contextIndexes.emplace_back(-3);
}
int insertIndex = 0;
for (auto index : contextIndexes)
for (unsigned int colIndex = 0; colIndex < columns.size(); colIndex++)
{
auto & col = columns[colIndex];
if (index == -1)
{
for (auto & contextElement : context)
contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, col));
context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::nullValueStr, col);
}
else if (index == -2)
{
for (auto & contextElement : context)
contextElement.push_back(dict.getIndexOrInsert(Dict::noChildValueStr, col));
context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::noChildValueStr, col);
}
else if (index == -3)
{
for (auto & contextElement : context)
contextElement.push_back(dict.getIndexOrInsert(Dict::oobValueStr, col));
context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::oobValueStr, col);
}
else
{
......@@ -162,9 +160,9 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c
dictIndex = dict.getIndexOrInsert(featureValue, col);
}
for (auto & contextElement : context)
contextElement.push_back(dictIndex);
context[firstInputIndex+insertIndex] = dictIndex;
}
insertIndex++;
}
}
......
......@@ -87,7 +87,7 @@ std::size_t ContextualModuleImpl::getInputSize()
return columns.size()*(4+window.second-window.first)+targets.size();
}
void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
void ContextualModuleImpl::addToContext(torch::Tensor & context, const Config & config)
{
auto & dict = getDict();
std::vector<long> contextIndexes;
......@@ -132,24 +132,23 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context
else
targetIndexes.emplace_back(-1);
int insertIndex = 0;
for (auto index : contextIndexes)
for (unsigned int colIndex = 0; colIndex < columns.size(); colIndex++)
{
auto & col = columns[colIndex];
if (index == -1)
{
for (auto & contextElement : context)
contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, col));
context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::nullValueStr, col);
}
else if (index == -2)
{
for (auto & contextElement : context)
contextElement.push_back(dict.getIndexOrInsert(Dict::noChildValueStr, col));
context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::noChildValueStr, col);
}
else if (index == -3)
{
for (auto & contextElement : context)
contextElement.push_back(dict.getIndexOrInsert(Dict::oobValueStr, col));
context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::oobValueStr, col);
}
else
{
......@@ -169,33 +168,32 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context
dictIndex = dict.getIndexOrInsert(functions[colIndex](featureValue), col);
}
for (auto & contextElement : context)
contextElement.push_back(dictIndex);
context[firstInputIndex+insertIndex] = dictIndex;
}
insertIndex++;
}
for (auto index : targetIndexes)
{
if (configIndex2ContextIndex.count(index))
{
for (auto & contextElement : context)
contextElement.push_back(configIndex2ContextIndex.at(index));
context[firstInputIndex+insertIndex] = configIndex2ContextIndex.at(index);
}
else
{
for (auto & contextElement : context)
{
// -1 == doesn't exist (s.0 when no stack)
if (index == -1)
contextElement.push_back(0);
// -2 == nochild
else if (index == -2)
contextElement.push_back(1);
// other == out of context bounds
else
contextElement.push_back(2);
}
// -1 == doesn't exist (s.0 when no stack)
if (index == -1)
context[firstInputIndex+insertIndex] = 0;
// -2 == nochild
else if (index == -2)
context[firstInputIndex+insertIndex] = 1;
// other == out of context bounds
else
context[firstInputIndex+insertIndex] = 2;
}
insertIndex++;
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment