Skip to content
GitLab
Menu
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Franck Dary
macaon
Commits
675d8f42
Commit
675d8f42
authored
Jul 30, 2020
by
Franck Dary
Browse files
allow multiple pretrained embeddings file for ContextModule
parent
df3fd3cb
Changes
2
Hide whitespace changes
Inline
Side-by-side
torch_modules/include/ContextModule.hpp
View file @
675d8f42
...
...
@@ -21,7 +21,7 @@ class ContextModuleImpl : public Submodule
std
::
vector
<
std
::
tuple
<
Config
::
Object
,
int
,
std
::
optional
<
int
>>>
targets
;
int
inSize
;
std
::
filesystem
::
path
path
;
std
::
filesystem
::
path
w2vFile
;
std
::
filesystem
::
path
w2vFile
s
;
public
:
...
...
torch_modules/src/ContextModule.cpp
View file @
675d8f42
...
...
@@ -48,11 +48,13 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin
else
util
::
myThrow
(
fmt
::
format
(
"unknown sumodule type '{}'"
,
subModuleType
));
w2vFile
=
sm
.
str
(
7
);
w2vFile
s
=
sm
.
str
(
7
);
if
(
!
w2vFile
.
empty
())
if
(
!
w2vFile
s
.
empty
())
{
getDict
().
loadWord2Vec
(
this
->
path
/
w2vFile
);
auto
pathes
=
util
::
split
(
w2vFiles
.
string
(),
' '
);
for
(
auto
&
p
:
pathes
)
getDict
().
loadWord2Vec
(
this
->
path
/
p
);
getDict
().
setState
(
Dict
::
State
::
Closed
);
dictSetPretrained
(
true
);
}
...
...
@@ -138,7 +140,7 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c
else
{
std
::
string
featureValue
=
functions
[
colIndex
](
config
.
getAsFeature
(
col
,
index
));
if
(
w2vFile
.
empty
())
if
(
w2vFile
s
.
empty
())
featureValue
=
fmt
::
format
(
"{}({})"
,
col
,
featureValue
);
dictIndex
=
dict
.
getIndexOrInsert
(
featureValue
);
}
...
...
@@ -161,6 +163,8 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input)
void
ContextModuleImpl
::
registerEmbeddings
()
{
wordEmbeddings
=
register_module
(
"embeddings"
,
torch
::
nn
::
Embedding
(
torch
::
nn
::
EmbeddingOptions
(
getDict
().
size
(),
inSize
)));
loadPretrainedW2vEmbeddings
(
wordEmbeddings
,
w2vFile
.
empty
()
?
""
:
path
/
w2vFile
);
auto
pathes
=
util
::
split
(
w2vFiles
.
string
(),
' '
);
for
(
auto
&
p
:
pathes
)
loadPretrainedW2vEmbeddings
(
wordEmbeddings
,
path
/
p
);
}
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment