Skip to content
GitLab
Menu
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Franck Dary
macaon
Commits
57db2a2e
Commit
57db2a2e
authored
Jul 31, 2020
by
Franck Dary
Browse files
Changed the way prefix are handled in dicts
parent
4f4cd7c3
Changes
14
Hide whitespace changes
Inline
Side-by-side
common/include/Dict.hpp
View file @
57db2a2e
...
...
@@ -26,6 +26,7 @@ class Dict
private
:
std
::
unordered_map
<
std
::
string
,
int
>
elementsToIndexes
;
std
::
unordered_map
<
int
,
std
::
string
>
indexesToElements
;
std
::
vector
<
int
>
nbOccs
;
State
state
;
bool
isCountingOccs
{
false
};
...
...
@@ -43,7 +44,8 @@ class Dict
public
:
void
countOcc
(
bool
isCountingOccs
);
int
getIndexOrInsert
(
const
std
::
string
&
element
);
int
getIndexOrInsert
(
const
std
::
string
&
element
,
const
std
::
string
&
prefix
);
std
::
string
getElement
(
std
::
size_t
index
);
void
setState
(
State
state
);
State
getState
()
const
;
void
save
(
std
::
filesystem
::
path
path
,
Encoding
encoding
)
const
;
...
...
@@ -52,7 +54,8 @@ class Dict
std
::
size_t
size
()
const
;
int
getNbOccs
(
int
index
)
const
;
void
removeRareElements
();
void
loadWord2Vec
(
std
::
filesystem
::
path
path
);
void
loadWord2Vec
(
std
::
filesystem
::
path
path
,
std
::
string
prefix
);
bool
isSpecialValue
(
const
std
::
string
&
value
);
};
#endif
common/src/Dict.cpp
View file @
57db2a2e
...
...
@@ -42,6 +42,7 @@ void Dict::readFromFile(const char * filename)
util
::
myThrow
(
fmt
::
format
(
"file '{}' bad format"
,
filename
));
elementsToIndexes
.
reserve
(
nbEntries
);
indexesToElements
.
reserve
(
nbEntries
);
int
entryIndex
;
int
nbOccsEntry
;
...
...
@@ -52,6 +53,7 @@ void Dict::readFromFile(const char * filename)
util
::
myThrow
(
fmt
::
format
(
"file '{}' line {} bad format"
,
filename
,
i
));
elementsToIndexes
[
entryString
]
=
entryIndex
;
indexesToElements
[
entryIndex
]
=
entryString
;
while
((
int
)
nbOccs
.
size
()
<=
entryIndex
)
nbOccs
.
emplace_back
(
0
);
nbOccs
[
entryIndex
]
=
nbOccsEntry
;
...
...
@@ -66,37 +68,40 @@ void Dict::insert(const std::string & element)
util
::
myThrow
(
fmt
::
format
(
"inserting element of size={} > maxElementSize={}"
,
element
.
size
(),
maxEntrySize
));
elementsToIndexes
.
emplace
(
element
,
elementsToIndexes
.
size
());
indexesToElements
.
emplace
(
elementsToIndexes
.
size
()
-
1
,
element
);
while
(
nbOccs
.
size
()
<
elementsToIndexes
.
size
())
nbOccs
.
emplace_back
(
0
);
}
int
Dict
::
getIndexOrInsert
(
const
std
::
string
&
element
)
int
Dict
::
getIndexOrInsert
(
const
std
::
string
&
element
,
const
std
::
string
&
prefix
)
{
if
(
element
.
empty
())
return
getIndexOrInsert
(
emptyValueStr
);
return
getIndexOrInsert
(
emptyValueStr
,
prefix
);
if
(
element
.
size
()
==
1
and
util
::
isSeparator
(
util
::
utf8char
(
element
)))
return
getIndexOrInsert
(
separatorValueStr
);
return
getIndexOrInsert
(
separatorValueStr
,
prefix
);
if
(
util
::
isNumber
(
element
))
return
getIndexOrInsert
(
numberValueStr
);
return
getIndexOrInsert
(
numberValueStr
,
prefix
);
if
(
util
::
isUrl
(
element
))
return
getIndexOrInsert
(
urlValueStr
);
return
getIndexOrInsert
(
urlValueStr
,
prefix
);
const
auto
&
found
=
elementsToIndexes
.
find
(
element
);
auto
prefixed
=
prefix
.
empty
()
?
element
:
fmt
::
format
(
"{}({})"
,
prefix
,
element
);
const
auto
&
found
=
elementsToIndexes
.
find
(
prefixed
);
if
(
found
==
elementsToIndexes
.
end
())
{
if
(
state
==
State
::
Open
)
{
insert
(
element
);
insert
(
prefixed
);
if
(
isCountingOccs
)
nbOccs
[
elementsToIndexes
[
element
]]
++
;
return
elementsToIndexes
[
element
];
nbOccs
[
elementsToIndexes
[
prefixed
]]
++
;
return
elementsToIndexes
[
prefixed
];
}
const
auto
&
found2
=
elementsToIndexes
.
find
(
util
::
lower
(
element
));
prefixed
=
prefix
.
empty
()
?
util
::
lower
(
element
)
:
fmt
::
format
(
"{}({})"
,
prefix
,
util
::
lower
(
element
));
const
auto
&
found2
=
elementsToIndexes
.
find
(
prefixed
);
if
(
found2
!=
elementsToIndexes
.
end
())
{
if
(
isCountingOccs
)
...
...
@@ -104,9 +109,10 @@ int Dict::getIndexOrInsert(const std::string & element)
return
found2
->
second
;
}
prefixed
=
prefix
.
empty
()
?
unknownValueStr
:
fmt
::
format
(
"{}({})"
,
prefix
,
unknownValueStr
);
if
(
isCountingOccs
)
nbOccs
[
elementsToIndexes
[
unknownValueStr
]]
++
;
return
elementsToIndexes
[
unknownValueStr
];
nbOccs
[
elementsToIndexes
[
prefixed
]]
++
;
return
elementsToIndexes
[
prefixed
];
}
if
(
isCountingOccs
)
...
...
@@ -217,7 +223,7 @@ void Dict::removeRareElements()
nbOccs
=
newNbOccs
;
}
void
Dict
::
loadWord2Vec
(
std
::
filesystem
::
path
path
)
void
Dict
::
loadWord2Vec
(
std
::
filesystem
::
path
path
,
std
::
string
prefix
)
{
if
(
path
.
empty
())
return
;
...
...
@@ -235,6 +241,16 @@ void Dict::loadWord2Vec(std::filesystem::path path)
try
{
if
(
!
prefix
.
empty
())
{
std
::
vector
<
std
::
string
>
toAdd
;
for
(
auto
&
it
:
elementsToIndexes
)
if
(
isSpecialValue
(
it
.
first
))
toAdd
.
emplace_back
(
fmt
::
format
(
"{}({})"
,
prefix
,
it
.
first
));
for
(
auto
&
elem
:
toAdd
)
getIndexOrInsert
(
elem
,
""
);
}
while
(
!
std
::
feof
(
file
))
{
if
(
buffer
!=
std
::
fgets
(
buffer
,
100000
,
file
))
...
...
@@ -251,9 +267,13 @@ void Dict::loadWord2Vec(std::filesystem::path path)
if
(
splited
.
size
()
<
2
)
util
::
myThrow
(
fmt
::
format
(
"invalid w2v line '{}' less than 2 columns"
,
buffer
));
auto
dictIndex
=
getIndexOrInsert
(
splited
[
0
]);
if
(
splited
[
0
]
==
"<unk>"
)
continue
;
auto
toInsert
=
util
::
splitAsUtf8
(
splited
[
0
]);
toInsert
.
replace
(
"◌"
,
" "
);
auto
dictIndex
=
getIndexOrInsert
(
fmt
::
format
(
"{}"
,
toInsert
),
prefix
);
if
(
dictIndex
==
getIndexOrInsert
(
Dict
::
unknownValueStr
)
or
dictIndex
==
getIndexOrInsert
(
Dict
::
nullValueStr
)
or
dictIndex
==
getIndexOrInsert
(
Dict
::
emptyValueStr
))
if
(
dictIndex
==
getIndexOrInsert
(
Dict
::
unknownValueStr
,
prefix
)
or
dictIndex
==
getIndexOrInsert
(
Dict
::
nullValueStr
,
prefix
)
or
dictIndex
==
getIndexOrInsert
(
Dict
::
emptyValueStr
,
prefix
))
util
::
myThrow
(
fmt
::
format
(
"w2v line '{}' gave unexpected special dict index"
,
buffer
));
}
}
catch
(
std
::
exception
&
e
)
...
...
@@ -269,3 +289,18 @@ void Dict::loadWord2Vec(std::filesystem::path path)
setState
(
originalState
);
}
bool
Dict
::
isSpecialValue
(
const
std
::
string
&
value
)
{
return
value
==
unknownValueStr
||
value
==
nullValueStr
||
value
==
emptyValueStr
||
value
==
separatorValueStr
||
value
==
numberValueStr
||
value
==
urlValueStr
;
}
std
::
string
Dict
::
getElement
(
std
::
size_t
index
)
{
return
indexesToElements
[
index
];
}
torch_modules/include/ContextualModule.hpp
View file @
57db2a2e
...
...
@@ -22,7 +22,7 @@ class ContextualModuleImpl : public Submodule
int
inSize
;
int
outSize
;
std
::
filesystem
::
path
path
;
std
::
filesystem
::
path
w2vFile
;
std
::
filesystem
::
path
w2vFile
s
;
public
:
...
...
torch_modules/include/Submodule.hpp
View file @
57db2a2e
...
...
@@ -16,7 +16,7 @@ class Submodule : public torch::nn::Module, public DictHolder, public StateHolde
public
:
void
setFirstInputIndex
(
std
::
size_t
firstInputIndex
);
void
loadPretrainedW2vEmbeddings
(
torch
::
nn
::
Embedding
&
embeddings
,
std
::
filesystem
::
path
path
);
void
loadPretrainedW2vEmbeddings
(
torch
::
nn
::
Embedding
&
embeddings
,
std
::
filesystem
::
path
path
,
std
::
string
prefix
);
virtual
std
::
size_t
getOutputSize
()
=
0
;
virtual
std
::
size_t
getInputSize
()
=
0
;
virtual
void
addToContext
(
std
::
vector
<
std
::
vector
<
long
>>
&
context
,
const
Config
&
config
)
=
0
;
...
...
torch_modules/src/ContextModule.cpp
View file @
57db2a2e
...
...
@@ -54,9 +54,14 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin
{
auto
pathes
=
util
::
split
(
w2vFiles
.
string
(),
' '
);
for
(
auto
&
p
:
pathes
)
getDict
().
loadWord2Vec
(
this
->
path
/
p
);
getDict
().
setState
(
Dict
::
State
::
Closed
);
dictSetPretrained
(
true
);
{
auto
splited
=
util
::
split
(
p
,
','
);
if
(
splited
.
size
()
!=
2
)
util
::
myThrow
(
"expected 'prefix,pretrained.w2v'"
);
getDict
().
loadWord2Vec
(
this
->
path
/
splited
[
1
],
splited
[
0
]);
getDict
().
setState
(
Dict
::
State
::
Closed
);
dictSetPretrained
(
true
);
}
}
}
catch
(
std
::
exception
&
e
)
{
util
::
myThrow
(
fmt
::
format
(
"{} in '{}'"
,
e
.
what
(),
definition
));}
...
...
@@ -117,7 +122,7 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c
if
(
index
==
-
1
)
{
for
(
auto
&
contextElement
:
context
)
contextElement
.
push_back
(
dict
.
getIndexOrInsert
(
fmt
::
format
(
"{}({})"
,
col
,
Dict
::
nullValueStr
)
));
contextElement
.
push_back
(
dict
.
getIndexOrInsert
(
Dict
::
nullValueStr
,
col
));
}
else
{
...
...
@@ -126,23 +131,17 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c
{
std
::
string
value
;
if
(
config
.
isCommentPredicted
(
index
))
value
=
"
ID(
comment
)
"
;
value
=
"comment"
;
else
if
(
config
.
isMultiwordPredicted
(
index
))
value
=
"
ID(
multiword
)
"
;
value
=
"multiword"
;
else
if
(
config
.
isTokenPredicted
(
index
))
value
=
"ID(token)"
;
dictIndex
=
dict
.
getIndexOrInsert
(
value
);
}
else
if
(
col
==
Config
::
EOSColName
)
{
dictIndex
=
dict
.
getIndexOrInsert
(
fmt
::
format
(
"EOS({})"
,
config
.
getAsFeature
(
col
,
index
)));
value
=
"token"
;
dictIndex
=
dict
.
getIndexOrInsert
(
value
,
col
);
}
else
{
std
::
string
featureValue
=
functions
[
colIndex
](
config
.
getAsFeature
(
col
,
index
));
if
(
w2vFiles
.
empty
())
featureValue
=
fmt
::
format
(
"{}({})"
,
col
,
featureValue
);
dictIndex
=
dict
.
getIndexOrInsert
(
featureValue
);
dictIndex
=
dict
.
getIndexOrInsert
(
featureValue
,
col
);
}
for
(
auto
&
contextElement
:
context
)
...
...
@@ -165,6 +164,9 @@ void ContextModuleImpl::registerEmbeddings()
wordEmbeddings
=
register_module
(
"embeddings"
,
torch
::
nn
::
Embedding
(
torch
::
nn
::
EmbeddingOptions
(
getDict
().
size
(),
inSize
)));
auto
pathes
=
util
::
split
(
w2vFiles
.
string
(),
' '
);
for
(
auto
&
p
:
pathes
)
loadPretrainedW2vEmbeddings
(
wordEmbeddings
,
path
/
p
);
{
auto
splited
=
util
::
split
(
p
,
','
);
loadPretrainedW2vEmbeddings
(
wordEmbeddings
,
path
/
splited
[
1
],
splited
[
0
]);
}
}
torch_modules/src/ContextualModule.cpp
View file @
57db2a2e
...
...
@@ -53,13 +53,20 @@ ContextualModuleImpl::ContextualModuleImpl(std::string name, const std::string &
else
util
::
myThrow
(
fmt
::
format
(
"unknown sumodule type '{}'"
,
subModuleType
));
w2vFile
=
sm
.
str
(
7
);
w2vFile
s
=
sm
.
str
(
7
);
if
(
!
w2vFile
.
empty
())
if
(
!
w2vFile
s
.
empty
())
{
getDict
().
loadWord2Vec
(
this
->
path
/
w2vFile
);
getDict
().
setState
(
Dict
::
State
::
Closed
);
dictSetPretrained
(
true
);
auto
pathes
=
util
::
split
(
w2vFiles
.
string
(),
' '
);
for
(
auto
&
p
:
pathes
)
{
auto
splited
=
util
::
split
(
p
,
','
);
if
(
splited
.
size
()
!=
2
)
util
::
myThrow
(
"expected 'prefix,file.w2v'"
);
getDict
().
loadWord2Vec
(
this
->
path
/
splited
[
1
],
splited
[
0
]);
getDict
().
setState
(
Dict
::
State
::
Closed
);
dictSetPretrained
(
true
);
}
}
}
catch
(
std
::
exception
&
e
)
{
util
::
myThrow
(
fmt
::
format
(
"{} in '{}'"
,
e
.
what
(),
definition
));}
...
...
@@ -127,17 +134,13 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context
if
(
index
==
-
1
)
{
for
(
auto
&
contextElement
:
context
)
contextElement
.
push_back
(
dict
.
getIndexOrInsert
(
fmt
::
format
(
"{}({})"
,
col
,
Dict
::
nullValueStr
)
));
contextElement
.
push_back
(
dict
.
getIndexOrInsert
(
Dict
::
nullValueStr
,
col
));
}
else
if
(
index
==
-
2
)
{
//TODO maybe change this to a unique value like Dict::noneValueStr
for
(
auto
&
contextElement
:
context
)
{
auto
currentState
=
dict
.
getState
();
dict
.
setState
(
Dict
::
State
::
Open
);
contextElement
.
push_back
(
dict
.
getIndexOrInsert
(
fmt
::
format
(
"{}({})"
,
col
,
"_NONE_"
)));
dict
.
setState
(
currentState
);
}
contextElement
.
push_back
(
dict
.
getIndexOrInsert
(
Dict
::
nullValueStr
,
col
));
}
else
{
...
...
@@ -146,23 +149,17 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context
{
std
::
string
value
;
if
(
config
.
isCommentPredicted
(
index
))
value
=
"
ID(
comment
)
"
;
value
=
"comment"
;
else
if
(
config
.
isMultiwordPredicted
(
index
))
value
=
"
ID(
multiword
)
"
;
value
=
"multiword"
;
else
if
(
config
.
isTokenPredicted
(
index
))
value
=
"ID(token)"
;
dictIndex
=
dict
.
getIndexOrInsert
(
value
);
}
else
if
(
col
==
Config
::
EOSColName
)
{
dictIndex
=
dict
.
getIndexOrInsert
(
fmt
::
format
(
"EOS({})"
,
config
.
getAsFeature
(
col
,
index
)));
value
=
"token"
;
dictIndex
=
dict
.
getIndexOrInsert
(
value
,
col
);
}
else
{
std
::
string
featureValue
=
config
.
getAsFeature
(
col
,
index
);
if
(
w2vFile
.
empty
())
featureValue
=
fmt
::
format
(
"{}({})"
,
col
,
featureValue
);
dictIndex
=
dict
.
getIndexOrInsert
(
functions
[
colIndex
](
featureValue
));
dictIndex
=
dict
.
getIndexOrInsert
(
functions
[
colIndex
](
featureValue
),
col
);
}
for
(
auto
&
contextElement
:
context
)
...
...
@@ -214,6 +211,12 @@ torch::Tensor ContextualModuleImpl::forward(torch::Tensor input)
void
ContextualModuleImpl
::
registerEmbeddings
()
{
wordEmbeddings
=
register_module
(
"embeddings"
,
torch
::
nn
::
Embedding
(
torch
::
nn
::
EmbeddingOptions
(
getDict
().
size
(),
inSize
)));
loadPretrainedW2vEmbeddings
(
wordEmbeddings
,
w2vFile
.
empty
()
?
""
:
path
/
w2vFile
);
auto
pathes
=
util
::
split
(
w2vFiles
.
string
(),
' '
);
for
(
auto
&
p
:
pathes
)
{
auto
splited
=
util
::
split
(
p
,
','
);
loadPretrainedW2vEmbeddings
(
wordEmbeddings
,
path
/
splited
[
1
],
splited
[
0
]);
}
}
torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
View file @
57db2a2e
...
...
@@ -117,9 +117,9 @@ void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<lon
for
(
int
i
=
0
;
i
<
maxElemPerDepth
[
depth
];
i
++
)
for
(
auto
&
col
:
columns
)
if
(
i
<
(
int
)
newChilds
.
size
()
and
config
.
has
(
col
,
std
::
stoi
(
newChilds
[
i
]),
0
))
contextElement
.
emplace_back
(
dict
.
getIndexOrInsert
(
config
.
getAsFeature
(
col
,
std
::
stoi
(
newChilds
[
i
]))));
contextElement
.
emplace_back
(
dict
.
getIndexOrInsert
(
config
.
getAsFeature
(
col
,
std
::
stoi
(
newChilds
[
i
]))
,
col
));
else
contextElement
.
emplace_back
(
dict
.
getIndexOrInsert
(
Dict
::
nullValueStr
));
contextElement
.
emplace_back
(
dict
.
getIndexOrInsert
(
Dict
::
nullValueStr
,
col
));
}
}
}
...
...
torch_modules/src/DistanceModule.cpp
View file @
57db2a2e
...
...
@@ -86,6 +86,8 @@ void DistanceModuleImpl::addToContext(std::vector<std::vector<long>> & context,
else
toIndexes
.
emplace_back
(
-
1
);
std
::
string
prefix
=
"DISTANCE"
;
for
(
auto
&
contextElement
:
context
)
{
for
(
auto
from
:
fromIndexes
)
...
...
@@ -93,16 +95,16 @@ void DistanceModuleImpl::addToContext(std::vector<std::vector<long>> & context,
{
if
(
from
==
-
1
or
to
==
-
1
)
{
contextElement
.
emplace_back
(
dict
.
getIndexOrInsert
(
Dict
::
nullValueStr
));
contextElement
.
emplace_back
(
dict
.
getIndexOrInsert
(
Dict
::
nullValueStr
,
prefix
));
continue
;
}
long
dist
=
std
::
abs
(
config
.
getRelativeDistance
(
from
,
to
));
if
(
dist
<=
threshold
)
contextElement
.
emplace_back
(
dict
.
getIndexOrInsert
(
fmt
::
format
(
"
distance({})"
,
dist
)));
contextElement
.
emplace_back
(
dict
.
getIndexOrInsert
(
fmt
::
format
(
"
{}({})"
,
prefix
,
dist
)
,
""
));
else
contextElement
.
emplace_back
(
dict
.
getIndexOrInsert
(
Dict
::
unknownValueStr
));
contextElement
.
emplace_back
(
dict
.
getIndexOrInsert
(
Dict
::
unknownValueStr
,
prefix
));
}
}
}
...
...
torch_modules/src/FocusedColumnModule.cpp
View file @
57db2a2e
...
...
@@ -84,7 +84,7 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont
if
(
index
==
-
1
)
{
for
(
int
i
=
0
;
i
<
maxNbElements
;
i
++
)
contextElement
.
emplace_back
(
dict
.
getIndexOrInsert
(
Dict
::
nullValueStr
));
contextElement
.
emplace_back
(
dict
.
getIndexOrInsert
(
Dict
::
nullValueStr
,
column
));
continue
;
}
...
...
@@ -93,6 +93,7 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont
{
auto
asUtf8
=
util
::
splitAsUtf8
(
func
(
config
.
getAsFeature
(
column
,
index
).
get
()));
//TODO don't use nullValueStr here
for
(
int
i
=
0
;
i
<
maxNbElements
;
i
++
)
if
(
i
<
(
int
)
asUtf8
.
size
())
elements
.
emplace_back
(
fmt
::
format
(
"{}"
,
asUtf8
[
i
]));
...
...
@@ -105,23 +106,23 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont
for
(
int
i
=
0
;
i
<
maxNbElements
;
i
++
)
if
(
i
<
(
int
)
splited
.
size
())
elements
.
emplace_back
(
fmt
::
format
(
"FEATS({})"
,
splited
[
i
])
)
;
elements
.
emplace_back
(
splited
[
i
]);
else
elements
.
emplace_back
(
Dict
::
nullValueStr
);
}
else
if
(
column
==
"ID"
)
{
if
(
config
.
isTokenPredicted
(
index
))
elements
.
emplace_back
(
"
ID(
TOKEN
)
"
);
elements
.
emplace_back
(
"TOKEN"
);
else
if
(
config
.
isMultiwordPredicted
(
index
))
elements
.
emplace_back
(
"
ID(
MULTIWORD
)
"
);
elements
.
emplace_back
(
"MULTIWORD"
);
else
if
(
config
.
isEmptyNodePredicted
(
index
))
elements
.
emplace_back
(
"
ID(
EMPTYNODE
)
"
);
elements
.
emplace_back
(
"EMPTYNODE"
);
}
else
if
(
column
==
"EOS"
)
{
bool
isEOS
=
func
(
config
.
getAsFeature
(
Config
::
EOSColName
,
index
))
==
Config
::
EOSSymbol1
;
elements
.
emplace_back
(
fmt
::
format
(
"
EOS(
{}
)
"
,
isEOS
));
elements
.
emplace_back
(
fmt
::
format
(
"{}"
,
isEOS
));
}
else
{
...
...
@@ -132,7 +133,7 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont
util
::
myThrow
(
fmt
::
format
(
"elements.size ({}) != maxNbElements ({})"
,
elements
.
size
(),
maxNbElements
));
for
(
auto
&
element
:
elements
)
contextElement
.
emplace_back
(
dict
.
getIndexOrInsert
(
element
));
contextElement
.
emplace_back
(
dict
.
getIndexOrInsert
(
element
,
column
));
}
}
}
...
...
torch_modules/src/HistoryModule.cpp
View file @
57db2a2e
...
...
@@ -57,12 +57,14 @@ void HistoryModuleImpl::addToContext(std::vector<std::vector<long>> & context, c
{
auto
&
dict
=
getDict
();
std
::
string
prefix
=
"HISTORY"
;
for
(
auto
&
contextElement
:
context
)
for
(
int
i
=
0
;
i
<
maxNbElements
;
i
++
)
if
(
config
.
hasHistory
(
i
))
contextElement
.
emplace_back
(
dict
.
getIndexOrInsert
(
config
.
getHistory
(
i
)));
contextElement
.
emplace_back
(
dict
.
getIndexOrInsert
(
config
.
getHistory
(
i
)
,
prefix
));
else
contextElement
.
emplace_back
(
dict
.
getIndexOrInsert
(
Dict
::
nullValueStr
));
contextElement
.
emplace_back
(
dict
.
getIndexOrInsert
(
Dict
::
nullValueStr
,
prefix
));
}
void
HistoryModuleImpl
::
registerEmbeddings
()
...
...
torch_modules/src/RawInputModule.cpp
View file @
57db2a2e
...
...
@@ -57,20 +57,22 @@ void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context,
if
(
leftWindow
<
0
or
rightWindow
<
0
)
return
;
std
::
string
prefix
=
"LETTER"
;
auto
&
dict
=
getDict
();
for
(
auto
&
contextElement
:
context
)
{
for
(
int
i
=
0
;
i
<
leftWindow
;
i
++
)
if
(
config
.
hasCharacter
(
config
.
getCharacterIndex
()
-
leftWindow
+
i
))
contextElement
.
push_back
(
dict
.
getIndexOrInsert
(
fmt
::
format
(
"{}
"
,
config
.
getLetter
(
config
.
getCharacterIndex
()
-
leftWindow
+
i
))));
contextElement
.
push_back
(
dict
.
getIndexOrInsert
(
fmt
::
format
(
"{}
({})"
,
prefix
,
config
.
getLetter
(
config
.
getCharacterIndex
()
-
leftWindow
+
i
))
,
""
));
else
contextElement
.
push_back
(
dict
.
getIndexOrInsert
(
Dict
::
nullValueStr
));
contextElement
.
push_back
(
dict
.
getIndexOrInsert
(
Dict
::
nullValueStr
,
prefix
));
for
(
int
i
=
0
;
i
<=
rightWindow
;
i
++
)
if
(
config
.
hasCharacter
(
config
.
getCharacterIndex
()
+
i
))
contextElement
.
push_back
(
dict
.
getIndexOrInsert
(
fmt
::
format
(
"{}
"
,
config
.
getLetter
(
config
.
getCharacterIndex
()
+
i
))));
contextElement
.
push_back
(
dict
.
getIndexOrInsert
(
fmt
::
format
(
"{}
({})"
,
prefix
,
config
.
getLetter
(
config
.
getCharacterIndex
()
+
i
))
,
""
));
else
contextElement
.
push_back
(
dict
.
getIndexOrInsert
(
Dict
::
nullValueStr
));
contextElement
.
push_back
(
dict
.
getIndexOrInsert
(
Dict
::
nullValueStr
,
prefix
));
}
}
...
...
torch_modules/src/SplitTransModule.cpp
View file @
57db2a2e
...
...
@@ -58,9 +58,9 @@ void SplitTransModuleImpl::addToContext(std::vector<std::vector<long>> & context
for
(
auto
&
contextElement
:
context
)
for
(
int
i
=
0
;
i
<
maxNbTrans
;
i
++
)
if
(
i
<
(
int
)
splitTransitions
.
size
())
contextElement
.
emplace_back
(
dict
.
getIndexOrInsert
(
splitTransitions
[
i
]
->
getName
()));
contextElement
.
emplace_back
(
dict
.
getIndexOrInsert
(
splitTransitions
[
i
]
->
getName
()
,
""
));
else
contextElement
.
emplace_back
(
dict
.
getIndexOrInsert
(
Dict
::
nullValueStr
));
contextElement
.
emplace_back
(
dict
.
getIndexOrInsert
(
Dict
::
nullValueStr
,
""
));
}
void
SplitTransModuleImpl
::
registerEmbeddings
()
...
...
torch_modules/src/StateNameModule.cpp
View file @
57db2a2e
...
...
@@ -33,7 +33,7 @@ void StateNameModuleImpl::addToContext(std::vector<std::vector<long>> & context,
{
auto
&
dict
=
getDict
();
for
(
auto
&
contextElement
:
context
)
contextElement
.
emplace_back
(
dict
.
getIndexOrInsert
(
config
.
getState
()));
contextElement
.
emplace_back
(
dict
.
getIndexOrInsert
(
config
.
getState
()
,
""
));
}
void
StateNameModuleImpl
::
registerEmbeddings
()
...
...
torch_modules/src/Submodule.cpp
View file @
57db2a2e
...
...
@@ -5,7 +5,7 @@ void Submodule::setFirstInputIndex(std::size_t firstInputIndex)
this
->
firstInputIndex
=
firstInputIndex
;
}
void
Submodule
::
loadPretrainedW2vEmbeddings
(
torch
::
nn
::
Embedding
&
embeddings
,
std
::
filesystem
::
path
path
)
void
Submodule
::
loadPretrainedW2vEmbeddings
(
torch
::
nn
::
Embedding
&
embeddings
,
std
::
filesystem
::
path
path
,
std
::
string
prefix
)
{
if
(
path
.
empty
())
return
;
...
...
@@ -44,12 +44,14 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, s
if
(
splited
.
size
()
<
2
)
util
::
myThrow
(
fmt
::
format
(
"invalid w2v line '{}' less than 2 columns"
,
buffer
));
auto
dictIndex
=
getDict
().
getIndexOrInsert
(
splited
[
0
]);
std
::
string
word
;
if
(
splited
[
0
]
==
"<unk>"
)
dictIndex
=
getDict
().
getIndexOrInsert
(
Dict
::
unknownValueStr
);
word
=
Dict
::
unknownValueStr
;
else
word
=
splited
[
0
];
if
(
splited
[
0
]
!=
"<unk>"
and
splited
[
0
]
!=
Dict
::
unknownValueStr
and
(
dictIndex
==
getDict
().
getIndexOrInsert
(
Dict
::
unknownValueStr
)
or
dictIndex
==
getDict
().
getIndexOrInsert
(
Dict
::
nullValueStr
)
or
dictIndex
==
getDict
().
getIndexOrInsert
(
Dict
::
emptyValueStr
)))
continue
;
auto
dictIndex
=
getDict
().
getIndexOrInsert
(
word
,
prefix
);
if
(
embeddingsSize
!=
splited
.
size
()
-
1
)
util
::
myThrow
(
fmt
::
format
(
"in line
\n
{}embeddingsSize='{}' mismatch pretrainedEmbeddingSize='{}'"
,
buffer
,
embeddingsSize
,
((
int
)
splited
.
size
())
-
1
));
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment