Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
M
macaon
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container registry
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Franck Dary
macaon
Commits
d3ecc26c
Commit
d3ecc26c
authored
5 years ago
by
Franck Dary
Browse files
Options
Downloads
Patches
Plain Diff
Working version with SparseAdam
parent
92e9fda7
No related branches found
No related tags found
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
dev/src/dev.cpp
+164
-142
164 additions, 142 deletions
dev/src/dev.cpp
torch_modules/include/TestNetwork.hpp
+5
-0
5 additions, 0 deletions
torch_modules/include/TestNetwork.hpp
torch_modules/src/TestNetwork.cpp
+18
-1
18 additions, 1 deletion
torch_modules/src/TestNetwork.cpp
with
187 additions
and
143 deletions
dev/src/dev.cpp
+
164
−
142
View file @
d3ecc26c
...
@@ -8,155 +8,177 @@
...
@@ -8,155 +8,177 @@
#include
"TestNetwork.hpp"
#include
"TestNetwork.hpp"
#include
"ConfigDataset.hpp"
#include
"ConfigDataset.hpp"
constexpr
int
batchSize
=
50
;
//constexpr int batchSize = 50;
constexpr
int
nbExamples
=
350000
;
//constexpr int nbExamples = 350000;
constexpr
int
embeddingSize
=
20
;
//constexpr int embeddingSize = 20;
constexpr
int
nbClasses
=
15
;
//constexpr int nbClasses = 15;
constexpr
int
nbWordsPerDatapoint
=
5
;
//constexpr int nbWordsPerDatapoint = 5;
constexpr
int
maxNbEmbeddings
=
1000000
;
//constexpr int maxNbEmbeddings = 1000000;
//
//3m15s
//struct NetworkImpl : torch::nn::Module
struct
NetworkImpl
:
torch
::
nn
::
Module
//{
// torch::nn::Linear linear{nullptr};
// torch::nn::Embedding wordEmbeddings{nullptr};
//
// std::vector<torch::Tensor> _sparseParameters;
// std::vector<torch::Tensor> _denseParameters;
// NetworkImpl()
// {
// linear = register_module("dense_linear", torch::nn::Linear(embeddingSize, nbClasses));
// auto params = linear->parameters();
// _denseParameters.insert(_denseParameters.end(), params.begin(), params.end());
//
// wordEmbeddings = register_module("sparse_word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingSize).sparse(true)));
// params = wordEmbeddings->parameters();
// _sparseParameters.insert(_sparseParameters.end(), params.begin(), params.end());
// };
// const std::vector<torch::Tensor> & denseParameters()
// {
// return _denseParameters;
// }
// const std::vector<torch::Tensor> & sparseParameters()
// {
// return _sparseParameters;
// }
// torch::Tensor forward(const torch::Tensor & input)
// {
// // I have a batch of sentences (list of word embeddings), so as the sentence embedding I take the mean of the embedding of its words
// auto embeddingsOfInput = wordEmbeddings(input).mean(1);
// return torch::softmax(linear(embeddingsOfInput),1);
// }
//};
//TORCH_MODULE(Network);
//int main(int argc, char * argv[])
//{
// auto nn = Network();
// torch::optim::SparseAdam sparseOptimizer(nn->sparseParameters(), torch::optim::SparseAdamOptions(2e-4).beta1(0.5));
// torch::optim::Adam denseOptimizer(nn->denseParameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
// std::vector<std::pair<torch::Tensor,torch::Tensor>> batches;
// for (int nbBatch = 0; nbBatch < nbExamples / batchSize; ++nbBatch)
// batches.emplace_back(std::make_pair(torch::randint(maxNbEmbeddings,{batchSize,nbWordsPerDatapoint}, at::kLong), torch::randint(nbClasses, batchSize, at::kLong)));
//
// for (auto & batch : batches)
// {
// sparseOptimizer.zero_grad();
// denseOptimizer.zero_grad();
// auto prediction = nn(batch.first);
// auto loss = torch::nll_loss(torch::log(prediction), batch.second);
// loss.backward();
// sparseOptimizer.step();
// denseOptimizer.step();
// }
// return 0;
//}
int
main
(
int
argc
,
char
*
argv
[])
{
{
torch
::
nn
::
Linear
linear
{
nullptr
};
if
(
argc
!=
5
)
torch
::
nn
::
Embedding
wordEmbeddings
{
nullptr
};
NetworkImpl
()
{
{
linear
=
register_module
(
"dense_linear"
,
torch
::
nn
::
Linear
(
embeddingSize
,
nbClasses
));
fmt
::
print
(
stderr
,
"needs 4 arguments.
\n
"
);
wordEmbeddings
=
register_module
(
"sparse_word_embeddings"
,
torch
::
nn
::
Embedding
(
torch
::
nn
::
EmbeddingOptions
(
maxNbEmbeddings
,
embeddingSize
).
sparse
(
true
)));
exit
(
1
);
};
}
torch
::
Tensor
forward
(
const
torch
::
Tensor
&
input
)
at
::
init_num_threads
();
std
::
string
machineFile
=
argv
[
1
];
std
::
string
mcdFile
=
argv
[
2
];
std
::
string
tsvFile
=
argv
[
3
];
//std::string rawFile = argv[4];
std
::
string
rawFile
=
""
;
ReadingMachine
machine
(
machineFile
);
BaseConfig
goldConfig
(
mcdFile
,
tsvFile
,
rawFile
);
SubConfig
config
(
goldConfig
);
config
.
setState
(
machine
.
getStrategy
().
getInitialState
());
std
::
vector
<
torch
::
Tensor
>
contexts
;
std
::
vector
<
torch
::
Tensor
>
classes
;
fmt
::
print
(
"Generating dataset...
\n
"
);
Dict
dict
(
Dict
::
State
::
Open
);
while
(
true
)
{
{
// I have a batch of sentences (list of word embeddings), so as the sentence embedding I take the mean of the embedding of its words
auto
*
transition
=
machine
.
getTransitionSet
().
getBestAppliableTransition
(
config
);
auto
embeddingsOfInput
=
wordEmbeddings
(
input
).
mean
(
1
);
if
(
!
transition
)
return
torch
::
softmax
(
linear
(
embeddingsOfInput
),
1
);
util
::
myThrow
(
"No transition appliable !"
);
auto
context
=
config
.
extractContext
(
5
,
5
,
dict
);
contexts
.
push_back
(
torch
::
from_blob
(
context
.
data
(),
{(
long
)
context
.
size
()},
at
::
kLong
).
clone
());
int
goldIndex
=
3
;
auto
gold
=
torch
::
zeros
(
1
,
at
::
kLong
);
gold
[
0
]
=
goldIndex
;
classes
.
emplace_back
(
gold
);
transition
->
apply
(
config
);
config
.
addToHistory
(
transition
->
getName
());
auto
movement
=
machine
.
getStrategy
().
getMovement
(
config
,
transition
->
getName
());
if
(
movement
==
Strategy
::
endMovement
)
break
;
config
.
setState
(
movement
.
first
);
if
(
!
config
.
moveWordIndex
(
movement
.
second
))
util
::
myThrow
(
"Cannot move word index !"
);
if
(
config
.
needsUpdate
())
config
.
update
();
}
}
};
TORCH_MODULE
(
Network
);
int
main
(
int
argc
,
char
*
argv
[])
auto
dataset
=
ConfigDataset
(
contexts
,
classes
).
map
(
torch
::
data
::
transforms
::
Stack
<>
());
{
auto
nn
=
Network
();
int
nbExamples
=
*
dataset
.
size
();
torch
::
optim
::
SparseAdam
sparseOptimizer
(
nn
->
parameters
(),
torch
::
optim
::
SparseAdamOptions
(
2e-4
).
beta1
(
0.5
));
fmt
::
print
(
"Done! size={}
\n
"
,
nbExamples
);
torch
::
optim
::
Adam
denseOptimizer
(
nn
->
parameters
(),
torch
::
optim
::
AdamOptions
(
2e-4
).
beta1
(
0.5
));
std
::
vector
<
std
::
pair
<
torch
::
Tensor
,
torch
::
Tensor
>>
batches
;
int
batchSize
=
100
;
for
(
int
nbBatch
=
0
;
nbBatch
<
nbExamples
/
batchSize
;
++
nbBatch
)
auto
dataLoader
=
torch
::
data
::
make_data_loader
(
std
::
move
(
dataset
),
torch
::
data
::
DataLoaderOptions
(
batchSize
).
workers
(
0
).
max_jobs
(
0
));
batches
.
emplace_back
(
std
::
make_pair
(
torch
::
randint
(
maxNbEmbeddings
,{
batchSize
,
nbWordsPerDatapoint
},
at
::
kLong
),
torch
::
randint
(
nbClasses
,
batchSize
,
at
::
kLong
)));
TestNetwork
nn
(
machine
.
getTransitionSet
().
size
(),
5
);
for
(
auto
&
batch
:
batches
)
torch
::
optim
::
Adam
denseOptimizer
(
nn
->
denseParameters
(),
torch
::
optim
::
AdamOptions
(
2e-1
).
beta1
(
0.5
));
torch
::
optim
::
SparseAdam
sparseOptimizer
(
nn
->
sparseParameters
(),
torch
::
optim
::
SparseAdamOptions
(
2e-1
).
beta1
(
0.5
));
for
(
int
epoch
=
1
;
epoch
<=
2
;
++
epoch
)
{
{
sparseOptimizer
.
zero_grad
();
float
totalLoss
=
0.0
;
denseOptimizer
.
zero_grad
();
float
lossSoFar
=
0.0
;
auto
prediction
=
nn
(
batch
.
first
);
torch
::
Tensor
example
;
auto
loss
=
torch
::
nll_loss
(
torch
::
log
(
prediction
),
batch
.
second
);
int
currentBatchNumber
=
0
;
loss
.
backward
();
sparseOptimizer
.
step
();
for
(
auto
&
batch
:
*
dataLoader
)
denseOptimizer
.
step
();
{
denseOptimizer
.
zero_grad
();
sparseOptimizer
.
zero_grad
();
auto
data
=
batch
.
data
;
auto
labels
=
batch
.
target
.
squeeze
();
auto
prediction
=
nn
(
data
);
example
=
prediction
[
0
];
auto
loss
=
torch
::
nll_loss
(
torch
::
log
(
prediction
),
labels
);
totalLoss
+=
loss
.
item
<
float
>
();
lossSoFar
+=
loss
.
item
<
float
>
();
loss
.
backward
();
denseOptimizer
.
step
();
sparseOptimizer
.
step
();
if
(
++
currentBatchNumber
*
batchSize
%
1000
==
0
)
{
fmt
::
print
(
"
\r
current epoch : {:6.2f}% loss={:<15}"
,
100.0
*
currentBatchNumber
*
batchSize
/
nbExamples
,
lossSoFar
);
std
::
fflush
(
stdout
);
lossSoFar
=
0
;
}
}
fmt
::
print
(
"
\n
Epoch {} : loss={:.2f}
\n
"
,
epoch
,
totalLoss
);
}
}
return
0
;
return
0
;
}
}
//int main(int argc, char * argv[])
//{
// if (argc != 5)
// {
// fmt::print(stderr, "needs 4 arguments.\n");
// exit(1);
// }
//
// at::init_num_threads();
//
// std::string machineFile = argv[1];
// std::string mcdFile = argv[2];
// std::string tsvFile = argv[3];
// //std::string rawFile = argv[4];
// std::string rawFile = "";
//
// ReadingMachine machine(machineFile);
//
// BaseConfig goldConfig(mcdFile, tsvFile, rawFile);
// SubConfig config(goldConfig);
//
// config.setState(machine.getStrategy().getInitialState());
//
// std::vector<torch::Tensor> contexts;
// std::vector<torch::Tensor> classes;
//
// fmt::print("Generating dataset...\n");
//
// Dict dict(Dict::State::Open);
//
// while (true)
// {
// auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
// if (!transition)
// util::myThrow("No transition appliable !");
//
// auto context = config.extractContext(5,5,dict);
// contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());
//
// int goldIndex = 3;
// auto gold = torch::from_blob(&goldIndex, {1}, at::kLong).clone();
//
// classes.emplace_back(gold);
//
// transition->apply(config);
// config.addToHistory(transition->getName());
//
// auto movement = machine.getStrategy().getMovement(config, transition->getName());
// if (movement == Strategy::endMovement)
// break;
//
// config.setState(movement.first);
// if (!config.moveWordIndex(movement.second))
// util::myThrow("Cannot move word index !");
//
// if (config.needsUpdate())
// config.update();
// }
//
// auto dataset = ConfigDataset(contexts, classes).map(torch::data::transforms::Stack<>());
//
// int nbExamples = *dataset.size();
// fmt::print("Done! size={}\n", nbExamples);
//
// int batchSize = 100;
// auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
//
// TestNetwork nn(machine.getTransitionSet().size(), 5);
// torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
//
// for (int epoch = 1; epoch <= 1; ++epoch)
// {
// float totalLoss = 0.0;
// torch::Tensor example;
// int currentBatchNumber = 0;
//
// for (auto & batch : *dataLoader)
// {
// optimizer.zero_grad();
//
// auto data = batch.data;
// auto labels = batch.target.squeeze();
//
// auto prediction = nn(data);
// example = prediction[0];
//
// auto loss = torch::nll_loss(torch::log(prediction), labels);
// totalLoss += loss.item<float>();
// loss.backward();
// optimizer.step();
//
// if (++currentBatchNumber*batchSize % 1000 == 0)
// {
// fmt::print("\rcurrent epoch : {:6.2f}%", 100.0*currentBatchNumber*batchSize/nbExamples);
// std::fflush(stdout);
// }
// }
//
// fmt::print("Epoch {} : loss={:.2f}\n", epoch, totalLoss);
// }
//
// return 0;
//}
//
This diff is collapsed.
Click to expand it.
torch_modules/include/TestNetwork.hpp
+
5
−
0
View file @
d3ecc26c
...
@@ -12,10 +12,15 @@ class TestNetworkImpl : public torch::nn::Module
...
@@ -12,10 +12,15 @@ class TestNetworkImpl : public torch::nn::Module
torch
::
nn
::
Linear
linear
{
nullptr
};
torch
::
nn
::
Linear
linear
{
nullptr
};
int
focusedIndex
;
int
focusedIndex
;
std
::
vector
<
torch
::
Tensor
>
_denseParameters
;
std
::
vector
<
torch
::
Tensor
>
_sparseParameters
;
public
:
public
:
TestNetworkImpl
(
int
nbOutputs
,
int
focusedIndex
);
TestNetworkImpl
(
int
nbOutputs
,
int
focusedIndex
);
torch
::
Tensor
forward
(
torch
::
Tensor
input
);
torch
::
Tensor
forward
(
torch
::
Tensor
input
);
std
::
vector
<
torch
::
Tensor
>
&
denseParameters
();
std
::
vector
<
torch
::
Tensor
>
&
sparseParameters
();
};
};
TORCH_MODULE
(
TestNetwork
);
TORCH_MODULE
(
TestNetwork
);
...
...
This diff is collapsed.
Click to expand it.
torch_modules/src/TestNetwork.cpp
+
18
−
1
View file @
d3ecc26c
...
@@ -3,11 +3,28 @@
...
@@ -3,11 +3,28 @@
TestNetworkImpl
::
TestNetworkImpl
(
int
nbOutputs
,
int
focusedIndex
)
TestNetworkImpl
::
TestNetworkImpl
(
int
nbOutputs
,
int
focusedIndex
)
{
{
constexpr
int
embeddingsSize
=
30
;
constexpr
int
embeddingsSize
=
30
;
wordEmbeddings
=
register_module
(
"word_embeddings"
,
torch
::
nn
::
Embedding
(
200000
,
embeddingsSize
));
wordEmbeddings
=
register_module
(
"word_embeddings"
,
torch
::
nn
::
Embedding
(
torch
::
nn
::
EmbeddingOptions
(
200000
,
embeddingsSize
).
sparse
(
true
)));
auto
params
=
wordEmbeddings
->
parameters
();
_sparseParameters
.
insert
(
_sparseParameters
.
end
(),
params
.
begin
(),
params
.
end
());
linear
=
register_module
(
"linear"
,
torch
::
nn
::
Linear
(
embeddingsSize
,
nbOutputs
));
linear
=
register_module
(
"linear"
,
torch
::
nn
::
Linear
(
embeddingsSize
,
nbOutputs
));
params
=
linear
->
parameters
();
_denseParameters
.
insert
(
_denseParameters
.
end
(),
params
.
begin
(),
params
.
end
());
this
->
focusedIndex
=
focusedIndex
;
this
->
focusedIndex
=
focusedIndex
;
}
}
std
::
vector
<
torch
::
Tensor
>
&
TestNetworkImpl
::
denseParameters
()
{
return
_denseParameters
;
}
std
::
vector
<
torch
::
Tensor
>
&
TestNetworkImpl
::
sparseParameters
()
{
return
_sparseParameters
;
}
torch
::
Tensor
TestNetworkImpl
::
forward
(
torch
::
Tensor
input
)
torch
::
Tensor
TestNetworkImpl
::
forward
(
torch
::
Tensor
input
)
{
{
// input dim = {batch, sequence, embeddings}
// input dim = {batch, sequence, embeddings}
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment