Newer
Older
ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions, std::filesystem::path path)
std::string anyBlanks = "(?:(?:\\s|\\t)*)";
auto splitLine = [anyBlanks](std::string line)
{
std::pair<std::string,std::string> result;
util::doIfNameMatch(std::regex(fmt::format("{}(\\S+){}:{}(.+)",anyBlanks,anyBlanks,anyBlanks)),line,[&result](auto sm)
{
result.first = sm.str(1);
result.second = sm.str(2);
});
return result;
};
std::size_t maxNbOutputs = 0;
for (auto & it : nbOutputsPerState)
maxNbOutputs = std::max<std::size_t>(it.second, maxNbOutputs);
int currentInputSize = 0;
int currentOutputSize = 0;
std::string mlpDef;
for (auto & line : definitions)
{
auto splited = splitLine(line);
std::string name = fmt::format("{}_{}", modules.size(), splited.first);
std::string nameH = fmt::format("{}_{}", getName(), name);
modules.emplace_back(register_module(name, ContextModule(nameH, splited.second, path)));
modules.emplace_back(register_module(name, StateNameModule(nameH, splited.second)));
else if (splited.first == "History")
modules.emplace_back(register_module(name, HistoryModule(nameH, splited.second)));
else if (splited.first == "NumericColumn")
modules.emplace_back(register_module(name, NumericColumnModule(nameH, splited.second)));
else if (splited.first == "UppercaseRate")
modules.emplace_back(register_module(name, UppercaseRateModule(nameH, splited.second)));
modules.emplace_back(register_module(name, FocusedColumnModule(nameH, splited.second)));
modules.emplace_back(register_module(name, RawInputModule(nameH, splited.second)));
modules.emplace_back(register_module(name, SplitTransModule(nameH, Config::maxNbAppliableSplitTransitions, splited.second)));
else if (splited.first == "AppliableTrans")
modules.emplace_back(register_module(name, AppliableTransModule(nameH, maxNbOutputs)));
else if (splited.first == "Distance")
modules.emplace_back(register_module(name, DistanceModule(nameH, splited.second)));
modules.emplace_back(register_module(name, DepthLayerTreeEmbeddingModule(nameH, splited.second)));
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
else if (splited.first == "MLP")
{
mlpDef = splited.second;
continue;
}
else if (splited.first == "InputDropout")
{
inputDropout = register_module("inputDropout", torch::nn::Dropout(std::stof(splited.second)));
continue;
}
else
util::myThrow(fmt::format("unknown module '{}' for line '{}'", splited.first, line));
modules.back()->setFirstInputIndex(currentInputSize);
currentInputSize += modules.back()->getInputSize();
currentOutputSize += modules.back()->getOutputSize();
}
if (mlpDef.empty())
util::myThrow("no MLP definition found");
if (inputDropout.is_empty())
util::myThrow("no InputDropout definition found");
mlp = register_module("mlp", MLP(currentOutputSize, mlpDef));
for (auto & it : nbOutputsPerState)
outputLayersPerState.emplace(it.first,register_module(fmt::format("output_{}",it.first), torch::nn::Linear(mlp->outputSize(), it.second)));
}
torch::Tensor ModularNetworkImpl::forward(torch::Tensor input)
{
if (input.dim() == 1)
input = input.unsqueeze(0);
std::vector<torch::Tensor> outputs;
for (auto & mod : modules)
outputs.emplace_back(mod->forward(input));
auto totalInput = inputDropout(torch::cat(outputs, 1));
return outputLayersPerState.at(getState())(mlp(totalInput));
}
std::vector<std::vector<long>> ModularNetworkImpl::extractContext(Config & config)
{
std::vector<std::vector<long>> context(1);
for (auto & mod : modules)
Franck Dary
committed
void ModularNetworkImpl::registerEmbeddings()
Franck Dary
committed
{
for (auto & mod : modules)
Franck Dary
committed
mod->registerEmbeddings();
}
void ModularNetworkImpl::saveDicts(std::filesystem::path path)
{
for (auto & mod : modules)
mod->saveDict(path);
}
void ModularNetworkImpl::loadDicts(std::filesystem::path path)
{
for (auto & mod : modules)
mod->loadDict(path);
}
void ModularNetworkImpl::setDictsState(Dict::State state)
{
for (auto & mod : modules)
Franck Dary
committed
{
if (!mod->dictIsPretrained())
mod->getDict().setState(state);
}
}
void ModularNetworkImpl::setCountOcc(bool countOcc)
{
for (auto & mod : modules)
mod->getDict().countOcc(countOcc);
}
void ModularNetworkImpl::removeRareDictElements(float rarityThreshold)
{
std::size_t minNbElems = 1000;
for (auto & mod : modules)
{
auto & dict = mod->getDict();
std::size_t originalSize = dict.size();
while (100.0*(originalSize-dict.size())/originalSize < rarityThreshold and dict.size() > minNbElems)
dict.removeRareElements();
}
Franck Dary
committed
}
void ModularNetworkImpl::setState(const std::string & state)
{
NeuralNetworkImpl::setState(state);
for (auto & mod : modules)
mod->setState(state);
}