#include "ModularNetwork.hpp" ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions, std::filesystem::path path) { setName(name); std::string anyBlanks = "(?:(?:\\s|\\t)*)"; auto splitLine = [anyBlanks](std::string line) { std::pair<std::string,std::string> result; util::doIfNameMatch(std::regex(fmt::format("{}(\\S+){}:{}(.+)",anyBlanks,anyBlanks,anyBlanks)),line,[&result](auto sm) { result.first = sm.str(1); result.second = sm.str(2); }); return result; }; std::size_t maxNbOutputs = 0; for (auto & it : nbOutputsPerState) maxNbOutputs = std::max<std::size_t>(it.second, maxNbOutputs); int currentInputSize = 0; int currentOutputSize = 0; std::string mlpDef; for (auto & line : definitions) { auto splited = splitLine(line); std::string name = fmt::format("{}_{}", modules.size(), splited.first); std::string nameH = fmt::format("{}_{}", getName(), name); if (splited.first == "Context") modules.emplace_back(register_module(name, ContextModule(nameH, splited.second, path))); else if (splited.first == "StateName") modules.emplace_back(register_module(name, StateNameModule(nameH, splited.second))); else if (splited.first == "History") modules.emplace_back(register_module(name, HistoryModule(nameH, splited.second))); else if (splited.first == "NumericColumn") modules.emplace_back(register_module(name, NumericColumnModule(nameH, splited.second))); else if (splited.first == "UppercaseRate") modules.emplace_back(register_module(name, UppercaseRateModule(nameH, splited.second))); else if (splited.first == "Focused") modules.emplace_back(register_module(name, FocusedColumnModule(nameH, splited.second))); else if (splited.first == "RawInput") modules.emplace_back(register_module(name, RawInputModule(nameH, splited.second))); else if (splited.first == "SplitTrans") modules.emplace_back(register_module(name, SplitTransModule(nameH, Config::maxNbAppliableSplitTransitions, splited.second))); else if (splited.first == "AppliableTrans") modules.emplace_back(register_module(name, AppliableTransModule(nameH, maxNbOutputs))); else if (splited.first == "Distance") modules.emplace_back(register_module(name, DistanceModule(nameH, splited.second))); else if (splited.first == "DepthLayerTree") modules.emplace_back(register_module(name, DepthLayerTreeEmbeddingModule(nameH, splited.second))); else if (splited.first == "MLP") { mlpDef = splited.second; continue; } else if (splited.first == "InputDropout") { inputDropout = register_module("inputDropout", torch::nn::Dropout(std::stof(splited.second))); continue; } else util::myThrow(fmt::format("unknown module '{}' for line '{}'", splited.first, line)); modules.back()->setFirstInputIndex(currentInputSize); currentInputSize += modules.back()->getInputSize(); currentOutputSize += modules.back()->getOutputSize(); } if (mlpDef.empty()) util::myThrow("no MLP definition found"); if (inputDropout.is_empty()) util::myThrow("no InputDropout definition found"); mlp = register_module("mlp", MLP(currentOutputSize, mlpDef)); for (auto & it : nbOutputsPerState) outputLayersPerState.emplace(it.first,register_module(fmt::format("output_{}",it.first), torch::nn::Linear(mlp->outputSize(), it.second))); } torch::Tensor ModularNetworkImpl::forward(torch::Tensor input) { if (input.dim() == 1) input = input.unsqueeze(0); std::vector<torch::Tensor> outputs; for (auto & mod : modules) outputs.emplace_back(mod->forward(input)); auto totalInput = inputDropout(torch::cat(outputs, 1)); return outputLayersPerState.at(getState())(mlp(totalInput)); } std::vector<std::vector<long>> ModularNetworkImpl::extractContext(Config & config) { std::vector<std::vector<long>> context(1); for (auto & mod : modules) mod->addToContext(context, config); return context; } void ModularNetworkImpl::registerEmbeddings() { for (auto & mod : modules) mod->registerEmbeddings(); } void ModularNetworkImpl::saveDicts(std::filesystem::path path) { for (auto & mod : modules) mod->saveDict(path); } void ModularNetworkImpl::loadDicts(std::filesystem::path path) { for (auto & mod : modules) mod->loadDict(path); } void ModularNetworkImpl::setDictsState(Dict::State state) { for (auto & mod : modules) { if (!mod->dictIsPretrained()) mod->getDict().setState(state); } } void ModularNetworkImpl::setCountOcc(bool countOcc) { for (auto & mod : modules) mod->getDict().countOcc(countOcc); } void ModularNetworkImpl::removeRareDictElements(float rarityThreshold) { std::size_t minNbElems = 1000; for (auto & mod : modules) { auto & dict = mod->getDict(); std::size_t originalSize = dict.size(); while (100.0*(originalSize-dict.size())/originalSize < rarityThreshold and dict.size() > minNbElems) dict.removeRareElements(); } } void ModularNetworkImpl::setState(const std::string & state) { NeuralNetworkImpl::setState(state); for (auto & mod : modules) mod->setState(state); }