#include <cstdio> #include "fmt/core.h" #include "util.hpp" #include "BaseConfig.hpp" #include "SubConfig.hpp" #include "TransitionSet.hpp" #include "ReadingMachine.hpp" #include "TestNetwork.hpp" #include "ConfigDataset.hpp" int main(int argc, char * argv[]) { if (argc != 5) { fmt::print(stderr, "needs 4 arguments.\n"); exit(1); } std::string machineFile = argv[1]; std::string mcdFile = argv[2]; std::string tsvFile = argv[3]; //std::string rawFile = argv[4]; std::string rawFile = ""; ReadingMachine machine(machineFile); BaseConfig goldConfig(mcdFile, tsvFile, rawFile); SubConfig config(goldConfig); config.setState(machine.getStrategy().getInitialState()); TestNetwork nn(machine.getTransitionSet().size()); torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5)); optimizer.zero_grad(); std::vector<torch::Tensor> predictionsBatch; std::vector<torch::Tensor> referencesBatch; std::vector<SubConfig> configs; while (true) { auto * transition = machine.getTransitionSet().getBestAppliableTransition(config); if (!transition) util::myThrow("No transition appliable !"); //here train int goldIndex = 3; auto gold = torch::zeros(machine.getTransitionSet().size(), at::kLong); gold[goldIndex] = 1; // referencesBatch.emplace_back(gold); // predictionsBatch.emplace_back(nn(config)); // auto loss = torch::nll_loss(prediction, gold); // loss.backward(); // optimizer.step(); configs.emplace_back(config); if (config.getWordIndex()%1 == 0) fmt::print("{:.5f}%\n", config.getWordIndex()*100.0/goldConfig.getNbLines()); // if (config.getWordIndex() >= 500) // exit(1); transition->apply(config); config.addToHistory(transition->getName()); auto movement = machine.getStrategy().getMovement(config, transition->getName()); if (movement == Strategy::endMovement) break; config.setState(movement.first); if (!config.moveWordIndex(movement.second)) util::myThrow("Cannot move word index !"); if (config.needsUpdate()) config.update(); } return 0; }