Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
M
macaon
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container registry
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Franck Dary
macaon
Commits
b495167c
Commit
b495167c
authored
4 years ago
by
Franck Dary
Browse files
Options
Downloads
Patches
Plain Diff
Parallel extractExamples
parent
30e51f46
No related branches found
No related tags found
No related merge requests found
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
common/include/Dict.hpp
+3
-0
3 additions, 0 deletions
common/include/Dict.hpp
common/src/Dict.cpp
+17
-4
17 additions, 4 deletions
common/src/Dict.cpp
trainer/src/Trainer.cpp
+118
-108
118 additions, 108 deletions
trainer/src/Trainer.cpp
with
138 additions
and
112 deletions
common/include/Dict.hpp
+
3
−
0
View file @
b495167c
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include
<unordered_map>
#include
<unordered_map>
#include
<vector>
#include
<vector>
#include
<filesystem>
#include
<filesystem>
#include
<mutex>
class
Dict
class
Dict
{
{
...
@@ -30,6 +31,7 @@ class Dict
...
@@ -30,6 +31,7 @@ class Dict
std
::
unordered_map
<
std
::
string
,
int
>
elementsToIndexes
;
std
::
unordered_map
<
std
::
string
,
int
>
elementsToIndexes
;
std
::
unordered_map
<
int
,
std
::
string
>
indexesToElements
;
std
::
unordered_map
<
int
,
std
::
string
>
indexesToElements
;
std
::
vector
<
int
>
nbOccs
;
std
::
vector
<
int
>
nbOccs
;
std
::
mutex
elementsMutex
;
State
state
;
State
state
;
bool
isCountingOccs
{
false
};
bool
isCountingOccs
{
false
};
...
@@ -43,6 +45,7 @@ class Dict
...
@@ -43,6 +45,7 @@ class Dict
void
readFromFile
(
const
char
*
filename
);
void
readFromFile
(
const
char
*
filename
);
void
insert
(
const
std
::
string
&
element
);
void
insert
(
const
std
::
string
&
element
);
void
reset
();
void
reset
();
int
_getIndexOrInsert
(
const
std
::
string
&
element
,
const
std
::
string
&
prefix
);
public
:
public
:
...
...
This diff is collapsed.
Click to expand it.
common/src/Dict.cpp
+
17
−
4
View file @
b495167c
...
@@ -90,20 +90,33 @@ void Dict::insert(const std::string & element)
...
@@ -90,20 +90,33 @@ void Dict::insert(const std::string & element)
}
}
int
Dict
::
getIndexOrInsert
(
const
std
::
string
&
element
,
const
std
::
string
&
prefix
)
int
Dict
::
getIndexOrInsert
(
const
std
::
string
&
element
,
const
std
::
string
&
prefix
)
{
if
(
state
==
State
::
Open
)
elementsMutex
.
lock
();
int
index
=
_getIndexOrInsert
(
element
,
prefix
);
if
(
state
==
State
::
Open
)
elementsMutex
.
unlock
();
return
index
;
}
int
Dict
::
_getIndexOrInsert
(
const
std
::
string
&
element
,
const
std
::
string
&
prefix
)
{
{
if
(
element
.
empty
())
if
(
element
.
empty
())
return
getIndexOrInsert
(
emptyValueStr
,
prefix
);
return
_
getIndexOrInsert
(
emptyValueStr
,
prefix
);
if
(
util
::
printedLength
(
element
)
==
1
and
util
::
isSeparator
(
util
::
utf8char
(
element
)))
if
(
util
::
printedLength
(
element
)
==
1
and
util
::
isSeparator
(
util
::
utf8char
(
element
)))
{
{
return
getIndexOrInsert
(
separatorValueStr
,
prefix
);
return
_
getIndexOrInsert
(
separatorValueStr
,
prefix
);
}
}
if
(
util
::
isNumber
(
element
))
if
(
util
::
isNumber
(
element
))
return
getIndexOrInsert
(
numberValueStr
,
prefix
);
return
_
getIndexOrInsert
(
numberValueStr
,
prefix
);
if
(
util
::
isUrl
(
element
))
if
(
util
::
isUrl
(
element
))
return
getIndexOrInsert
(
urlValueStr
,
prefix
);
return
_
getIndexOrInsert
(
urlValueStr
,
prefix
);
auto
prefixed
=
prefix
.
empty
()
?
element
:
fmt
::
format
(
"{}({})"
,
prefix
,
element
);
auto
prefixed
=
prefix
.
empty
()
?
element
:
fmt
::
format
(
"{}({})"
,
prefix
,
element
);
const
auto
&
found
=
elementsToIndexes
.
find
(
prefixed
);
const
auto
&
found
=
elementsToIndexes
.
find
(
prefixed
);
...
...
This diff is collapsed.
Click to expand it.
trainer/src/Trainer.cpp
+
118
−
108
View file @
b495167c
#include
"Trainer.hpp"
#include
"Trainer.hpp"
#include
"SubConfig.hpp"
#include
"SubConfig.hpp"
#include
<execution>
Trainer
::
Trainer
(
ReadingMachine
&
machine
,
int
batchSize
)
:
machine
(
machine
),
batchSize
(
batchSize
)
Trainer
::
Trainer
(
ReadingMachine
&
machine
,
int
batchSize
)
:
machine
(
machine
),
batchSize
(
batchSize
)
{
{
...
@@ -35,7 +36,8 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
...
@@ -35,7 +36,8 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
torch
::
AutoGradMode
useGrad
(
false
);
torch
::
AutoGradMode
useGrad
(
false
);
int
maxNbExamplesPerFile
=
50000
;
int
maxNbExamplesPerFile
=
50000
;
std
::
map
<
std
::
string
,
Examples
>
examplesPerState
;
std
::
unordered_map
<
std
::
string
,
Examples
>
examplesPerState
;
std
::
mutex
examplesMutex
;
std
::
filesystem
::
create_directories
(
dir
);
std
::
filesystem
::
create_directories
(
dir
);
...
@@ -46,9 +48,12 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
...
@@ -46,9 +48,12 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
fmt
::
print
(
stderr
,
"[{}] Starting to extract examples{}
\n
"
,
util
::
getTime
(),
dynamicOracle
?
", dynamic oracle"
:
""
);
fmt
::
print
(
stderr
,
"[{}] Starting to extract examples{}
\n
"
,
util
::
getTime
(),
dynamicOracle
?
", dynamic oracle"
:
""
);
int
totalNbExamples
=
0
;
std
::
atomic
<
int
>
totalNbExamples
=
0
;
for
(
auto
&
config
:
configs
)
NeuralNetworkImpl
::
device
=
torch
::
kCPU
;
machine
.
to
(
NeuralNetworkImpl
::
device
);
std
::
for_each
(
std
::
execution
::
par_unseq
,
configs
.
begin
(),
configs
.
end
(),
[
this
,
maxNbExamplesPerFile
,
&
examplesPerState
,
&
totalNbExamples
,
debug
,
dynamicOracle
,
explorationThreshold
,
dir
,
epoch
,
&
examplesMutex
](
SubConfig
&
config
)
{
{
config
.
addPredicted
(
machine
.
getPredicted
());
config
.
addPredicted
(
machine
.
getPredicted
());
config
.
setStrategy
(
machine
.
getStrategyDefinition
());
config
.
setStrategy
(
machine
.
getStrategyDefinition
());
...
@@ -157,9 +162,11 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
...
@@ -157,9 +162,11 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
if
(
totalNbExamples
>=
(
int
)
safetyNbExamplesMax
)
if
(
totalNbExamples
>=
(
int
)
safetyNbExamplesMax
)
util
::
myThrow
(
fmt
::
format
(
"Trying to extract more examples than the limit ({})"
,
util
::
int2HumanStr
(
safetyNbExamplesMax
)));
util
::
myThrow
(
fmt
::
format
(
"Trying to extract more examples than the limit ({})"
,
util
::
int2HumanStr
(
safetyNbExamplesMax
)));
examplesMutex
.
lock
();
examplesPerState
[
config
.
getState
()].
addContext
(
context
);
examplesPerState
[
config
.
getState
()].
addContext
(
context
);
examplesPerState
[
config
.
getState
()].
addClass
(
machine
.
getClassifier
(
config
.
getState
())
->
getLossFunction
(),
nbClasses
,
goldIndexes
);
examplesPerState
[
config
.
getState
()].
addClass
(
machine
.
getClassifier
(
config
.
getState
())
->
getLossFunction
(),
nbClasses
,
goldIndexes
);
examplesPerState
[
config
.
getState
()].
saveIfNeeded
(
config
.
getState
(),
dir
,
maxNbExamplesPerFile
,
epoch
,
dynamicOracle
);
examplesPerState
[
config
.
getState
()].
saveIfNeeded
(
config
.
getState
(),
dir
,
maxNbExamplesPerFile
,
epoch
,
dynamicOracle
);
examplesMutex
.
unlock
();
}
}
config
.
setChosenActionScore
(
bestScore
);
config
.
setChosenActionScore
(
bestScore
);
...
@@ -179,11 +186,14 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
...
@@ -179,11 +186,14 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
if
(
config
.
needsUpdate
())
if
(
config
.
needsUpdate
())
config
.
update
();
config
.
update
();
}
// End while true
}
// End while true
}
// End for on configs
}
);
// End for on configs
for
(
auto
&
it
:
examplesPerState
)
for
(
auto
&
it
:
examplesPerState
)
it
.
second
.
saveIfNeeded
(
it
.
first
,
dir
,
0
,
epoch
,
dynamicOracle
);
it
.
second
.
saveIfNeeded
(
it
.
first
,
dir
,
0
,
epoch
,
dynamicOracle
);
NeuralNetworkImpl
::
device
=
NeuralNetworkImpl
::
getPreferredDevice
();
machine
.
to
(
NeuralNetworkImpl
::
device
);
std
::
FILE
*
f
=
std
::
fopen
(
currentEpochAllExtractedFile
.
c_str
(),
"w"
);
std
::
FILE
*
f
=
std
::
fopen
(
currentEpochAllExtractedFile
.
c_str
(),
"w"
);
if
(
!
f
)
if
(
!
f
)
util
::
myThrow
(
fmt
::
format
(
"could not create file '{}'"
,
currentEpochAllExtractedFile
.
c_str
()));
util
::
myThrow
(
fmt
::
format
(
"could not create file '{}'"
,
currentEpochAllExtractedFile
.
c_str
()));
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment