Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Franck Dary
macaon
Commits
5b723ac5
Commit
5b723ac5
authored
Oct 09, 2020
by
Franck Dary
Browse files
Added program argument to lock pretrained embeddings
parent
032ca410
Changes
4
Hide whitespace changes
Inline
Side-by-side
torch_modules/include/WordEmbeddings.hpp
View file @
5b723ac5
...
...
@@ -8,6 +8,7 @@ class WordEmbeddingsImpl : public torch::nn::Module
private
:
static
bool
scaleGradByFreq
;
static
bool
canTrainPretrained
;
static
float
maxNorm
;
private
:
...
...
@@ -18,6 +19,8 @@ class WordEmbeddingsImpl : public torch::nn::Module
static
void
setScaleGradByFreq
(
bool
scaleGradByFreq
);
static
void
setMaxNorm
(
float
maxNorm
);
static
void
setCanTrainPretrained
(
bool
value
);
static
bool
getCanTrainPretrained
();
WordEmbeddingsImpl
(
std
::
size_t
vocab
,
std
::
size_t
dim
);
torch
::
nn
::
Embedding
get
();
...
...
torch_modules/src/Submodule.cpp
View file @
5b723ac5
#include "Submodule.hpp"
#include "WordEmbeddings.hpp"
void
Submodule
::
setFirstInputIndex
(
std
::
size_t
firstInputIndex
)
{
...
...
@@ -74,6 +75,7 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std
util
::
myThrow
(
fmt
::
format
(
"file '{}' is empty"
,
path
.
string
()));
getDict
().
setState
(
originalState
);
embeddings
->
weight
.
set_requires_grad
(
WordEmbeddingsImpl
::
getCanTrainPretrained
());
}
std
::
function
<
std
::
string
(
const
std
::
string
&
)
>
Submodule
::
getFunction
(
const
std
::
string
functionNames
)
...
...
torch_modules/src/WordEmbeddings.cpp
View file @
5b723ac5
#include "WordEmbeddings.hpp"
bool
WordEmbeddingsImpl
::
scaleGradByFreq
=
false
;
bool
WordEmbeddingsImpl
::
canTrainPretrained
=
false
;
float
WordEmbeddingsImpl
::
maxNorm
=
std
::
numeric_limits
<
float
>::
max
();
WordEmbeddingsImpl
::
WordEmbeddingsImpl
(
std
::
size_t
vocab
,
std
::
size_t
dim
)
...
...
@@ -23,8 +24,18 @@ void WordEmbeddingsImpl::setMaxNorm(float maxNorm)
WordEmbeddingsImpl
::
maxNorm
=
maxNorm
;
}
void
WordEmbeddingsImpl
::
setCanTrainPretrained
(
bool
value
)
{
WordEmbeddingsImpl
::
canTrainPretrained
=
value
;
}
torch
::
Tensor
WordEmbeddingsImpl
::
forward
(
torch
::
Tensor
input
)
{
return
embeddings
(
input
);
}
bool
WordEmbeddingsImpl
::
getCanTrainPretrained
()
{
return
canTrainPretrained
;
}
trainer/src/MacaonTrain.cpp
View file @
5b723ac5
...
...
@@ -45,6 +45,7 @@ po::options_description MacaonTrain::getOptionsDescription()
(
"scaleGrad"
,
"Scale embedding's gradient with its frequence in the minibatch"
)
(
"maxNorm"
,
po
::
value
<
float
>
()
->
default_value
(
std
::
numeric_limits
<
float
>::
max
()),
"Max norm for the embeddings"
)
(
"lockPretrained"
,
"Disable fine tuning of all pretrained word embeddings."
)
(
"help,h"
,
"Produce this help message"
);
desc
.
add
(
req
).
add
(
opt
);
...
...
@@ -137,6 +138,7 @@ int MacaonTrain::main()
auto
seed
=
variables
[
"seed"
].
as
<
int
>
();
WordEmbeddingsImpl
::
setMaxNorm
(
variables
[
"maxNorm"
].
as
<
float
>
());
WordEmbeddingsImpl
::
setScaleGradByFreq
(
variables
.
count
(
"scaleGrad"
)
!=
0
);
WordEmbeddingsImpl
::
setCanTrainPretrained
(
variables
.
count
(
"lockPretrained"
)
==
0
);
std
::
srand
(
seed
);
torch
::
manual_seed
(
seed
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment