-
Franck Dary authoredFranck Dary authored
NumericColumnModule.hpp 822 B
#ifndef NUMERICCOLUMNMODULE__H
#define NUMERICCOLUMNMODULE__H
#include <torch/torch.h>
#include "Submodule.hpp"
#include "MyModule.hpp"
#include "LSTM.hpp"
#include "GRU.hpp"
#include "Concat.hpp"
class NumericColumnModuleImpl : public Submodule
{
private :
int outSize;
std::vector<int> focusedBuffer, focusedStack;
std::shared_ptr<MyModule> myModule{nullptr};
std::string column;
public :
NumericColumnModuleImpl(std::string name, const std::string & definition);
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings(std::filesystem::path pretrained) override;
};
TORCH_MODULE(NumericColumnModule);
#endif