Newer
Older
#include "BaseConfig.hpp"
#include "SubConfig.hpp"
Franck Dary
committed
#include "ReadingMachine.hpp"
if (argc != 5)
{
fmt::print(stderr, "needs 4 arguments.\n");
exit(1);
}
std::string machineFile = argv[1];
std::string mcdFile = argv[2];
std::string tsvFile = argv[3];
//std::string rawFile = argv[4];
std::string rawFile = "";
ReadingMachine machine(machineFile);
BaseConfig goldConfig(mcdFile, tsvFile, rawFile);
config.setState(machine.getStrategy().getInitialState());
TestNetwork nn(machine.getTransitionSet().size());
torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
optimizer.zero_grad();
std::vector<torch::Tensor> predictionsBatch;
std::vector<torch::Tensor> referencesBatch;
std::vector<SubConfig> configs;
auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
if (!transition)
util::myThrow("No transition appliable !");
int goldIndex = 3;
auto gold = torch::zeros(machine.getTransitionSet().size(), at::kLong);
gold[goldIndex] = 1;
// referencesBatch.emplace_back(gold);
// predictionsBatch.emplace_back(nn(config));
// auto loss = torch::nll_loss(prediction, gold);
// loss.backward();
// optimizer.step();
configs.emplace_back(config);
if (config.getWordIndex()%1 == 0)
fmt::print("{:.5f}%\n", config.getWordIndex()*100.0/goldConfig.getNbLines());
// if (config.getWordIndex() >= 500)
// exit(1);
config.addToHistory(transition->getName());
auto movement = machine.getStrategy().getMovement(config, transition->getName());
if (movement == Strategy::endMovement)
break;
config.setState(movement.first);
if (!config.moveWordIndex(movement.second))
util::myThrow("Cannot move word index !");
if (config.needsUpdate())
config.update();