Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
M
macaon
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container Registry
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Franck Dary
macaon
Commits
d4ec0a24
Commit
d4ec0a24
authored
5 years ago
by
Franck Dary
Browse files
Options
Downloads
Patches
Plain Diff
Made SparseAdam optimizer
parent
be63334b
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
dev/src/dev.cpp
+240
-72
240 additions, 72 deletions
dev/src/dev.cpp
with
240 additions
and
72 deletions
dev/src/dev.cpp
+
240
−
72
View file @
d4ec0a24
...
...
@@ -8,105 +8,273 @@
#include
"TestNetwork.hpp"
#include
"ConfigDataset.hpp"
int
main
(
int
argc
,
char
*
argv
[])
namespace
torch
{
namespace
optim
{
if
(
argc
!=
5
)
class
SparseAdam
:
public
Optimizer
{
public:
template
<
typename
ParameterContainer
>
explicit
SparseAdam
(
ParameterContainer
&&
parameters
,
const
AdamOptions
&
options_
)
:
Optimizer
(
std
::
forward
<
ParameterContainer
>
(
parameters
)),
options
(
options_
)
{
fmt
::
print
(
stderr
,
"needs 4 arguments.
\n
"
);
exit
(
1
);
}
std
::
string
machineFile
=
argv
[
1
];
std
::
string
mcdFile
=
argv
[
2
];
std
::
string
tsvFile
=
argv
[
3
];
//std::string rawFile = argv[4];
std
::
string
rawFile
=
""
;
void
step
()
override
{
for
(
size_t
i
=
0
;
i
<
parameters_
.
size
();
++
i
)
{
Tensor
p
=
parameters_
.
at
(
i
);
if
(
!
p
.
grad
().
defined
())
continue
;
auto
&
exp_average
=
buffer_at
(
exp_average_buffers
,
i
);
auto
&
exp_average_sq
=
buffer_at
(
exp_average_sq_buffers
,
i
);
buffer_at
(
step_buffers
,
i
)
+=
1
;
const
auto
bias_correction1
=
1
-
std
::
pow
(
options
.
beta1
(),
buffer_at
(
step_buffers
,
i
));
const
auto
bias_correction2
=
1
-
std
::
pow
(
options
.
beta2
(),
buffer_at
(
step_buffers
,
i
));
if
(
p
.
grad
().
is_sparse
())
{
NoGradGuard
guard
;
p
.
grad
()
=
p
.
grad
().
coalesce
();
auto
indices
=
p
.
grad
().
indices
().
squeeze
();
auto
values
=
p
.
grad
().
values
();
ReadingMachine
machine
(
machineFile
);
auto
old_exp_average_values
=
exp_average
.
sparse_mask
(
p
.
grad
()).
_values
();
auto
exp_average_update_values
=
values
.
sub
(
old_exp_average_values
).
mul_
(
1
-
options
.
beta1
());
for
(
unsigned
int
j
=
0
;
j
<
indices
.
size
(
0
);
j
++
)
exp_average
[
indices
[
j
].
item
<
long
>
()]
+=
exp_average_update_values
[
j
];
auto
old_exp_average_sq_values
=
exp_average_sq
.
sparse_mask
(
p
.
grad
()).
_values
();
auto
exp_average_sq_update_values
=
values
.
pow
(
2
).
sub_
(
old_exp_average_sq_values
).
mul_
(
1
-
options
.
beta2
());
for
(
unsigned
int
j
=
0
;
j
<
indices
.
size
(
0
);
j
++
)
exp_average_sq
[
indices
[
j
].
item
<
long
>
()]
+=
exp_average_sq_update_values
[
j
];
BaseConfig
goldConfig
(
mcdFile
,
tsvFile
,
rawFile
);
SubConfig
config
(
goldConfig
);
auto
numer
=
exp_average_update_values
.
add_
(
old_exp_average_values
);
exp_average_sq_update_values
.
add_
(
old_exp_average_sq_values
);
auto
denom
=
exp_average_sq_update_values
.
sqrt_
().
add_
(
options
.
eps
());
const
auto
step_size
=
options
.
learning_rate
()
*
std
::
sqrt
(
bias_correction2
)
/
bias_correction1
;
auto
divided
=
numer
.
div
(
denom
);
for
(
unsigned
int
j
=
0
;
j
<
indices
.
size
(
0
);
j
++
)
p
.
data
()[
indices
[
j
].
item
<
long
>
()]
+=
-
step_size
*
divided
[
j
];
}
else
{
if
(
options
.
weight_decay
()
>
0
)
{
NoGradGuard
guard
;
p
.
grad
()
=
p
.
grad
()
+
options
.
weight_decay
()
*
p
;
}
config
.
setState
(
machine
.
getStrategy
().
getInitialState
());
exp_average
.
mul_
(
options
.
beta1
()).
add_
(
p
.
grad
(),
1
-
options
.
beta1
());
exp_average_sq
.
mul_
(
options
.
beta2
()).
addcmul_
(
p
.
grad
(),
p
.
grad
(),
1
-
options
.
beta2
());
std
::
vector
<
torch
::
Tensor
>
contexts
;
std
::
vector
<
torch
::
Tensor
>
classes
;
Tensor
denom
;
if
(
options
.
amsgrad
())
{
auto
&
max_exp_average_sq
=
buffer_at
(
max_exp_average_sq_buffers
,
i
);
max_exp_average_sq
=
torch
::
max
(
max_exp_average_sq
,
exp_average_sq
);
denom
=
max_exp_average_sq
/
bias_correction2
;
}
else
{
denom
=
exp_average_sq
/
bias_correction2
;
}
fmt
::
print
(
"Generating dataset...
\n
"
)
;
const
auto
step_size
=
options
.
learning_rate
()
/
bias_correction1
;
Dict
dict
(
Dict
::
State
::
Open
);
NoGradGuard
guard
;
p
.
addcdiv_
(
exp_average
,
denom
.
sqrt
()
+
options
.
eps
(),
-
step_size
);
}
}
}
while
(
true
)
void
save
(
serialize
::
OutputArchive
&
archive
)
const
override
{
auto
*
transition
=
machine
.
getTransitionSet
().
getBestAppliableTransition
(
config
);
if
(
!
transition
)
util
::
myThrow
(
"No transition appliable !"
);
//serialize(*this, archive)
;
}
auto
context
=
config
.
extractContext
(
5
,
5
,
dict
);
contexts
.
push_back
(
torch
::
from_blob
(
context
.
data
(),
{(
long
)
context
.
size
()},
at
::
kLong
).
clone
());
void
load
(
serialize
::
InputArchive
&
archive
)
override
{
//serialize(*this, archive)
;
}
int
goldIndex
=
3
;
auto
gold
=
torch
::
from_blob
(
&
goldIndex
,
{
1
},
at
::
kLong
).
clone
();
public
:
classes
.
emplace_back
(
gold
)
;
AdamOptions
options
;
transition
->
apply
(
config
);
config
.
addToHistory
(
transition
->
getName
());
std
::
vector
<
int64_t
>
step_buffers
;
std
::
vector
<
Tensor
>
exp_average_buffers
;
std
::
vector
<
Tensor
>
exp_average_sq_buffers
;
std
::
vector
<
Tensor
>
max_exp_average_sq_buffers
;
auto
movement
=
machine
.
getStrategy
().
getMovement
(
config
,
transition
->
getName
());
if
(
movement
==
Strategy
::
endMovement
)
break
;
private
:
config
.
setState
(
movement
.
first
);
if
(
!
config
.
moveWordIndex
(
movement
.
second
))
util
::
myThrow
(
"Cannot move word index !"
);
SparseAdam
()
:
options
(
0
)
{}
if
(
config
.
needsUpdate
())
config
.
update
();
template
<
typename
Self
,
typename
Archive
>
static
void
serialize
(
Self
&
self
,
Archive
&
archive
)
{
_TORCH_OPTIM_SERIALIZE
(
step_buffers
);
_TORCH_OPTIM_SERIALIZE
(
exp_average_buffers
);
_TORCH_OPTIM_SERIALIZE
(
exp_average_sq_buffers
);
_TORCH_OPTIM_SERIALIZE
(
max_exp_average_sq_buffers
);
}
};
auto
dataset
=
ConfigDataset
(
contexts
,
classes
).
map
(
torch
::
data
::
transforms
::
Stack
<>
());
}
// torch
}
// optim
int
nbExamples
=
*
dataset
.
size
();
fmt
::
print
(
"Done! size={}
\n
"
,
nbExamples
);
constexpr
int
batchSize
=
50
;
constexpr
int
nbExamples
=
350000
;
constexpr
int
embeddingSize
=
20
;
constexpr
int
nbClasses
=
15
;
constexpr
int
nbWordsPerDatapoint
=
5
;
constexpr
int
maxNbEmbeddings
=
1000000
;
int
batchSize
=
100
;
auto
dataLoader
=
torch
::
data
::
make_data_loader
(
std
::
move
(
dataset
),
torch
::
data
::
DataLoaderOptions
(
batchSize
).
workers
(
0
).
max_jobs
(
0
));
TestNetwork
nn
(
machine
.
getTransitionSet
().
size
(),
5
);
torch
::
optim
::
Adam
optimizer
(
nn
->
parameters
(),
torch
::
optim
::
AdamOptions
(
2e-4
).
beta1
(
0.5
));
for
(
int
epoch
=
1
;
epoch
<=
1
;
++
epoch
)
//3m15s
struct
NetworkImpl
:
torch
::
nn
::
Module
{
torch
::
nn
::
Linear
linear
{
nullptr
};
torch
::
nn
::
Embedding
wordEmbeddings
{
nullptr
};
NetworkImpl
()
{
float
totalLoss
=
0.0
;
torch
::
Tensor
example
;
int
currentBatchNumber
=
0
;
for
(
auto
&
batch
:
*
dataLoader
)
{
optimizer
.
zero_grad
();
auto
data
=
batch
.
data
;
auto
labels
=
batch
.
target
.
squeeze
();
auto
prediction
=
nn
(
data
);
example
=
prediction
[
0
];
auto
loss
=
torch
::
nll_loss
(
torch
::
log
(
prediction
),
labels
);
totalLoss
+=
loss
.
item
<
float
>
();
loss
.
backward
();
optimizer
.
step
();
linear
=
register_module
(
"linear"
,
torch
::
nn
::
Linear
(
embeddingSize
,
nbClasses
));
wordEmbeddings
=
register_module
(
"word_embeddings"
,
torch
::
nn
::
Embedding
(
torch
::
nn
::
EmbeddingOptions
(
maxNbEmbeddings
,
embeddingSize
).
sparse
(
false
)));
};
torch
::
Tensor
forward
(
const
torch
::
Tensor
&
input
)
{
// I have a batch of sentences (list of word embeddings), so as the sentence embedding I take the mean of the embedding of its words
auto
embeddingsOfInput
=
wordEmbeddings
(
input
).
mean
(
1
);
return
torch
::
softmax
(
linear
(
embeddingsOfInput
),
1
);
}
};
TORCH_MODULE
(
Network
);
if
(
++
currentBatchNumber
*
batchSize
%
1000
==
0
)
{
fmt
::
print
(
"
\r
current epoch : {:6.2f}%"
,
100.0
*
currentBatchNumber
*
batchSize
/
nbExamples
);
std
::
fflush
(
stdout
);
}
}
int
main
(
int
argc
,
char
*
argv
[])
{
auto
nn
=
Network
();
torch
::
optim
::
Adam
optimizer
(
nn
->
parameters
(),
torch
::
optim
::
AdamOptions
(
2e-4
).
beta1
(
0.5
).
weight_decay
(
0.1
));
std
::
vector
<
std
::
pair
<
torch
::
Tensor
,
torch
::
Tensor
>>
batches
;
for
(
int
nbBatch
=
0
;
nbBatch
<
nbExamples
/
batchSize
;
++
nbBatch
)
batches
.
emplace_back
(
std
::
make_pair
(
torch
::
randint
(
maxNbEmbeddings
,{
batchSize
,
nbWordsPerDatapoint
},
at
::
kLong
),
torch
::
randint
(
nbClasses
,
batchSize
,
at
::
kLong
)));
fmt
::
print
(
"Epoch {} : loss={:.2f}
\n
"
,
epoch
,
totalLoss
);
for
(
auto
&
batch
:
batches
)
{
optimizer
.
zero_grad
();
auto
prediction
=
nn
(
batch
.
first
);
auto
loss
=
torch
::
nll_loss
(
torch
::
log
(
prediction
),
batch
.
second
);
loss
.
backward
();
optimizer
.
step
();
}
return
0
;
}
//int main(int argc, char * argv[])
//{
// if (argc != 5)
// {
// fmt::print(stderr, "needs 4 arguments.\n");
// exit(1);
// }
//
// at::init_num_threads();
//
// std::string machineFile = argv[1];
// std::string mcdFile = argv[2];
// std::string tsvFile = argv[3];
// //std::string rawFile = argv[4];
// std::string rawFile = "";
//
// ReadingMachine machine(machineFile);
//
// BaseConfig goldConfig(mcdFile, tsvFile, rawFile);
// SubConfig config(goldConfig);
//
// config.setState(machine.getStrategy().getInitialState());
//
// std::vector<torch::Tensor> contexts;
// std::vector<torch::Tensor> classes;
//
// fmt::print("Generating dataset...\n");
//
// Dict dict(Dict::State::Open);
//
// while (true)
// {
// auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
// if (!transition)
// util::myThrow("No transition appliable !");
//
// auto context = config.extractContext(5,5,dict);
// contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());
//
// int goldIndex = 3;
// auto gold = torch::from_blob(&goldIndex, {1}, at::kLong).clone();
//
// classes.emplace_back(gold);
//
// transition->apply(config);
// config.addToHistory(transition->getName());
//
// auto movement = machine.getStrategy().getMovement(config, transition->getName());
// if (movement == Strategy::endMovement)
// break;
//
// config.setState(movement.first);
// if (!config.moveWordIndex(movement.second))
// util::myThrow("Cannot move word index !");
//
// if (config.needsUpdate())
// config.update();
// }
//
// auto dataset = ConfigDataset(contexts, classes).map(torch::data::transforms::Stack<>());
//
// int nbExamples = *dataset.size();
// fmt::print("Done! size={}\n", nbExamples);
//
// int batchSize = 100;
// auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
//
// TestNetwork nn(machine.getTransitionSet().size(), 5);
// torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
//
// for (int epoch = 1; epoch <= 1; ++epoch)
// {
// float totalLoss = 0.0;
// torch::Tensor example;
// int currentBatchNumber = 0;
//
// for (auto & batch : *dataLoader)
// {
// optimizer.zero_grad();
//
// auto data = batch.data;
// auto labels = batch.target.squeeze();
//
// auto prediction = nn(data);
// example = prediction[0];
//
// auto loss = torch::nll_loss(torch::log(prediction), labels);
// totalLoss += loss.item<float>();
// loss.backward();
// optimizer.step();
//
// if (++currentBatchNumber*batchSize % 1000 == 0)
// {
// fmt::print("\rcurrent epoch : {:6.2f}%", 100.0*currentBatchNumber*batchSize/nbExamples);
// std::fflush(stdout);
// }
// }
//
// fmt::print("Epoch {} : loss={:.2f}\n", epoch, totalLoss);
// }
//
// return 0;
//}
//
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment