Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
M
macaon
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container Registry
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Franck Dary
macaon
Commits
d3ecc26c
Commit
d3ecc26c
authored
5 years ago
by
Franck Dary
Browse files
Options
Downloads
Patches
Plain Diff
Working version with SparseAdam
parent
92e9fda7
No related branches found
Branches containing commit
No related tags found
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
dev/src/dev.cpp
+164
-142
164 additions, 142 deletions
dev/src/dev.cpp
torch_modules/include/TestNetwork.hpp
+5
-0
5 additions, 0 deletions
torch_modules/include/TestNetwork.hpp
torch_modules/src/TestNetwork.cpp
+18
-1
18 additions, 1 deletion
torch_modules/src/TestNetwork.cpp
with
187 additions
and
143 deletions
dev/src/dev.cpp
+
164
−
142
View file @
d3ecc26c
...
...
@@ -8,155 +8,177 @@
#include
"TestNetwork.hpp"
#include
"ConfigDataset.hpp"
constexpr
int
batchSize
=
50
;
constexpr
int
nbExamples
=
350000
;
constexpr
int
embeddingSize
=
20
;
constexpr
int
nbClasses
=
15
;
constexpr
int
nbWordsPerDatapoint
=
5
;
constexpr
int
maxNbEmbeddings
=
1000000
;
//3m15s
struct
NetworkImpl
:
torch
::
nn
::
Module
//constexpr int batchSize = 50;
//constexpr int nbExamples = 350000;
//constexpr int embeddingSize = 20;
//constexpr int nbClasses = 15;
//constexpr int nbWordsPerDatapoint = 5;
//constexpr int maxNbEmbeddings = 1000000;
//
//struct NetworkImpl : torch::nn::Module
//{
// torch::nn::Linear linear{nullptr};
// torch::nn::Embedding wordEmbeddings{nullptr};
//
// std::vector<torch::Tensor> _sparseParameters;
// std::vector<torch::Tensor> _denseParameters;
// NetworkImpl()
// {
// linear = register_module("dense_linear", torch::nn::Linear(embeddingSize, nbClasses));
// auto params = linear->parameters();
// _denseParameters.insert(_denseParameters.end(), params.begin(), params.end());
//
// wordEmbeddings = register_module("sparse_word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingSize).sparse(true)));
// params = wordEmbeddings->parameters();
// _sparseParameters.insert(_sparseParameters.end(), params.begin(), params.end());
// };
// const std::vector<torch::Tensor> & denseParameters()
// {
// return _denseParameters;
// }
// const std::vector<torch::Tensor> & sparseParameters()
// {
// return _sparseParameters;
// }
// torch::Tensor forward(const torch::Tensor & input)
// {
// // I have a batch of sentences (list of word embeddings), so as the sentence embedding I take the mean of the embedding of its words
// auto embeddingsOfInput = wordEmbeddings(input).mean(1);
// return torch::softmax(linear(embeddingsOfInput),1);
// }
//};
//TORCH_MODULE(Network);
//int main(int argc, char * argv[])
//{
// auto nn = Network();
// torch::optim::SparseAdam sparseOptimizer(nn->sparseParameters(), torch::optim::SparseAdamOptions(2e-4).beta1(0.5));
// torch::optim::Adam denseOptimizer(nn->denseParameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
// std::vector<std::pair<torch::Tensor,torch::Tensor>> batches;
// for (int nbBatch = 0; nbBatch < nbExamples / batchSize; ++nbBatch)
// batches.emplace_back(std::make_pair(torch::randint(maxNbEmbeddings,{batchSize,nbWordsPerDatapoint}, at::kLong), torch::randint(nbClasses, batchSize, at::kLong)));
//
// for (auto & batch : batches)
// {
// sparseOptimizer.zero_grad();
// denseOptimizer.zero_grad();
// auto prediction = nn(batch.first);
// auto loss = torch::nll_loss(torch::log(prediction), batch.second);
// loss.backward();
// sparseOptimizer.step();
// denseOptimizer.step();
// }
// return 0;
//}
int
main
(
int
argc
,
char
*
argv
[])
{
torch
::
nn
::
Linear
linear
{
nullptr
};
torch
::
nn
::
Embedding
wordEmbeddings
{
nullptr
};
NetworkImpl
()
if
(
argc
!=
5
)
{
linear
=
register_module
(
"dense_linear"
,
torch
::
nn
::
Linear
(
embeddingSize
,
nbClasses
));
wordEmbeddings
=
register_module
(
"sparse_word_embeddings"
,
torch
::
nn
::
Embedding
(
torch
::
nn
::
EmbeddingOptions
(
maxNbEmbeddings
,
embeddingSize
).
sparse
(
true
)));
};
torch
::
Tensor
forward
(
const
torch
::
Tensor
&
input
)
fmt
::
print
(
stderr
,
"needs 4 arguments.
\n
"
);
exit
(
1
);
}
at
::
init_num_threads
();
std
::
string
machineFile
=
argv
[
1
];
std
::
string
mcdFile
=
argv
[
2
];
std
::
string
tsvFile
=
argv
[
3
];
//std::string rawFile = argv[4];
std
::
string
rawFile
=
""
;
ReadingMachine
machine
(
machineFile
);
BaseConfig
goldConfig
(
mcdFile
,
tsvFile
,
rawFile
);
SubConfig
config
(
goldConfig
);
config
.
setState
(
machine
.
getStrategy
().
getInitialState
());
std
::
vector
<
torch
::
Tensor
>
contexts
;
std
::
vector
<
torch
::
Tensor
>
classes
;
fmt
::
print
(
"Generating dataset...
\n
"
);
Dict
dict
(
Dict
::
State
::
Open
);
while
(
true
)
{
// I have a batch of sentences (list of word embeddings), so as the sentence embedding I take the mean of the embedding of its words
auto
embeddingsOfInput
=
wordEmbeddings
(
input
).
mean
(
1
);
return
torch
::
softmax
(
linear
(
embeddingsOfInput
),
1
);
auto
*
transition
=
machine
.
getTransitionSet
().
getBestAppliableTransition
(
config
);
if
(
!
transition
)
util
::
myThrow
(
"No transition appliable !"
);
auto
context
=
config
.
extractContext
(
5
,
5
,
dict
);
contexts
.
push_back
(
torch
::
from_blob
(
context
.
data
(),
{(
long
)
context
.
size
()},
at
::
kLong
).
clone
());
int
goldIndex
=
3
;
auto
gold
=
torch
::
zeros
(
1
,
at
::
kLong
);
gold
[
0
]
=
goldIndex
;
classes
.
emplace_back
(
gold
);
transition
->
apply
(
config
);
config
.
addToHistory
(
transition
->
getName
());
auto
movement
=
machine
.
getStrategy
().
getMovement
(
config
,
transition
->
getName
());
if
(
movement
==
Strategy
::
endMovement
)
break
;
config
.
setState
(
movement
.
first
);
if
(
!
config
.
moveWordIndex
(
movement
.
second
))
util
::
myThrow
(
"Cannot move word index !"
);
if
(
config
.
needsUpdate
())
config
.
update
();
}
};
TORCH_MODULE
(
Network
);
int
main
(
int
argc
,
char
*
argv
[])
{
auto
nn
=
Network
();
torch
::
optim
::
SparseAdam
sparseOptimizer
(
nn
->
parameters
(),
torch
::
optim
::
SparseAdamOptions
(
2e-4
).
beta1
(
0.5
));
torch
::
optim
::
Adam
denseOptimizer
(
nn
->
parameters
(),
torch
::
optim
::
AdamOptions
(
2e-4
).
beta1
(
0.5
));
std
::
vector
<
std
::
pair
<
torch
::
Tensor
,
torch
::
Tensor
>>
batches
;
for
(
int
nbBatch
=
0
;
nbBatch
<
nbExamples
/
batchSize
;
++
nbBatch
)
batches
.
emplace_back
(
std
::
make_pair
(
torch
::
randint
(
maxNbEmbeddings
,{
batchSize
,
nbWordsPerDatapoint
},
at
::
kLong
),
torch
::
randint
(
nbClasses
,
batchSize
,
at
::
kLong
)));
for
(
auto
&
batch
:
batches
)
auto
dataset
=
ConfigDataset
(
contexts
,
classes
).
map
(
torch
::
data
::
transforms
::
Stack
<>
());
int
nbExamples
=
*
dataset
.
size
();
fmt
::
print
(
"Done! size={}
\n
"
,
nbExamples
);
int
batchSize
=
100
;
auto
dataLoader
=
torch
::
data
::
make_data_loader
(
std
::
move
(
dataset
),
torch
::
data
::
DataLoaderOptions
(
batchSize
).
workers
(
0
).
max_jobs
(
0
));
TestNetwork
nn
(
machine
.
getTransitionSet
().
size
(),
5
);
torch
::
optim
::
Adam
denseOptimizer
(
nn
->
denseParameters
(),
torch
::
optim
::
AdamOptions
(
2e-1
).
beta1
(
0.5
));
torch
::
optim
::
SparseAdam
sparseOptimizer
(
nn
->
sparseParameters
(),
torch
::
optim
::
SparseAdamOptions
(
2e-1
).
beta1
(
0.5
));
for
(
int
epoch
=
1
;
epoch
<=
2
;
++
epoch
)
{
sparseOptimizer
.
zero_grad
();
denseOptimizer
.
zero_grad
();
auto
prediction
=
nn
(
batch
.
first
);
auto
loss
=
torch
::
nll_loss
(
torch
::
log
(
prediction
),
batch
.
second
);
loss
.
backward
();
sparseOptimizer
.
step
();
denseOptimizer
.
step
();
float
totalLoss
=
0.0
;
float
lossSoFar
=
0.0
;
torch
::
Tensor
example
;
int
currentBatchNumber
=
0
;
for
(
auto
&
batch
:
*
dataLoader
)
{
denseOptimizer
.
zero_grad
();
sparseOptimizer
.
zero_grad
();
auto
data
=
batch
.
data
;
auto
labels
=
batch
.
target
.
squeeze
();
auto
prediction
=
nn
(
data
);
example
=
prediction
[
0
];
auto
loss
=
torch
::
nll_loss
(
torch
::
log
(
prediction
),
labels
);
totalLoss
+=
loss
.
item
<
float
>
();
lossSoFar
+=
loss
.
item
<
float
>
();
loss
.
backward
();
denseOptimizer
.
step
();
sparseOptimizer
.
step
();
if
(
++
currentBatchNumber
*
batchSize
%
1000
==
0
)
{
fmt
::
print
(
"
\r
current epoch : {:6.2f}% loss={:<15}"
,
100.0
*
currentBatchNumber
*
batchSize
/
nbExamples
,
lossSoFar
);
std
::
fflush
(
stdout
);
lossSoFar
=
0
;
}
}
fmt
::
print
(
"
\n
Epoch {} : loss={:.2f}
\n
"
,
epoch
,
totalLoss
);
}
return
0
;
}
//int main(int argc, char * argv[])
//{
// if (argc != 5)
// {
// fmt::print(stderr, "needs 4 arguments.\n");
// exit(1);
// }
//
// at::init_num_threads();
//
// std::string machineFile = argv[1];
// std::string mcdFile = argv[2];
// std::string tsvFile = argv[3];
// //std::string rawFile = argv[4];
// std::string rawFile = "";
//
// ReadingMachine machine(machineFile);
//
// BaseConfig goldConfig(mcdFile, tsvFile, rawFile);
// SubConfig config(goldConfig);
//
// config.setState(machine.getStrategy().getInitialState());
//
// std::vector<torch::Tensor> contexts;
// std::vector<torch::Tensor> classes;
//
// fmt::print("Generating dataset...\n");
//
// Dict dict(Dict::State::Open);
//
// while (true)
// {
// auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
// if (!transition)
// util::myThrow("No transition appliable !");
//
// auto context = config.extractContext(5,5,dict);
// contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());
//
// int goldIndex = 3;
// auto gold = torch::from_blob(&goldIndex, {1}, at::kLong).clone();
//
// classes.emplace_back(gold);
//
// transition->apply(config);
// config.addToHistory(transition->getName());
//
// auto movement = machine.getStrategy().getMovement(config, transition->getName());
// if (movement == Strategy::endMovement)
// break;
//
// config.setState(movement.first);
// if (!config.moveWordIndex(movement.second))
// util::myThrow("Cannot move word index !");
//
// if (config.needsUpdate())
// config.update();
// }
//
// auto dataset = ConfigDataset(contexts, classes).map(torch::data::transforms::Stack<>());
//
// int nbExamples = *dataset.size();
// fmt::print("Done! size={}\n", nbExamples);
//
// int batchSize = 100;
// auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
//
// TestNetwork nn(machine.getTransitionSet().size(), 5);
// torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
//
// for (int epoch = 1; epoch <= 1; ++epoch)
// {
// float totalLoss = 0.0;
// torch::Tensor example;
// int currentBatchNumber = 0;
//
// for (auto & batch : *dataLoader)
// {
// optimizer.zero_grad();
//
// auto data = batch.data;
// auto labels = batch.target.squeeze();
//
// auto prediction = nn(data);
// example = prediction[0];
//
// auto loss = torch::nll_loss(torch::log(prediction), labels);
// totalLoss += loss.item<float>();
// loss.backward();
// optimizer.step();
//
// if (++currentBatchNumber*batchSize % 1000 == 0)
// {
// fmt::print("\rcurrent epoch : {:6.2f}%", 100.0*currentBatchNumber*batchSize/nbExamples);
// std::fflush(stdout);
// }
// }
//
// fmt::print("Epoch {} : loss={:.2f}\n", epoch, totalLoss);
// }
//
// return 0;
//}
//
This diff is collapsed.
Click to expand it.
torch_modules/include/TestNetwork.hpp
+
5
−
0
View file @
d3ecc26c
...
...
@@ -12,10 +12,15 @@ class TestNetworkImpl : public torch::nn::Module
torch
::
nn
::
Linear
linear
{
nullptr
};
int
focusedIndex
;
std
::
vector
<
torch
::
Tensor
>
_denseParameters
;
std
::
vector
<
torch
::
Tensor
>
_sparseParameters
;
public
:
TestNetworkImpl
(
int
nbOutputs
,
int
focusedIndex
);
torch
::
Tensor
forward
(
torch
::
Tensor
input
);
std
::
vector
<
torch
::
Tensor
>
&
denseParameters
();
std
::
vector
<
torch
::
Tensor
>
&
sparseParameters
();
};
TORCH_MODULE
(
TestNetwork
);
...
...
This diff is collapsed.
Click to expand it.
torch_modules/src/TestNetwork.cpp
+
18
−
1
View file @
d3ecc26c
...
...
@@ -3,11 +3,28 @@
TestNetworkImpl
::
TestNetworkImpl
(
int
nbOutputs
,
int
focusedIndex
)
{
constexpr
int
embeddingsSize
=
30
;
wordEmbeddings
=
register_module
(
"word_embeddings"
,
torch
::
nn
::
Embedding
(
200000
,
embeddingsSize
));
wordEmbeddings
=
register_module
(
"word_embeddings"
,
torch
::
nn
::
Embedding
(
torch
::
nn
::
EmbeddingOptions
(
200000
,
embeddingsSize
).
sparse
(
true
)));
auto
params
=
wordEmbeddings
->
parameters
();
_sparseParameters
.
insert
(
_sparseParameters
.
end
(),
params
.
begin
(),
params
.
end
());
linear
=
register_module
(
"linear"
,
torch
::
nn
::
Linear
(
embeddingsSize
,
nbOutputs
));
params
=
linear
->
parameters
();
_denseParameters
.
insert
(
_denseParameters
.
end
(),
params
.
begin
(),
params
.
end
());
this
->
focusedIndex
=
focusedIndex
;
}
std
::
vector
<
torch
::
Tensor
>
&
TestNetworkImpl
::
denseParameters
()
{
return
_denseParameters
;
}
std
::
vector
<
torch
::
Tensor
>
&
TestNetworkImpl
::
sparseParameters
()
{
return
_sparseParameters
;
}
torch
::
Tensor
TestNetworkImpl
::
forward
(
torch
::
Tensor
input
)
{
// input dim = {batch, sequence, embeddings}
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment