Commit 5769657d authored by Franck Dary's avatar Franck Dary
Browse files

Trying to load gpu tansor onto cpu mem

parent 9df2da0f
......@@ -48,7 +48,7 @@ void Beam::update(ReadingMachine & machine, bool debug)
elements[index].config.setAppliableTransitions(appliableTransitions);
auto context = classifier.getNN()->extractContext(elements[index].config).back();
auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device));
auto prediction = classifier.isRegression() ? classifier.getNN()(neuralInput).squeeze(0) : torch::softmax(classifier.getNN()(neuralInput).squeeze(0), 0);
float entropy = classifier.isRegression() ? 0.0 : NeuralNetworkImpl::entropy(prediction);
......
......@@ -81,17 +81,15 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std
getNN()->loadDicts(path);
getNN()->registerEmbeddings();
getNN()->to(torch::kCPU);
if (!train)
{
torch::load(getNN(), getBestFilename());
torch::load(getNN(), getBestFilename(), NeuralNetworkImpl::device);
getNN()->registerEmbeddings();
getNN()->to(NeuralNetworkImpl::device);
}
else if (std::filesystem::exists(getLastFilename()))
{
torch::load(getNN(), getLastFilename());
torch::load(getNN(), getLastFilename(), NeuralNetworkImpl::device);
getNN()->to(NeuralNetworkImpl::device);
resetOptimizer();
loadOptimizer();
......@@ -185,7 +183,7 @@ void Classifier::loadOptimizer()
{
auto optimizerPath = std::filesystem::path(fmt::format("{}/{}_optimizer.pt", path.string(), name));
if (std::filesystem::exists(optimizerPath))
torch::load(*optimizer, optimizerPath);
torch::load(*optimizer, optimizerPath, NeuralNetworkImpl::device);
}
void Classifier::saveOptimizer()
......@@ -273,16 +271,12 @@ std::string Classifier::getLastFilename() const
void Classifier::saveBest()
{
getNN()->to(torch::kCPU);
torch::save(getNN(), getBestFilename());
getNN()->to(NeuralNetworkImpl::device);
}
void Classifier::saveLast()
{
getNN()->to(torch::kCPU);
torch::save(getNN(), getLastFilename());
getNN()->to(NeuralNetworkImpl::device);
saveOptimizer();
}
......
......@@ -46,7 +46,7 @@ NumericColumnModuleImpl::NumericColumnModuleImpl(std::string name, const std::st
torch::Tensor NumericColumnModuleImpl::forward(torch::Tensor input)
{
auto context = input.narrow(1, firstInputIndex, getInputSize());
auto values = torch::from_blob(context.data_ptr(), context.sizes(), context.strides(), torch::TensorOptions(torch::kDouble).requires_grad(false).device(NeuralNetworkImpl::device)).to(torch::kFloat).unsqueeze(-1).clone();
auto values = torch::from_blob(context.data_ptr(), context.sizes(), context.strides(), torch::TensorOptions(torch::kDouble).requires_grad(false).device(NeuralNetworkImpl::device)).to(torch::kFloat).unsqueeze(-1);
return myModule->forward(values).reshape({input.size(0), -1});
}
......
......@@ -93,7 +93,7 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter")
{
auto & classifier = *machine.getClassifier(config.getState());
auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device));
auto prediction = classifier.isRegression() ? classifier.getNN()(neuralInput).squeeze(0) : torch::softmax(classifier.getNN()(neuralInput).squeeze(0), 0);
entropy = NeuralNetworkImpl::entropy(prediction);
......@@ -291,7 +291,7 @@ void Trainer::Examples::saveIfNeeded(const std::string & state, std::filesystem:
int nbClasses = classes[0].size(0);
auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1).to(torch::kCPU);
auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1);
auto filename = fmt::format("{}-{}_{}-{}.{}.{}.tensor", state, nbClasses, lastSavedIndex, currentExampleIndex-1, epoch, dynamicOracle);
torch::save(tensorToSave, dir/filename);
lastSavedIndex = currentExampleIndex;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment