#ifndef MODULARNETWORK__H #define MODULARNETWORK__H #include "NeuralNetwork.hpp" #include "ContextModule.hpp" #include "RawInputModule.hpp" #include "SplitTransModule.hpp" #include "FocusedColumnModule.hpp" #include "DepthLayerTreeEmbeddingModule.hpp" #include "StateNameModule.hpp" #include "UppercaseRateModule.hpp" #include "NumericColumnModule.hpp" #include "MLP.hpp" class ModularNetworkImpl : public NeuralNetworkImpl { private : torch::nn::Dropout inputDropout{nullptr}; MLP mlp{nullptr}; std::vector<std::shared_ptr<Submodule>> modules; std::map<std::string,torch::nn::Linear> outputLayersPerState; public : ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions); torch::Tensor forward(torch::Tensor input) override; std::vector<std::vector<long>> extractContext(Config & config) override; void registerEmbeddings() override; void saveDicts(std::filesystem::path path) override; void loadDicts(std::filesystem::path path) override; void setDictsState(Dict::State state) override; void setCountOcc(bool countOcc) override; void removeRareDictElements(float rarityThreshold) override; }; #endif