Skip to content
Snippets Groups Projects
RandomNetwork.hpp 746 B
#ifndef RANDOMNETWORK__H
#define RANDOMNETWORK__H

#include "NeuralNetwork.hpp"

class RandomNetworkImpl : public NeuralNetworkImpl
{
  private :

  std::map<std::string,std::size_t> nbOutputsPerState;

  public :

  RandomNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState);
  torch::Tensor forward(torch::Tensor input) override;
  std::vector<std::vector<long>> extractContext(Config &) override;
  void registerEmbeddings() override;
  void saveDicts(std::filesystem::path path) override;
  void loadDicts(std::filesystem::path path) override;
  void setDictsState(Dict::State state) override;
  void setCountOcc(bool countOcc) override;
  void removeRareDictElements(float rarityThreshold) override;
};

#endif