Skip to content
Snippets Groups Projects
Commit 5769657d authored by Franck Dary's avatar Franck Dary
Browse files

Trying to load gpu tansor onto cpu mem

parent 9df2da0f
Branches
No related tags found
No related merge requests found
......@@ -48,7 +48,7 @@ void Beam::update(ReadingMachine & machine, bool debug)
elements[index].config.setAppliableTransitions(appliableTransitions);
auto context = classifier.getNN()->extractContext(elements[index].config).back();
auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device));
auto prediction = classifier.isRegression() ? classifier.getNN()(neuralInput).squeeze(0) : torch::softmax(classifier.getNN()(neuralInput).squeeze(0), 0);
float entropy = classifier.isRegression() ? 0.0 : NeuralNetworkImpl::entropy(prediction);
......
......@@ -81,17 +81,15 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std
getNN()->loadDicts(path);
getNN()->registerEmbeddings();
getNN()->to(torch::kCPU);
if (!train)
{
torch::load(getNN(), getBestFilename());
torch::load(getNN(), getBestFilename(), NeuralNetworkImpl::device);
getNN()->registerEmbeddings();
getNN()->to(NeuralNetworkImpl::device);
}
else if (std::filesystem::exists(getLastFilename()))
{
torch::load(getNN(), getLastFilename());
torch::load(getNN(), getLastFilename(), NeuralNetworkImpl::device);
getNN()->to(NeuralNetworkImpl::device);
resetOptimizer();
loadOptimizer();
......@@ -185,7 +183,7 @@ void Classifier::loadOptimizer()
{
auto optimizerPath = std::filesystem::path(fmt::format("{}/{}_optimizer.pt", path.string(), name));
if (std::filesystem::exists(optimizerPath))
torch::load(*optimizer, optimizerPath);
torch::load(*optimizer, optimizerPath, NeuralNetworkImpl::device);
}
void Classifier::saveOptimizer()
......@@ -273,16 +271,12 @@ std::string Classifier::getLastFilename() const
void Classifier::saveBest()
{
getNN()->to(torch::kCPU);
torch::save(getNN(), getBestFilename());
getNN()->to(NeuralNetworkImpl::device);
}
void Classifier::saveLast()
{
getNN()->to(torch::kCPU);
torch::save(getNN(), getLastFilename());
getNN()->to(NeuralNetworkImpl::device);
saveOptimizer();
}
......
......@@ -46,7 +46,7 @@ NumericColumnModuleImpl::NumericColumnModuleImpl(std::string name, const std::st
torch::Tensor NumericColumnModuleImpl::forward(torch::Tensor input)
{
auto context = input.narrow(1, firstInputIndex, getInputSize());
auto values = torch::from_blob(context.data_ptr(), context.sizes(), context.strides(), torch::TensorOptions(torch::kDouble).requires_grad(false).device(NeuralNetworkImpl::device)).to(torch::kFloat).unsqueeze(-1).clone();
auto values = torch::from_blob(context.data_ptr(), context.sizes(), context.strides(), torch::TensorOptions(torch::kDouble).requires_grad(false).device(NeuralNetworkImpl::device)).to(torch::kFloat).unsqueeze(-1);
return myModule->forward(values).reshape({input.size(0), -1});
}
......
......@@ -93,7 +93,7 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter")
{
auto & classifier = *machine.getClassifier(config.getState());
auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device));
auto prediction = classifier.isRegression() ? classifier.getNN()(neuralInput).squeeze(0) : torch::softmax(classifier.getNN()(neuralInput).squeeze(0), 0);
entropy = NeuralNetworkImpl::entropy(prediction);
......@@ -291,7 +291,7 @@ void Trainer::Examples::saveIfNeeded(const std::string & state, std::filesystem:
int nbClasses = classes[0].size(0);
auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1).to(torch::kCPU);
auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1);
auto filename = fmt::format("{}-{}_{}-{}.{}.{}.tensor", state, nbClasses, lastSavedIndex, currentExampleIndex-1, epoch, dynamicOracle);
torch::save(tensorToSave, dir/filename);
lastSavedIndex = currentExampleIndex;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment