Skip to content
GitLab
Menu
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Franck Dary
macaon
Commits
b495167c
Commit
b495167c
authored
Mar 06, 2021
by
Franck Dary
Browse files
Parallel extractExamples
parent
30e51f46
Changes
3
Hide whitespace changes
Inline
Side-by-side
common/include/Dict.hpp
View file @
b495167c
...
...
@@ -5,6 +5,7 @@
#include
<unordered_map>
#include
<vector>
#include
<filesystem>
#include
<mutex>
class
Dict
{
...
...
@@ -30,6 +31,7 @@ class Dict
std
::
unordered_map
<
std
::
string
,
int
>
elementsToIndexes
;
std
::
unordered_map
<
int
,
std
::
string
>
indexesToElements
;
std
::
vector
<
int
>
nbOccs
;
std
::
mutex
elementsMutex
;
State
state
;
bool
isCountingOccs
{
false
};
...
...
@@ -43,6 +45,7 @@ class Dict
void
readFromFile
(
const
char
*
filename
);
void
insert
(
const
std
::
string
&
element
);
void
reset
();
int
_getIndexOrInsert
(
const
std
::
string
&
element
,
const
std
::
string
&
prefix
);
public
:
...
...
common/src/Dict.cpp
View file @
b495167c
...
...
@@ -90,20 +90,33 @@ void Dict::insert(const std::string & element)
}
int
Dict
::
getIndexOrInsert
(
const
std
::
string
&
element
,
const
std
::
string
&
prefix
)
{
if
(
state
==
State
::
Open
)
elementsMutex
.
lock
();
int
index
=
_getIndexOrInsert
(
element
,
prefix
);
if
(
state
==
State
::
Open
)
elementsMutex
.
unlock
();
return
index
;
}
int
Dict
::
_getIndexOrInsert
(
const
std
::
string
&
element
,
const
std
::
string
&
prefix
)
{
if
(
element
.
empty
())
return
getIndexOrInsert
(
emptyValueStr
,
prefix
);
return
_
getIndexOrInsert
(
emptyValueStr
,
prefix
);
if
(
util
::
printedLength
(
element
)
==
1
and
util
::
isSeparator
(
util
::
utf8char
(
element
)))
{
return
getIndexOrInsert
(
separatorValueStr
,
prefix
);
return
_
getIndexOrInsert
(
separatorValueStr
,
prefix
);
}
if
(
util
::
isNumber
(
element
))
return
getIndexOrInsert
(
numberValueStr
,
prefix
);
return
_
getIndexOrInsert
(
numberValueStr
,
prefix
);
if
(
util
::
isUrl
(
element
))
return
getIndexOrInsert
(
urlValueStr
,
prefix
);
return
_
getIndexOrInsert
(
urlValueStr
,
prefix
);
auto
prefixed
=
prefix
.
empty
()
?
element
:
fmt
::
format
(
"{}({})"
,
prefix
,
element
);
const
auto
&
found
=
elementsToIndexes
.
find
(
prefixed
);
...
...
trainer/src/Trainer.cpp
View file @
b495167c
#include
"Trainer.hpp"
#include
"SubConfig.hpp"
#include
<execution>
Trainer
::
Trainer
(
ReadingMachine
&
machine
,
int
batchSize
)
:
machine
(
machine
),
batchSize
(
batchSize
)
{
...
...
@@ -35,7 +36,8 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
torch
::
AutoGradMode
useGrad
(
false
);
int
maxNbExamplesPerFile
=
50000
;
std
::
map
<
std
::
string
,
Examples
>
examplesPerState
;
std
::
unordered_map
<
std
::
string
,
Examples
>
examplesPerState
;
std
::
mutex
examplesMutex
;
std
::
filesystem
::
create_directories
(
dir
);
...
...
@@ -46,144 +48,152 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
fmt
::
print
(
stderr
,
"[{}] Starting to extract examples{}
\n
"
,
util
::
getTime
(),
dynamicOracle
?
", dynamic oracle"
:
""
);
int
totalNbExamples
=
0
;
std
::
atomic
<
int
>
totalNbExamples
=
0
;
for
(
auto
&
config
:
configs
)
{
config
.
addPredicted
(
machine
.
getPredicted
());
config
.
setStrategy
(
machine
.
getStrategyDefinition
());
config
.
setState
(
config
.
getStrategy
().
getInitialState
());
while
(
true
)
NeuralNetworkImpl
::
device
=
torch
::
kCPU
;
machine
.
to
(
NeuralNetworkImpl
::
device
);
std
::
for_each
(
std
::
execution
::
par_unseq
,
configs
.
begin
(),
configs
.
end
(),
[
this
,
maxNbExamplesPerFile
,
&
examplesPerState
,
&
totalNbExamples
,
debug
,
dynamicOracle
,
explorationThreshold
,
dir
,
epoch
,
&
examplesMutex
](
SubConfig
&
config
)
{
if
(
debug
)
config
.
printForDebug
(
stderr
);
config
.
addPredicted
(
machine
.
getPredicted
());
config
.
setStrategy
(
machine
.
getStrategyDefinition
());
config
.
setState
(
config
.
getStrategy
().
getInitialState
());
if
(
machine
.
hasSplitWordTransitionSet
())
config
.
setAppliableSplitTransitions
(
machine
.
getSplitWordTransitionSet
().
getNAppliableTransitions
(
config
,
Config
::
maxNbAppliableSplitTransitions
));
while
(
true
)
{
if
(
debug
)
config
.
printForDebug
(
stderr
);
auto
appliableTransitions
=
machine
.
getTransitionSet
(
config
.
getState
()).
getAppliableTransitions
(
config
);
config
.
setAppliableTransitions
(
appliableTransitions
);
if
(
machine
.
hasSplitWordTransitionSet
())
config
.
setAppliable
Split
Transitions
(
machine
.
getSplitWordTransitionSet
().
getNAppliableTransitions
(
config
,
Config
::
maxNbAppliableSplitTransitions
)
);
torch
::
Tensor
context
;
auto
appliableTransitions
=
machine
.
getTransitionSet
(
config
.
getState
()).
getAppliableTransitions
(
config
);
config
.
setAppliableTransitions
(
appliableTransitions
);
try
{
context
=
machine
.
getClassifier
(
config
.
getState
())
->
getNN
()
->
extractContext
(
config
);
}
catch
(
std
::
exception
&
e
)
{
util
::
myThrow
(
fmt
::
format
(
"Failed to extract context : {}"
,
e
.
what
()));
}
torch
::
Tensor
context
;
Transition
*
transition
=
nullptr
;
try
{
context
=
machine
.
getClassifier
(
config
.
getState
())
->
getNN
()
->
extractContext
(
config
);
}
catch
(
std
::
exception
&
e
)
{
util
::
myThrow
(
fmt
::
format
(
"Failed to extract context : {}"
,
e
.
what
()));
}
auto
gold
Transition
s
=
machine
.
getTransitionSet
(
config
.
getState
()).
getBestAppliableTransitions
(
config
,
appliableTransitions
,
true
or
dynamicOracle
)
;
Transition
*
transition
=
nullptr
;
Transition
*
goldTransition
=
goldTransitions
[
0
];
if
(
config
.
getState
()
==
"parser"
)
goldTransitions
[
std
::
rand
()
%
goldTransitions
.
size
()];
auto
goldTransitions
=
machine
.
getTransitionSet
(
config
.
getState
()).
getBestAppliableTransitions
(
config
,
appliableTransitions
,
true
or
dynamicOracle
);
int
nbClasses
=
machine
.
getTransitionSet
(
config
.
getState
()).
size
();
Transition
*
goldTransition
=
goldTransitions
[
0
];
if
(
config
.
getState
()
==
"parser"
)
goldTransitions
[
std
::
rand
()
%
goldTransitions
.
size
()];
float
bestScore
=
-
std
::
numeric_limits
<
float
>::
max
();
int
nbClasses
=
machine
.
getTransitionSet
(
config
.
getState
()).
size
();
float
entropy
=
0.0
;
if
(
dynamicOracle
and
util
::
choiceWithProbability
(
1.0
)
and
config
.
getState
()
!=
"tokenizer"
and
config
.
getState
()
!=
"segmenter"
)
{
auto
&
classifier
=
*
machine
.
getClassifier
(
config
.
getState
());
auto
prediction
=
classifier
.
isRegression
()
?
classifier
.
getNN
()
->
forward
(
context
,
config
.
getState
()).
squeeze
(
0
)
:
torch
::
softmax
(
classifier
.
getNN
()
->
forward
(
context
,
config
.
getState
()).
squeeze
(
0
),
0
);
entropy
=
NeuralNetworkImpl
::
entropy
(
prediction
);
std
::
vector
<
int
>
candidates
;
float
bestScore
=
-
std
::
numeric_limits
<
float
>::
max
();
for
(
unsigned
int
i
=
0
;
i
<
prediction
.
size
(
0
);
i
++
)
float
entropy
=
0.0
;
if
(
dynamicOracle
and
util
::
choiceWithProbability
(
1.0
)
and
config
.
getState
()
!=
"tokenizer"
and
config
.
getState
()
!=
"segmenter"
)
{
float
score
=
prediction
[
i
].
item
<
float
>
();
if
(
score
>
bestScore
and
appliableTransitions
[
i
])
bestScore
=
score
;
auto
&
classifier
=
*
machine
.
getClassifier
(
config
.
getState
());
auto
prediction
=
classifier
.
isRegression
()
?
classifier
.
getNN
()
->
forward
(
context
,
config
.
getState
()).
squeeze
(
0
)
:
torch
::
softmax
(
classifier
.
getNN
()
->
forward
(
context
,
config
.
getState
()).
squeeze
(
0
),
0
);
entropy
=
NeuralNetworkImpl
::
entropy
(
prediction
);
std
::
vector
<
int
>
candidates
;
for
(
unsigned
int
i
=
0
;
i
<
prediction
.
size
(
0
);
i
++
)
{
float
score
=
prediction
[
i
].
item
<
float
>
();
if
(
score
>
bestScore
and
appliableTransitions
[
i
])
bestScore
=
score
;
}
for
(
unsigned
int
i
=
0
;
i
<
prediction
.
size
(
0
);
i
++
)
{
float
score
=
prediction
[
i
].
item
<
float
>
();
if
(
appliableTransitions
[
i
]
and
bestScore
-
score
<=
explorationThreshold
)
candidates
.
emplace_back
(
i
);
}
transition
=
machine
.
getTransitionSet
(
config
.
getState
()).
getTransition
(
candidates
[
std
::
rand
()
%
candidates
.
size
()]);
}
for
(
unsigned
int
i
=
0
;
i
<
prediction
.
size
(
0
);
i
++
)
else
{
float
score
=
prediction
[
i
].
item
<
float
>
();
if
(
appliableTransitions
[
i
]
and
bestScore
-
score
<=
explorationThreshold
)
candidates
.
emplace_back
(
i
);
transition
=
goldTransition
;
}
transition
=
machine
.
getTransitionSet
(
config
.
getState
()).
getTransition
(
candidates
[
std
::
rand
()
%
candidates
.
size
()]);
}
else
{
transition
=
goldTransition
;
}
if
(
!
transition
or
!
goldTransition
)
{
config
.
printForDebug
(
stderr
);
util
::
myThrow
(
"No transition appliable !"
);
}
if
(
!
transition
or
!
goldTransition
)
{
config
.
printForDebug
(
stderr
);
util
::
myThrow
(
"No transition appliable !"
);
}
std
::
vector
<
long
>
goldIndexes
;
bool
exampleIsBanned
=
machine
.
getClassifier
(
config
.
getState
())
->
exampleIsBanned
(
config
);
std
::
vector
<
long
>
goldIndexes
;
bool
exampleIsBanned
=
machine
.
getClassifier
(
config
.
getState
())
->
exampleIsBanned
(
config
);
if
(
machine
.
getClassifier
(
config
.
getState
())
->
isRegression
())
{
entropy
=
0.0
;
auto
errMessage
=
fmt
::
format
(
"Invalid regression transition '{}'"
,
transition
->
getName
());
auto
splited
=
util
::
split
(
transition
->
getName
(),
' '
);
if
(
splited
.
size
()
!=
3
or
splited
[
0
]
!=
"WRITESCORE"
)
util
::
myThrow
(
errMessage
);
auto
col
=
splited
[
2
];
splited
=
util
::
split
(
splited
[
1
],
'.'
);
if
(
splited
.
size
()
!=
2
)
util
::
myThrow
(
errMessage
);
auto
object
=
Config
::
str2object
(
splited
[
0
]);
int
index
=
std
::
stoi
(
splited
[
1
]);
float
regressionTarget
=
std
::
stof
(
config
.
getConst
(
col
,
config
.
getRelativeWordIndex
(
object
,
index
),
0
));
goldIndexes
.
emplace_back
(
util
::
float2long
(
regressionTarget
));
}
else
{
for
(
auto
&
t
:
goldTransitions
)
goldIndexes
.
emplace_back
(
machine
.
getTransitionSet
(
config
.
getState
()).
getTransitionIndex
(
t
));
}
if
(
machine
.
getClassifier
(
config
.
getState
())
->
isRegression
())
{
entropy
=
0.0
;
auto
errMessage
=
fmt
::
format
(
"Invalid regression transition '{}'"
,
transition
->
getName
());
auto
splited
=
util
::
split
(
transition
->
getName
(),
' '
);
if
(
splited
.
size
()
!=
3
or
splited
[
0
]
!=
"WRITESCORE"
)
util
::
myThrow
(
errMessage
);
auto
col
=
splited
[
2
];
splited
=
util
::
split
(
splited
[
1
],
'.'
);
if
(
splited
.
size
()
!=
2
)
util
::
myThrow
(
errMessage
);
auto
object
=
Config
::
str2object
(
splited
[
0
]);
int
index
=
std
::
stoi
(
splited
[
1
]);
float
regressionTarget
=
std
::
stof
(
config
.
getConst
(
col
,
config
.
getRelativeWordIndex
(
object
,
index
),
0
));
goldIndexes
.
emplace_back
(
util
::
float2long
(
regressionTarget
));
}
else
{
for
(
auto
&
t
:
goldTransitions
)
goldIndexes
.
emplace_back
(
machine
.
getTransitionSet
(
config
.
getState
()).
getTransitionIndex
(
t
));
if
(
!
exampleIsBanned
)
{
totalNbExamples
+=
1
;
if
(
totalNbExamples
>=
(
int
)
safetyNbExamplesMax
)
util
::
myThrow
(
fmt
::
format
(
"Trying to extract more examples than the limit ({})"
,
util
::
int2HumanStr
(
safetyNbExamplesMax
)));
}
examplesPerState
[
config
.
getState
()].
addContext
(
context
);
examplesPerState
[
config
.
getState
()].
addClass
(
machine
.
getClassifier
(
config
.
getState
())
->
getLossFunction
(),
nbClasses
,
goldIndexes
);
examplesPerState
[
config
.
getState
()].
saveIfNeeded
(
config
.
getState
(),
dir
,
maxNbExamplesPerFile
,
epoch
,
dynamicOracle
);
}
if
(
!
exampleIsBanned
)
{
totalNbExamples
+=
1
;
if
(
totalNbExamples
>=
(
int
)
safetyNbExamplesMax
)
util
::
myThrow
(
fmt
::
format
(
"Trying to extract more examples than the limit ({})"
,
util
::
int2HumanStr
(
safetyNbExamplesMax
)));
examplesMutex
.
lock
();
examplesPerState
[
config
.
getState
()].
addContext
(
context
);
examplesPerState
[
config
.
getState
()].
addClass
(
machine
.
getClassifier
(
config
.
getState
())
->
getLossFunction
(),
nbClasses
,
goldIndexes
);
examplesPerState
[
config
.
getState
()].
saveIfNeeded
(
config
.
getState
(),
dir
,
maxNbExamplesPerFile
,
epoch
,
dynamicOracle
);
examplesMutex
.
unlock
();
}
config
.
setChosenActionScore
(
bestScore
);
config
.
setChosenActionScore
(
bestScore
);
transition
->
apply
(
config
,
entropy
);
config
.
addToHistory
(
transition
->
getName
());
transition
->
apply
(
config
,
entropy
);
config
.
addToHistory
(
transition
->
getName
());
auto
movement
=
config
.
getStrategy
().
getMovement
(
config
,
transition
->
getName
());
if
(
debug
)
fmt
::
print
(
stderr
,
"(Transition,Newstate,Movement) = ({},{},{})
\n
"
,
transition
->
getName
(),
movement
.
first
,
movement
.
second
);
if
(
movement
==
Strategy
::
endMovement
)
break
;
auto
movement
=
config
.
getStrategy
().
getMovement
(
config
,
transition
->
getName
());
if
(
debug
)
fmt
::
print
(
stderr
,
"(Transition,Newstate,Movement) = ({},{},{})
\n
"
,
transition
->
getName
(),
movement
.
first
,
movement
.
second
);
if
(
movement
==
Strategy
::
endMovement
)
break
;
config
.
setState
(
movement
.
first
);
config
.
moveWordIndexRelaxed
(
movement
.
second
);
config
.
setState
(
movement
.
first
);
config
.
moveWordIndexRelaxed
(
movement
.
second
);
if
(
config
.
needsUpdate
())
config
.
update
();
}
// End while true
}
// End for on configs
if
(
config
.
needsUpdate
())
config
.
update
();
}
// End while true
}
);
// End for on configs
for
(
auto
&
it
:
examplesPerState
)
it
.
second
.
saveIfNeeded
(
it
.
first
,
dir
,
0
,
epoch
,
dynamicOracle
);
NeuralNetworkImpl
::
device
=
NeuralNetworkImpl
::
getPreferredDevice
();
machine
.
to
(
NeuralNetworkImpl
::
device
);
std
::
FILE
*
f
=
std
::
fopen
(
currentEpochAllExtractedFile
.
c_str
(),
"w"
);
if
(
!
f
)
util
::
myThrow
(
fmt
::
format
(
"could not create file '{}'"
,
currentEpochAllExtractedFile
.
c_str
()));
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment