Skip to content
Snippets Groups Projects
dev.cpp 2.11 KiB
Newer Older
  • Learn to ignore specific revisions
  • Franck Dary's avatar
    Franck Dary committed
    #include <cstdio>
    
    #include "fmt/core.h"
    
    Franck Dary's avatar
    Franck Dary committed
    #include "util.hpp"
    
    #include "BaseConfig.hpp"
    #include "SubConfig.hpp"
    
    #include "TransitionSet.hpp"
    
    #include "TestNetwork.hpp"
    
    #include "ConfigDataset.hpp"
    
    Franck Dary's avatar
    Franck Dary committed
    
    int main(int argc, char * argv[])
    {
    
      if (argc != 5)
      {
        fmt::print(stderr, "needs 4 arguments.\n");
        exit(1);
      }
    
      std::string machineFile = argv[1];
      std::string mcdFile = argv[2];
      std::string tsvFile = argv[3];
      //std::string rawFile = argv[4];
      std::string rawFile = "";
    
      ReadingMachine machine(machineFile);
    
      BaseConfig goldConfig(mcdFile, tsvFile, rawFile);
    
    Franck Dary's avatar
    Franck Dary committed
      SubConfig config(goldConfig);
    
      config.setState(machine.getStrategy().getInitialState());
    
    
    
      TestNetwork nn(machine.getTransitionSet().size());
      torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
      optimizer.zero_grad();
    
      std::vector<torch::Tensor> predictionsBatch;
      std::vector<torch::Tensor> referencesBatch;
      std::vector<SubConfig> configs;
    
        auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
        if (!transition)
          util::myThrow("No transition appliable !");
    
    
        //here train
    
        int goldIndex = 3;
        auto gold = torch::zeros(machine.getTransitionSet().size(), at::kLong);
        gold[goldIndex] = 1;
    //    referencesBatch.emplace_back(gold);
    //    predictionsBatch.emplace_back(nn(config));
    
    //    auto loss = torch::nll_loss(prediction, gold);
    //    loss.backward();
    //    optimizer.step();
        configs.emplace_back(config);
    
        if (config.getWordIndex()%1 == 0)
          fmt::print("{:.5f}%\n", config.getWordIndex()*100.0/goldConfig.getNbLines());
    
    //    if (config.getWordIndex() >= 500)
    //      exit(1);
    
        transition->apply(config);
    
        config.addToHistory(transition->getName());
    
    
        auto movement = machine.getStrategy().getMovement(config, transition->getName());
        if (movement == Strategy::endMovement)
          break;
    
        config.setState(movement.first);
        if (!config.moveWordIndex(movement.second))
          util::myThrow("Cannot move word index !");
    
    
        if (config.needsUpdate())
          config.update();
    
    Franck Dary's avatar
    Franck Dary committed
    
    
    Franck Dary's avatar
    Franck Dary committed
      return 0;
    }