#ifndef DEPTHLAYERTREEEMBEDDING__H
#define DEPTHLAYERTREEEMBEDDING__H

#include <torch/torch.h>
#include "Submodule.hpp"
#include "LSTM.hpp"

class DepthLayerTreeEmbeddingImpl : public torch::nn::Module, public Submodule
{
  private :

  std::vector<std::string> columns{"DEPREL"};
  std::vector<int> focusedBuffer{0};
  std::vector<int> focusedStack{0};
  std::string firstElem{"__special_DepthLayerTreeEmbeddingImpl__"};
  std::vector<LSTM> depthLstm;
  int maxDepth;
  int maxElemPerDepth;

  public :

  DepthLayerTreeEmbeddingImpl(int maxDepth, int maxElemPerDepth, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options);
  torch::Tensor forward(torch::Tensor input);
  std::size_t getOutputSize() override;
  std::size_t getInputSize() override;
  void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
};
TORCH_MODULE(DepthLayerTreeEmbedding);

#endif