Skip to content
GitLab
Menu
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Franck Dary
macaon
Commits
5800a6f3
Commit
5800a6f3
authored
Oct 10, 2021
by
Franck Dary
Browse files
Special embeddings can be trained even with lockPretrained
parent
31016256
Changes
14
Hide whitespace changes
Inline
Side-by-side
common/include/Dict.hpp
View file @
5800a6f3
...
...
@@ -6,6 +6,7 @@
#include
<vector>
#include
<filesystem>
#include
<mutex>
#include
<set>
class
Dict
{
...
...
@@ -34,6 +35,7 @@ class Dict
std
::
mutex
elementsMutex
;
State
state
;
bool
isCountingOccs
{
false
};
std
::
set
<
std
::
string
>
prefixes
{
""
};
public
:
...
...
@@ -50,6 +52,7 @@ class Dict
public
:
void
countOcc
(
bool
isCountingOccs
);
std
::
set
<
std
::
size_t
>
getSpecialIndexes
();
int
getIndexOrInsert
(
const
std
::
string
&
element
,
const
std
::
string
&
prefix
);
std
::
string
getElement
(
std
::
size_t
index
);
void
setState
(
State
state
);
...
...
common/src/Dict.cpp
View file @
5800a6f3
...
...
@@ -94,6 +94,7 @@ int Dict::getIndexOrInsert(const std::string & element, const std::string & pref
if
(
state
==
State
::
Open
)
elementsMutex
.
lock
();
prefixes
.
insert
(
prefix
);
int
index
=
_getIndexOrInsert
(
element
,
prefix
);
if
(
state
==
State
::
Open
)
...
...
@@ -350,6 +351,28 @@ bool Dict::isSpecialValue(const std::string & value)
||
value
==
urlValueStr
;
}
std
::
set
<
std
::
size_t
>
Dict
::
getSpecialIndexes
()
{
auto
oldState
=
getState
();
setState
(
State
::
Closed
);
std
::
set
<
std
::
string
>
specials
=
{
unknownValueStr
,
nullValueStr
,
oobValueStr
,
noChildValueStr
,
emptyValueStr
,
separatorValueStr
,
numberValueStr
,
urlValueStr
,
};
std
::
set
<
std
::
size_t
>
res
;
for
(
auto
&
prefix
:
prefixes
)
for
(
auto
&
special
:
specials
)
res
.
insert
(
getIndexOrInsert
(
special
,
prefix
));
setState
(
oldState
);
return
res
;
}
std
::
string
Dict
::
getElement
(
std
::
size_t
index
)
{
return
indexesToElements
[
index
];
...
...
torch_modules/include/WordEmbeddings.hpp
View file @
5800a6f3
...
...
@@ -13,7 +13,8 @@ class WordEmbeddingsImpl : public torch::nn::Module
private
:
torch
::
nn
::
Embedding
embeddings
{
nullptr
};
torch
::
nn
::
Embedding
normalEmbeddings
{
nullptr
};
torch
::
nn
::
Embedding
specialEmbeddings
{
nullptr
};
public
:
...
...
@@ -22,8 +23,8 @@ class WordEmbeddingsImpl : public torch::nn::Module
static
void
setCanTrainPretrained
(
bool
value
);
static
bool
getCanTrainPretrained
();
WordEmbeddingsImpl
(
std
::
size_t
vocab
,
std
::
size_t
dim
);
torch
::
nn
::
Embedding
get
();
WordEmbeddingsImpl
(
std
::
size_t
vocab
,
std
::
size_t
dim
,
std
::
set
<
std
::
size_t
>
specialIndexes
);
torch
::
nn
::
Embedding
get
NormalEmbeddings
();
torch
::
Tensor
forward
(
torch
::
Tensor
input
);
};
TORCH_MODULE
(
WordEmbeddings
);
...
...
torch_modules/src/ContextModule.cpp
View file @
5800a6f3
...
...
@@ -187,12 +187,12 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input)
void
ContextModuleImpl
::
registerEmbeddings
()
{
if
(
!
wordEmbeddings
)
wordEmbeddings
=
register_module
(
"embeddings"
,
WordEmbeddings
(
getDict
().
size
(),
inSize
));
wordEmbeddings
=
register_module
(
"embeddings"
,
WordEmbeddings
(
getDict
().
size
(),
inSize
,
w2vFiles
.
empty
()
?
std
::
set
<
std
::
size_t
>
()
:
getDict
().
getSpecialIndexes
()
));
auto
pathes
=
util
::
split
(
w2vFiles
.
string
(),
' '
);
for
(
auto
&
p
:
pathes
)
{
auto
splited
=
util
::
split
(
p
,
','
);
loadPretrainedW2vEmbeddings
(
wordEmbeddings
->
get
(),
path
/
splited
[
1
],
splited
[
0
]);
loadPretrainedW2vEmbeddings
(
wordEmbeddings
->
get
NormalEmbeddings
(),
path
/
splited
[
1
],
splited
[
0
]);
}
}
torch_modules/src/ContextualModule.cpp
View file @
5800a6f3
...
...
@@ -234,13 +234,13 @@ torch::Tensor ContextualModuleImpl::forward(torch::Tensor input)
void
ContextualModuleImpl
::
registerEmbeddings
()
{
if
(
!
wordEmbeddings
)
wordEmbeddings
=
register_module
(
"embeddings"
,
WordEmbeddings
(
getDict
().
size
(),
inSize
));
wordEmbeddings
=
register_module
(
"embeddings"
,
WordEmbeddings
(
getDict
().
size
(),
inSize
,
w2vFiles
.
empty
()
?
std
::
set
<
std
::
size_t
>
()
:
getDict
().
getSpecialIndexes
()
));
auto
pathes
=
util
::
split
(
w2vFiles
.
string
(),
' '
);
for
(
auto
&
p
:
pathes
)
{
auto
splited
=
util
::
split
(
p
,
','
);
loadPretrainedW2vEmbeddings
(
wordEmbeddings
->
get
(),
path
/
splited
[
1
],
splited
[
0
]);
loadPretrainedW2vEmbeddings
(
wordEmbeddings
->
get
NormalEmbeddings
(),
path
/
splited
[
1
],
splited
[
0
]);
}
}
torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
View file @
5800a6f3
...
...
@@ -131,6 +131,6 @@ void DepthLayerTreeEmbeddingModuleImpl::addToContext(torch::Tensor & context, co
void
DepthLayerTreeEmbeddingModuleImpl
::
registerEmbeddings
()
{
if
(
!
wordEmbeddings
)
wordEmbeddings
=
register_module
(
"embeddings"
,
WordEmbeddings
(
getDict
().
size
(),
inSize
));
wordEmbeddings
=
register_module
(
"embeddings"
,
WordEmbeddings
(
getDict
().
size
(),
inSize
,
std
::
set
<
std
::
size_t
>
()
));
}
torch_modules/src/DistanceModule.cpp
View file @
5800a6f3
...
...
@@ -113,6 +113,6 @@ void DistanceModuleImpl::addToContext(torch::Tensor & context, const Config & co
void
DistanceModuleImpl
::
registerEmbeddings
()
{
if
(
!
wordEmbeddings
)
wordEmbeddings
=
register_module
(
"embeddings"
,
WordEmbeddings
(
getDict
().
size
(),
inSize
));
wordEmbeddings
=
register_module
(
"embeddings"
,
WordEmbeddings
(
getDict
().
size
(),
inSize
,
std
::
set
<
std
::
size_t
>
()
));
}
torch_modules/src/FocusedColumnModule.cpp
View file @
5800a6f3
...
...
@@ -164,12 +164,12 @@ void FocusedColumnModuleImpl::addToContext(torch::Tensor & context, const Config
void
FocusedColumnModuleImpl
::
registerEmbeddings
()
{
if
(
!
wordEmbeddings
)
wordEmbeddings
=
register_module
(
"embeddings"
,
WordEmbeddings
(
getDict
().
size
(),
inSize
));
wordEmbeddings
=
register_module
(
"embeddings"
,
WordEmbeddings
(
getDict
().
size
(),
inSize
,
w2vFiles
.
empty
()
?
std
::
set
<
std
::
size_t
>
()
:
getDict
().
getSpecialIndexes
()
));
auto
pathes
=
util
::
split
(
w2vFiles
.
string
(),
' '
);
for
(
auto
&
p
:
pathes
)
{
auto
splited
=
util
::
split
(
p
,
','
);
loadPretrainedW2vEmbeddings
(
wordEmbeddings
->
get
(),
path
/
splited
[
1
],
splited
[
0
]);
loadPretrainedW2vEmbeddings
(
wordEmbeddings
->
get
NormalEmbeddings
(),
path
/
splited
[
1
],
splited
[
0
]);
}
}
torch_modules/src/HistoryMineModule.cpp
View file @
5800a6f3
...
...
@@ -69,6 +69,6 @@ void HistoryMineModuleImpl::addToContext(torch::Tensor & context, const Config &
void
HistoryMineModuleImpl
::
registerEmbeddings
()
{
if
(
!
wordEmbeddings
)
wordEmbeddings
=
register_module
(
"embeddings"
,
WordEmbeddings
(
getDict
().
size
(),
inSize
));
wordEmbeddings
=
register_module
(
"embeddings"
,
WordEmbeddings
(
getDict
().
size
(),
inSize
,
std
::
set
<
std
::
size_t
>
()
));
}
torch_modules/src/HistoryModule.cpp
View file @
5800a6f3
...
...
@@ -69,6 +69,6 @@ void HistoryModuleImpl::addToContext(torch::Tensor & context, const Config & con
void
HistoryModuleImpl
::
registerEmbeddings
()
{
if
(
!
wordEmbeddings
)
wordEmbeddings
=
register_module
(
"embeddings"
,
WordEmbeddings
(
getDict
().
size
(),
inSize
));
wordEmbeddings
=
register_module
(
"embeddings"
,
WordEmbeddings
(
getDict
().
size
(),
inSize
,
std
::
set
<
std
::
size_t
>
()
));
}
torch_modules/src/RawInputModule.cpp
View file @
5800a6f3
...
...
@@ -87,6 +87,6 @@ void RawInputModuleImpl::addToContext(torch::Tensor & context, const Config & co
void
RawInputModuleImpl
::
registerEmbeddings
()
{
if
(
!
wordEmbeddings
)
wordEmbeddings
=
register_module
(
"embeddings"
,
WordEmbeddings
(
getDict
().
size
(),
inSize
));
wordEmbeddings
=
register_module
(
"embeddings"
,
WordEmbeddings
(
getDict
().
size
(),
inSize
,
std
::
set
<
std
::
size_t
>
()
));
}
torch_modules/src/SplitTransModule.cpp
View file @
5800a6f3
...
...
@@ -65,6 +65,6 @@ void SplitTransModuleImpl::addToContext(torch::Tensor & context, const Config &
void
SplitTransModuleImpl
::
registerEmbeddings
()
{
if
(
!
wordEmbeddings
)
wordEmbeddings
=
register_module
(
"embeddings"
,
WordEmbeddings
(
getDict
().
size
(),
inSize
));
wordEmbeddings
=
register_module
(
"embeddings"
,
WordEmbeddings
(
getDict
().
size
(),
inSize
,
std
::
set
<
std
::
size_t
>
()
));
}
torch_modules/src/StateNameModule.cpp
View file @
5800a6f3
...
...
@@ -38,6 +38,6 @@ void StateNameModuleImpl::addToContext(torch::Tensor & context, const Config & c
void
StateNameModuleImpl
::
registerEmbeddings
()
{
if
(
!
embeddings
)
embeddings
=
register_module
(
"embeddings"
,
WordEmbeddings
(
getDict
().
size
(),
outSize
));
embeddings
=
register_module
(
"embeddings"
,
WordEmbeddings
(
getDict
().
size
(),
outSize
,
std
::
set
<
std
::
size_t
>
()
));
}
torch_modules/src/WordEmbeddings.cpp
View file @
5800a6f3
#include
"WordEmbeddings.hpp"
#include
"util.hpp"
#include
"NeuralNetwork.hpp"
bool
WordEmbeddingsImpl
::
scaleGradByFreq
=
false
;
bool
WordEmbeddingsImpl
::
canTrainPretrained
=
false
;
float
WordEmbeddingsImpl
::
maxNorm
=
std
::
numeric_limits
<
float
>::
max
();
WordEmbeddingsImpl
::
WordEmbeddingsImpl
(
std
::
size_t
vocab
,
std
::
size_t
dim
)
WordEmbeddingsImpl
::
WordEmbeddingsImpl
(
std
::
size_t
vocab
,
std
::
size_t
dim
,
std
::
set
<
std
::
size_t
>
specialIndexes
)
{
for
(
auto
elem
:
specialIndexes
)
if
(
elem
>=
specialIndexes
.
size
())
util
::
error
(
"Special indexes are not contiguous from zero."
);
if
(
maxNorm
==
std
::
numeric_limits
<
float
>::
max
())
embeddings
=
register_module
(
"embeddings"
,
torch
::
nn
::
Embedding
(
torch
::
nn
::
EmbeddingOptions
(
vocab
,
dim
).
scale_grad_by_freq
(
scaleGradByFreq
)));
{
normalEmbeddings
=
register_module
(
"normalEmbeddings"
,
torch
::
nn
::
Embedding
(
torch
::
nn
::
EmbeddingOptions
(
vocab
,
dim
).
scale_grad_by_freq
(
scaleGradByFreq
)));
specialEmbeddings
=
register_module
(
"specialEmbeddings"
,
torch
::
nn
::
Embedding
(
torch
::
nn
::
EmbeddingOptions
(
specialIndexes
.
size
(),
dim
).
scale_grad_by_freq
(
scaleGradByFreq
)));
}
else
embeddings
=
register_module
(
"embeddings"
,
torch
::
nn
::
Embedding
(
torch
::
nn
::
EmbeddingOptions
(
vocab
,
dim
).
max_norm
(
maxNorm
).
scale_grad_by_freq
(
scaleGradByFreq
)));
{
normalEmbeddings
=
register_module
(
"normalEmbeddings"
,
torch
::
nn
::
Embedding
(
torch
::
nn
::
EmbeddingOptions
(
vocab
,
dim
).
max_norm
(
maxNorm
).
scale_grad_by_freq
(
scaleGradByFreq
)));
specialEmbeddings
=
register_module
(
"specialEmbeddings"
,
torch
::
nn
::
Embedding
(
torch
::
nn
::
EmbeddingOptions
(
specialIndexes
.
size
(),
dim
).
scale_grad_by_freq
(
scaleGradByFreq
)));
}
}
torch
::
nn
::
Embedding
WordEmbeddingsImpl
::
get
()
torch
::
nn
::
Embedding
WordEmbeddingsImpl
::
get
NormalEmbeddings
()
{
return
e
mbeddings
;
return
normalE
mbeddings
;
}
void
WordEmbeddingsImpl
::
setScaleGradByFreq
(
bool
scaleGradByFreq
)
...
...
@@ -34,7 +45,19 @@ void WordEmbeddingsImpl::setCanTrainPretrained(bool value)
torch
::
Tensor
WordEmbeddingsImpl
::
forward
(
torch
::
Tensor
input
)
{
return
embeddings
(
input
);
if
(
specialEmbeddings
->
weight
.
size
(
0
)
==
0
)
return
normalEmbeddings
(
input
);
auto
mask
=
input
>=
specialEmbeddings
->
weight
.
size
(
0
);
auto
specialIndexes
=
torch
::
ones
(
input
.
sizes
(),
torch
::
TensorOptions
(
torch
::
kLong
).
device
(
NeuralNetworkImpl
::
getDevice
()));
specialIndexes
.
index_put_
({
mask
},
0
);
auto
normalRes
=
normalEmbeddings
(
input
);
auto
specialRes
=
specialEmbeddings
(
input
*
specialIndexes
);
auto
normalIndexes
=
torch
::
ones
(
normalRes
.
sizes
(),
torch
::
TensorOptions
(
torch
::
kLong
).
device
(
NeuralNetworkImpl
::
getDevice
()));
specialIndexes
=
torch
::
ones
(
specialRes
.
sizes
(),
torch
::
TensorOptions
(
torch
::
kLong
).
device
(
NeuralNetworkImpl
::
getDevice
()));
specialIndexes
.
index_put_
({
mask
},
0
);
normalIndexes
.
index_put_
({
~
mask
},
0
);
return
normalIndexes
*
normalRes
+
specialIndexes
*
specialRes
;
}
bool
WordEmbeddingsImpl
::
getCanTrainPretrained
()
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment