Skip to content
GitLab
Menu
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Franck Dary
macaon
Commits
397e390f
Commit
397e390f
authored
Jul 31, 2020
by
Franck Dary
Browse files
FocusedModule can now have pretraiend word embeddings
parent
57db2a2e
Changes
4
Hide whitespace changes
Inline
Side-by-side
torch_modules/include/FocusedColumnModule.hpp
View file @
397e390f
...
@@ -19,10 +19,12 @@ class FocusedColumnModuleImpl : public Submodule
...
@@ -19,10 +19,12 @@ class FocusedColumnModuleImpl : public Submodule
std
::
function
<
std
::
string
(
const
std
::
string
&
)
>
func
{[](
const
std
::
string
&
s
){
return
s
;}};
std
::
function
<
std
::
string
(
const
std
::
string
&
)
>
func
{[](
const
std
::
string
&
s
){
return
s
;}};
int
maxNbElements
;
int
maxNbElements
;
int
inSize
;
int
inSize
;
std
::
filesystem
::
path
path
;
std
::
filesystem
::
path
w2vFiles
;
public
:
public
:
FocusedColumnModuleImpl
(
std
::
string
name
,
const
std
::
string
&
definition
);
FocusedColumnModuleImpl
(
std
::
string
name
,
const
std
::
string
&
definition
,
std
::
filesystem
::
path
path
);
torch
::
Tensor
forward
(
torch
::
Tensor
input
);
torch
::
Tensor
forward
(
torch
::
Tensor
input
);
std
::
size_t
getOutputSize
()
override
;
std
::
size_t
getOutputSize
()
override
;
std
::
size_t
getInputSize
()
override
;
std
::
size_t
getInputSize
()
override
;
...
...
torch_modules/src/FocusedColumnModule.cpp
View file @
397e390f
#include
"FocusedColumnModule.hpp"
#include
"FocusedColumnModule.hpp"
FocusedColumnModuleImpl
::
FocusedColumnModuleImpl
(
std
::
string
name
,
const
std
::
string
&
definition
)
FocusedColumnModuleImpl
::
FocusedColumnModuleImpl
(
std
::
string
name
,
const
std
::
string
&
definition
,
std
::
filesystem
::
path
path
)
:
path
(
path
)
{
{
setName
(
name
);
setName
(
name
);
std
::
regex
regex
(
"(?:(?:
\\
s|
\\
t)*)Column
\\
{(.*)
\\
}(?:(?:
\\
s|
\\
t)*)NbElem
\\
{(.*)
\\
}(?:(?:
\\
s|
\\
t)*)Buffer
\\
{(.*)
\\
}(?:(?:
\\
s|
\\
t)*)Stack
\\
{(.*)
\\
}(?:(?:
\\
s|
\\
t)*)(
\\
S+)
\\
{(.*)
\\
}(?:(?:
\\
s|
\\
t)*)In
\\
{(.*)
\\
}(?:(?:
\\
s|
\\
t)*)Out
\\
{(.*)
\\
}(?:(?:
\\
s|
\\
t)*)"
);
std
::
regex
regex
(
"(?:(?:
\\
s|
\\
t)*)Column
\\
{(.*)
\\
}(?:(?:
\\
s|
\\
t)*)NbElem
\\
{(.*)
\\
}(?:(?:
\\
s|
\\
t)*)Buffer
\\
{(.*)
\\
}(?:(?:
\\
s|
\\
t)*)Stack
\\
{(.*)
\\
}(?:(?:
\\
s|
\\
t)*)(
\\
S+)
\\
{(.*)
\\
}(?:(?:
\\
s|
\\
t)*)In
\\
{(.*)
\\
}(?:(?:
\\
s|
\\
t)*)Out
\\
{(.*)
\\
}(?:(?:
\\
s|
\\
t)*)
w2v
\\
{(.*)
\\
}(?:(?:
\\
s|
\\
t)*)
"
);
if
(
!
util
::
doIfNameMatch
(
regex
,
definition
,
[
this
,
&
definition
](
auto
sm
)
if
(
!
util
::
doIfNameMatch
(
regex
,
definition
,
[
this
,
&
definition
](
auto
sm
)
{
{
try
try
...
@@ -39,6 +39,22 @@ FocusedColumnModuleImpl::FocusedColumnModuleImpl(std::string name, const std::st
...
@@ -39,6 +39,22 @@ FocusedColumnModuleImpl::FocusedColumnModuleImpl(std::string name, const std::st
else
else
util
::
myThrow
(
fmt
::
format
(
"unknown sumodule type '{}'"
,
subModuleType
));
util
::
myThrow
(
fmt
::
format
(
"unknown sumodule type '{}'"
,
subModuleType
));
w2vFiles
=
sm
.
str
(
9
);
if
(
!
w2vFiles
.
empty
())
{
auto
pathes
=
util
::
split
(
w2vFiles
.
string
(),
' '
);
for
(
auto
&
p
:
pathes
)
{
auto
splited
=
util
::
split
(
p
,
','
);
if
(
splited
.
size
()
!=
2
)
util
::
myThrow
(
"expected 'prefix,pretrained.w2v'"
);
getDict
().
loadWord2Vec
(
this
->
path
/
splited
[
1
],
splited
[
0
]);
getDict
().
setState
(
Dict
::
State
::
Closed
);
dictSetPretrained
(
true
);
}
}
}
catch
(
std
::
exception
&
e
)
{
util
::
myThrow
(
fmt
::
format
(
"{} in '{}'"
,
e
.
what
(),
definition
));}
}
catch
(
std
::
exception
&
e
)
{
util
::
myThrow
(
fmt
::
format
(
"{} in '{}'"
,
e
.
what
(),
definition
));}
}))
}))
util
::
myThrow
(
fmt
::
format
(
"invalid definition '{}'"
,
definition
));
util
::
myThrow
(
fmt
::
format
(
"invalid definition '{}'"
,
definition
));
...
@@ -141,5 +157,11 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont
...
@@ -141,5 +157,11 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont
void
FocusedColumnModuleImpl
::
registerEmbeddings
()
void
FocusedColumnModuleImpl
::
registerEmbeddings
()
{
{
wordEmbeddings
=
register_module
(
"embeddings"
,
torch
::
nn
::
Embedding
(
torch
::
nn
::
EmbeddingOptions
(
getDict
().
size
(),
inSize
)));
wordEmbeddings
=
register_module
(
"embeddings"
,
torch
::
nn
::
Embedding
(
torch
::
nn
::
EmbeddingOptions
(
getDict
().
size
(),
inSize
)));
auto
pathes
=
util
::
split
(
w2vFiles
.
string
(),
' '
);
for
(
auto
&
p
:
pathes
)
{
auto
splited
=
util
::
split
(
p
,
','
);
loadPretrainedW2vEmbeddings
(
wordEmbeddings
,
path
/
splited
[
1
],
splited
[
0
]);
}
}
}
torch_modules/src/ModularNetwork.cpp
View file @
397e390f
...
@@ -40,7 +40,7 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st
...
@@ -40,7 +40,7 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st
else
if
(
splited
.
first
==
"UppercaseRate"
)
else
if
(
splited
.
first
==
"UppercaseRate"
)
modules
.
emplace_back
(
register_module
(
name
,
UppercaseRateModule
(
nameH
,
splited
.
second
)));
modules
.
emplace_back
(
register_module
(
name
,
UppercaseRateModule
(
nameH
,
splited
.
second
)));
else
if
(
splited
.
first
==
"Focused"
)
else
if
(
splited
.
first
==
"Focused"
)
modules
.
emplace_back
(
register_module
(
name
,
FocusedColumnModule
(
nameH
,
splited
.
second
)));
modules
.
emplace_back
(
register_module
(
name
,
FocusedColumnModule
(
nameH
,
splited
.
second
,
path
)));
else
if
(
splited
.
first
==
"RawInput"
)
else
if
(
splited
.
first
==
"RawInput"
)
modules
.
emplace_back
(
register_module
(
name
,
RawInputModule
(
nameH
,
splited
.
second
)));
modules
.
emplace_back
(
register_module
(
name
,
RawInputModule
(
nameH
,
splited
.
second
)));
else
if
(
splited
.
first
==
"SplitTrans"
)
else
if
(
splited
.
first
==
"SplitTrans"
)
...
...
torch_modules/src/Submodule.cpp
View file @
397e390f
...
@@ -51,6 +51,10 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, s
...
@@ -51,6 +51,10 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, s
else
else
word
=
splited
[
0
];
word
=
splited
[
0
];
auto
toInsert
=
util
::
splitAsUtf8
(
word
);
toInsert
.
replace
(
"◌"
,
" "
);
word
=
fmt
::
format
(
"{}"
,
toInsert
);
auto
dictIndex
=
getDict
().
getIndexOrInsert
(
word
,
prefix
);
auto
dictIndex
=
getDict
().
getIndexOrInsert
(
word
,
prefix
);
if
(
embeddingsSize
!=
splited
.
size
()
-
1
)
if
(
embeddingsSize
!=
splited
.
size
()
-
1
)
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment