Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Franck Dary
macaon
Commits
b13669bd
Commit
b13669bd
authored
Aug 04, 2020
by
Franck Dary
Browse files
Added program arguments : scaleGrad and maxNorm
parent
397e390f
Changes
23
Hide whitespace changes
Inline
Side-by-side
torch_modules/src/Submodule.cpp
View file @
b13669bd
...
...
@@ -5,7 +5,7 @@ void Submodule::setFirstInputIndex(std::size_t firstInputIndex)
this
->
firstInputIndex
=
firstInputIndex
;
}
void
Submodule
::
loadPretrainedW2vEmbeddings
(
torch
::
nn
::
Embedding
&
embeddings
,
std
::
filesystem
::
path
path
,
std
::
string
prefix
)
void
Submodule
::
loadPretrainedW2vEmbeddings
(
torch
::
nn
::
Embedding
embeddings
,
std
::
filesystem
::
path
path
,
std
::
string
prefix
)
{
if
(
path
.
empty
())
return
;
...
...
torch_modules/src/WordEmbeddings.cpp
0 → 100644
View file @
b13669bd
#include
"WordEmbeddings.hpp"
bool
WordEmbeddingsImpl
::
scaleGradByFreq
=
false
;
float
WordEmbeddingsImpl
::
maxNorm
=
std
::
numeric_limits
<
float
>::
max
();
WordEmbeddingsImpl
::
WordEmbeddingsImpl
(
std
::
size_t
vocab
,
std
::
size_t
dim
)
{
embeddings
=
register_module
(
"embeddings"
,
torch
::
nn
::
Embedding
(
torch
::
nn
::
EmbeddingOptions
(
vocab
,
dim
).
max_norm
(
maxNorm
).
scale_grad_by_freq
(
scaleGradByFreq
)));
}
torch
::
nn
::
Embedding
WordEmbeddingsImpl
::
get
()
{
return
embeddings
;
}
void
WordEmbeddingsImpl
::
setScaleGradByFreq
(
bool
scaleGradByFreq
)
{
WordEmbeddingsImpl
::
scaleGradByFreq
=
scaleGradByFreq
;
}
void
WordEmbeddingsImpl
::
setMaxNorm
(
float
maxNorm
)
{
WordEmbeddingsImpl
::
maxNorm
=
maxNorm
;
}
torch
::
Tensor
WordEmbeddingsImpl
::
forward
(
torch
::
Tensor
input
)
{
return
embeddings
(
input
);
}
trainer/src/MacaonTrain.cpp
View file @
b13669bd
...
...
@@ -2,6 +2,7 @@
#include
<filesystem>
#include
"util.hpp"
#include
"NeuralNetwork.hpp"
#include
"WordEmbeddings.hpp"
namespace
po
=
boost
::
program_options
;
...
...
@@ -43,6 +44,9 @@ po::options_description MacaonTrain::getOptionsDescription()
"Loss function to use during training : CrossEntropy | bce | mse | hinge"
)
(
"seed"
,
po
::
value
<
int
>
()
->
default_value
(
100
),
"Number of examples per batch"
)
(
"scaleGrad"
,
"Scale embedding's gradient with its frequence in the minibatch"
)
(
"maxNorm"
,
po
::
value
<
float
>
()
->
default_value
(
std
::
numeric_limits
<
float
>::
max
()),
"Max norm for the embeddings"
)
(
"help,h"
,
"Produce this help message"
);
desc
.
add
(
req
).
add
(
opt
);
...
...
@@ -134,6 +138,8 @@ int MacaonTrain::main()
auto
lossFunction
=
variables
[
"loss"
].
as
<
std
::
string
>
();
auto
explorationThreshold
=
variables
[
"explorationThreshold"
].
as
<
float
>
();
auto
seed
=
variables
[
"seed"
].
as
<
int
>
();
WordEmbeddingsImpl
::
setMaxNorm
(
variables
[
"maxNorm"
].
as
<
float
>
());
WordEmbeddingsImpl
::
setScaleGradByFreq
(
variables
.
count
(
"scaleGrad"
)
!=
0
);
std
::
srand
(
seed
);
torch
::
manual_seed
(
seed
);
...
...
Prev
1
2
Next
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment