Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
M
macaon
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container registry
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Franck Dary
macaon
Commits
d4ec0a24
Commit
d4ec0a24
authored
5 years ago
by
Franck Dary
Browse files
Options
Downloads
Patches
Plain Diff
Made SparseAdam optimizer
parent
be63334b
No related branches found
No related tags found
No related merge requests found
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
dev/src/dev.cpp
+240
-72
240 additions, 72 deletions
dev/src/dev.cpp
with
240 additions
and
72 deletions
dev/src/dev.cpp
+
240
−
72
View file @
d4ec0a24
...
...
@@ -8,105 +8,273 @@
#include
"TestNetwork.hpp"
#include
"ConfigDataset.hpp"
int
main
(
int
argc
,
char
*
argv
[])
namespace
torch
{
if
(
argc
!=
5
)
namespace
optim
{
fmt
::
print
(
stderr
,
"needs 4 arguments.
\n
"
);
exit
(
1
);
}
std
::
string
machineFile
=
argv
[
1
];
std
::
string
mcdFile
=
argv
[
2
];
std
::
string
tsvFile
=
argv
[
3
];
//std::string rawFile = argv[4];
std
::
string
rawFile
=
""
;
ReadingMachine
machine
(
machineFile
);
class
SparseAdam
:
public
Optimizer
{
public:
BaseConfig
goldConfig
(
mcdFile
,
tsvFile
,
rawFile
);
SubConfig
config
(
goldConfig
);
template
<
typename
ParameterContainer
>
explicit
SparseAdam
(
ParameterContainer
&&
parameters
,
const
AdamOptions
&
options_
)
:
Optimizer
(
std
::
forward
<
ParameterContainer
>
(
parameters
)),
options
(
options_
)
{
}
config
.
setState
(
machine
.
getStrategy
().
getInitialState
());
void
step
()
override
{
for
(
size_t
i
=
0
;
i
<
parameters_
.
size
();
++
i
)
{
Tensor
p
=
parameters_
.
at
(
i
);
if
(
!
p
.
grad
().
defined
())
continue
;
std
::
vector
<
torch
::
Tensor
>
contexts
;
std
::
vector
<
torch
::
Tensor
>
classes
;
auto
&
exp_average
=
buffer_at
(
exp_average_buffers
,
i
)
;
auto
&
exp_average_sq
=
buffer_at
(
exp_average_sq_buffers
,
i
)
;
fmt
::
print
(
"Generating dataset...
\n
"
);
buffer_at
(
step_buffers
,
i
)
+=
1
;
const
auto
bias_correction1
=
1
-
std
::
pow
(
options
.
beta1
(),
buffer_at
(
step_buffers
,
i
));
const
auto
bias_correction2
=
1
-
std
::
pow
(
options
.
beta2
(),
buffer_at
(
step_buffers
,
i
));
if
(
p
.
grad
().
is_sparse
())
{
NoGradGuard
guard
;
p
.
grad
()
=
p
.
grad
().
coalesce
();
auto
indices
=
p
.
grad
().
indices
().
squeeze
();
auto
values
=
p
.
grad
().
values
();
Dict
dict
(
Dict
::
State
::
Open
);
auto
old_exp_average_values
=
exp_average
.
sparse_mask
(
p
.
grad
()).
_values
();
auto
exp_average_update_values
=
values
.
sub
(
old_exp_average_values
).
mul_
(
1
-
options
.
beta1
());
for
(
unsigned
int
j
=
0
;
j
<
indices
.
size
(
0
);
j
++
)
exp_average
[
indices
[
j
].
item
<
long
>
()]
+=
exp_average_update_values
[
j
];
auto
old_exp_average_sq_values
=
exp_average_sq
.
sparse_mask
(
p
.
grad
()).
_values
();
auto
exp_average_sq_update_values
=
values
.
pow
(
2
).
sub_
(
old_exp_average_sq_values
).
mul_
(
1
-
options
.
beta2
());
for
(
unsigned
int
j
=
0
;
j
<
indices
.
size
(
0
);
j
++
)
exp_average_sq
[
indices
[
j
].
item
<
long
>
()]
+=
exp_average_sq_update_values
[
j
];
while
(
true
)
auto
numer
=
exp_average_update_values
.
add_
(
old_exp_average_values
);
exp_average_sq_update_values
.
add_
(
old_exp_average_sq_values
);
auto
denom
=
exp_average_sq_update_values
.
sqrt_
().
add_
(
options
.
eps
());
const
auto
step_size
=
options
.
learning_rate
()
*
std
::
sqrt
(
bias_correction2
)
/
bias_correction1
;
auto
divided
=
numer
.
div
(
denom
);
for
(
unsigned
int
j
=
0
;
j
<
indices
.
size
(
0
);
j
++
)
p
.
data
()[
indices
[
j
].
item
<
long
>
()]
+=
-
step_size
*
divided
[
j
];
}
else
{
auto
*
transition
=
machine
.
getTransitionSet
().
getBestAppliableTransition
(
config
);
if
(
!
transition
)
util
::
myThrow
(
"No transition appliable !"
);
auto
context
=
config
.
extractContext
(
5
,
5
,
dict
);
contexts
.
push_back
(
torch
::
from_blob
(
context
.
data
(),
{(
long
)
context
.
size
()},
at
::
kLong
).
clone
());
if
(
options
.
weight_decay
()
>
0
)
{
NoGradGuard
guard
;
p
.
grad
()
=
p
.
grad
()
+
options
.
weight_decay
()
*
p
;
}
int
goldIndex
=
3
;
auto
gold
=
torch
::
from_blob
(
&
goldIndex
,
{
1
},
at
::
kLong
).
clone
(
);
exp_average
.
mul_
(
options
.
beta1
()).
add_
(
p
.
grad
(),
1
-
options
.
beta1
())
;
exp_average_sq
.
mul_
(
options
.
beta2
()).
addcmul_
(
p
.
grad
(),
p
.
grad
(),
1
-
options
.
beta2
()
);
classes
.
emplace_back
(
gold
);
Tensor
denom
;
if
(
options
.
amsgrad
())
{
auto
&
max_exp_average_sq
=
buffer_at
(
max_exp_average_sq_buffers
,
i
);
max_exp_average_sq
=
torch
::
max
(
max_exp_average_sq
,
exp_average_sq
);
denom
=
max_exp_average_sq
/
bias_correction2
;
}
else
{
denom
=
exp_average_sq
/
bias_correction2
;
}
transition
->
apply
(
config
);
config
.
addToHistory
(
transition
->
getName
());
const
auto
step_size
=
options
.
learning_rate
()
/
bias_correction1
;
auto
movement
=
machine
.
getStrategy
().
getMovement
(
config
,
transition
->
getName
());
if
(
movement
==
Strategy
::
endMovement
)
break
;
NoGradGuard
guard
;
p
.
addcdiv_
(
exp_average
,
denom
.
sqrt
()
+
options
.
eps
(),
-
step_size
);
}
}
}
config
.
setState
(
movement
.
first
);
if
(
!
config
.
moveWordIndex
(
movement
.
second
))
util
::
myThrow
(
"Cannot move word index !"
);
void
save
(
serialize
::
OutputArchive
&
archive
)
const
override
{
//serialize(*this, archive)
;
}
if
(
config
.
needsUpdate
())
config
.
update
();
void
load
(
serialize
::
InputArchive
&
archive
)
override
{
//serialize(*this, archive)
;
}
auto
dataset
=
ConfigDataset
(
contexts
,
classes
).
map
(
torch
::
data
::
transforms
::
Stack
<>
());
public
:
int
nbExamples
=
*
dataset
.
size
();
fmt
::
print
(
"Done! size={}
\n
"
,
nbExamples
);
AdamOptions
options
;
int
batchSize
=
100
;
auto
dataLoader
=
torch
::
data
::
make_data_loader
(
std
::
move
(
dataset
),
torch
::
data
::
DataLoaderOptions
(
batchSize
).
workers
(
0
).
max_jobs
(
0
));
std
::
vector
<
int64_t
>
step_buffers
;
std
::
vector
<
Tensor
>
exp_average_buffers
;
std
::
vector
<
Tensor
>
exp_average_sq_buffers
;
std
::
vector
<
Tensor
>
max_exp_average_sq_buffers
;
TestNetwork
nn
(
machine
.
getTransitionSet
().
size
(),
5
);
torch
::
optim
::
Adam
optimizer
(
nn
->
parameters
(),
torch
::
optim
::
AdamOptions
(
2e-4
).
beta1
(
0.5
));
private
:
for
(
int
epoch
=
1
;
epoch
<=
1
;
++
epoch
)
{
float
totalLoss
=
0.0
;
torch
::
Tensor
example
;
int
currentBatchNumber
=
0
;
SparseAdam
()
:
options
(
0
)
{}
for
(
auto
&
batch
:
*
dataLoader
)
template
<
typename
Self
,
typename
Archive
>
static
void
serialize
(
Self
&
self
,
Archive
&
archive
)
{
optimizer
.
zero_grad
();
_TORCH_OPTIM_SERIALIZE
(
step_buffers
);
_TORCH_OPTIM_SERIALIZE
(
exp_average_buffers
);
_TORCH_OPTIM_SERIALIZE
(
exp_average_sq_buffers
);
_TORCH_OPTIM_SERIALIZE
(
max_exp_average_sq_buffers
);
}
};
auto
data
=
batch
.
data
;
auto
labels
=
batch
.
target
.
squeeze
();
}
// torch
}
// optim
auto
prediction
=
nn
(
data
);
example
=
prediction
[
0
];
constexpr
int
batchSize
=
50
;
constexpr
int
nbExamples
=
350000
;
constexpr
int
embeddingSize
=
20
;
constexpr
int
nbClasses
=
15
;
constexpr
int
nbWordsPerDatapoint
=
5
;
constexpr
int
maxNbEmbeddings
=
1000000
;
auto
loss
=
torch
::
nll_loss
(
torch
::
log
(
prediction
),
labels
);
totalLoss
+=
loss
.
item
<
float
>
();
loss
.
backward
();
optimizer
.
step
();
if
(
++
currentBatchNumber
*
batchSize
%
1000
==
0
)
//3m15s
struct
NetworkImpl
:
torch
::
nn
::
Module
{
fmt
::
print
(
"
\r
current epoch : {:6.2f}%"
,
100.0
*
currentBatchNumber
*
batchSize
/
nbExamples
);
std
::
fflush
(
stdout
);
}
torch
::
nn
::
Linear
linear
{
nullptr
};
torch
::
nn
::
Embedding
wordEmbeddings
{
nullptr
};
NetworkImpl
()
{
linear
=
register_module
(
"linear"
,
torch
::
nn
::
Linear
(
embeddingSize
,
nbClasses
));
wordEmbeddings
=
register_module
(
"word_embeddings"
,
torch
::
nn
::
Embedding
(
torch
::
nn
::
EmbeddingOptions
(
maxNbEmbeddings
,
embeddingSize
).
sparse
(
false
)));
};
torch
::
Tensor
forward
(
const
torch
::
Tensor
&
input
)
{
// I have a batch of sentences (list of word embeddings), so as the sentence embedding I take the mean of the embedding of its words
auto
embeddingsOfInput
=
wordEmbeddings
(
input
).
mean
(
1
);
return
torch
::
softmax
(
linear
(
embeddingsOfInput
),
1
);
}
};
TORCH_MODULE
(
Network
);
fmt
::
print
(
"Epoch {} : loss={:.2f}
\n
"
,
epoch
,
totalLoss
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
auto
nn
=
Network
();
torch
::
optim
::
Adam
optimizer
(
nn
->
parameters
(),
torch
::
optim
::
AdamOptions
(
2e-4
).
beta1
(
0.5
).
weight_decay
(
0.1
));
std
::
vector
<
std
::
pair
<
torch
::
Tensor
,
torch
::
Tensor
>>
batches
;
for
(
int
nbBatch
=
0
;
nbBatch
<
nbExamples
/
batchSize
;
++
nbBatch
)
batches
.
emplace_back
(
std
::
make_pair
(
torch
::
randint
(
maxNbEmbeddings
,{
batchSize
,
nbWordsPerDatapoint
},
at
::
kLong
),
torch
::
randint
(
nbClasses
,
batchSize
,
at
::
kLong
)));
for
(
auto
&
batch
:
batches
)
{
optimizer
.
zero_grad
();
auto
prediction
=
nn
(
batch
.
first
);
auto
loss
=
torch
::
nll_loss
(
torch
::
log
(
prediction
),
batch
.
second
);
loss
.
backward
();
optimizer
.
step
();
}
return
0
;
}
//int main(int argc, char * argv[])
//{
// if (argc != 5)
// {
// fmt::print(stderr, "needs 4 arguments.\n");
// exit(1);
// }
//
// at::init_num_threads();
//
// std::string machineFile = argv[1];
// std::string mcdFile = argv[2];
// std::string tsvFile = argv[3];
// //std::string rawFile = argv[4];
// std::string rawFile = "";
//
// ReadingMachine machine(machineFile);
//
// BaseConfig goldConfig(mcdFile, tsvFile, rawFile);
// SubConfig config(goldConfig);
//
// config.setState(machine.getStrategy().getInitialState());
//
// std::vector<torch::Tensor> contexts;
// std::vector<torch::Tensor> classes;
//
// fmt::print("Generating dataset...\n");
//
// Dict dict(Dict::State::Open);
//
// while (true)
// {
// auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
// if (!transition)
// util::myThrow("No transition appliable !");
//
// auto context = config.extractContext(5,5,dict);
// contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());
//
// int goldIndex = 3;
// auto gold = torch::from_blob(&goldIndex, {1}, at::kLong).clone();
//
// classes.emplace_back(gold);
//
// transition->apply(config);
// config.addToHistory(transition->getName());
//
// auto movement = machine.getStrategy().getMovement(config, transition->getName());
// if (movement == Strategy::endMovement)
// break;
//
// config.setState(movement.first);
// if (!config.moveWordIndex(movement.second))
// util::myThrow("Cannot move word index !");
//
// if (config.needsUpdate())
// config.update();
// }
//
// auto dataset = ConfigDataset(contexts, classes).map(torch::data::transforms::Stack<>());
//
// int nbExamples = *dataset.size();
// fmt::print("Done! size={}\n", nbExamples);
//
// int batchSize = 100;
// auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
//
// TestNetwork nn(machine.getTransitionSet().size(), 5);
// torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
//
// for (int epoch = 1; epoch <= 1; ++epoch)
// {
// float totalLoss = 0.0;
// torch::Tensor example;
// int currentBatchNumber = 0;
//
// for (auto & batch : *dataLoader)
// {
// optimizer.zero_grad();
//
// auto data = batch.data;
// auto labels = batch.target.squeeze();
//
// auto prediction = nn(data);
// example = prediction[0];
//
// auto loss = torch::nll_loss(torch::log(prediction), labels);
// totalLoss += loss.item<float>();
// loss.backward();
// optimizer.step();
//
// if (++currentBatchNumber*batchSize % 1000 == 0)
// {
// fmt::print("\rcurrent epoch : {:6.2f}%", 100.0*currentBatchNumber*batchSize/nbExamples);
// std::fflush(stdout);
// }
// }
//
// fmt::print("Epoch {} : loss={:.2f}\n", epoch, totalLoss);
// }
//
// return 0;
//}
//
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment