#ifndef MODULARNETWORK__H
#define MODULARNETWORK__H

#include "NeuralNetwork.hpp"
#include "ContextModule.hpp"
#include "RawInputModule.hpp"
#include "SplitTransModule.hpp"
#include "FocusedColumnModule.hpp"
#include "DepthLayerTreeEmbeddingModule.hpp"
#include "StateNameModule.hpp"
#include "UppercaseRateModule.hpp"
#include "NumericColumnModule.hpp"
#include "MLP.hpp"

class ModularNetworkImpl : public NeuralNetworkImpl
{
  private :

  torch::nn::Dropout inputDropout{nullptr};

  MLP mlp{nullptr};
  std::vector<std::shared_ptr<Submodule>> modules;
  std::map<std::string,torch::nn::Linear> outputLayersPerState;

  public :

  ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions);
  torch::Tensor forward(torch::Tensor input) override;
  std::vector<std::vector<long>> extractContext(Config & config) override;
  void registerEmbeddings() override;
  void saveDicts(std::filesystem::path path) override;
  void loadDicts(std::filesystem::path path) override;
  void setDictsState(Dict::State state) override;
  void setCountOcc(bool countOcc) override;
  void removeRareDictElements(float rarityThreshold) override;
};

#endif