Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
O
old_macaon
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Deploy
Releases
Container registry
Model registry
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Franck Dary
old_macaon
Commits
535710e6
Commit
535710e6
authored
Feb 3, 2019
by
Franck Dary
Browse files
Options
Downloads
Patches
Plain Diff
Refactored code of Decoder
parent
7d08c2a0
Branches
Branches containing commit
No related tags found
No related merge requests found
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
decoder/src/Decoder.cpp
+129
-87
129 additions, 87 deletions
decoder/src/Decoder.cpp
with
129 additions
and
87 deletions
decoder/src/Decoder.cpp
+
129
−
87
View file @
535710e6
...
@@ -9,34 +9,43 @@ Decoder::Decoder(TransitionMachine & tm, Config & config)
...
@@ -9,34 +9,43 @@ Decoder::Decoder(TransitionMachine & tm, Config & config)
{
{
}
}
void
Decoder
::
decode
()
struct
EndOfDecode
:
public
std
::
exception
{
{
float
entropyAccumulator
=
0.0
;
const
char
*
what
()
const
throw
()
int
nbActionsInSequence
=
0
;
bool
justFlipped
=
false
;
Errors
errors
;
errors
.
newSequence
();
int
nbActions
=
0
;
int
nbActionsCutoff
=
200
;
float
currentSpeed
=
0.0
;
auto
pastTime
=
std
::
chrono
::
high_resolution_clock
::
now
();
while
(
!
config
.
isFinal
())
{
{
TransitionMachine
::
State
*
currentState
=
tm
.
getCurrentState
();
return
"End of Decode"
;
Classifier
*
classifier
=
currentState
->
classifier
;
}
config
.
setCurrentStateName
(
&
currentState
->
name
);
};
Dict
::
currentClassifierName
=
classifier
->
name
;
if
(
ProgramParameters
::
debug
)
void
checkAndRecordError
(
Config
&
config
,
Classifier
*
classifier
,
Classifier
::
WeightedActions
&
weightedActions
,
Action
*
action
,
Errors
&
errors
)
{
if
(
classifier
->
needsTrain
()
&&
ProgramParameters
::
errorAnalysis
&&
(
classifier
->
name
==
ProgramParameters
::
classifierName
||
ProgramParameters
::
classifierName
.
empty
()))
{
auto
zeroCostActions
=
classifier
->
getZeroCostActions
(
config
);
if
(
zeroCostActions
.
empty
())
{
{
fprintf
(
stderr
,
"ERROR (%s) : could not find zero cost action for classifier
\'
%s
\'
. Aborting.
\n
"
,
ERRINFO
,
classifier
->
name
.
c_str
());
config
.
printForDebug
(
stderr
);
config
.
printForDebug
(
stderr
);
fprintf
(
stderr
,
"State :
\'
%s
\'\n
"
,
currentState
->
name
.
c_str
());
for
(
auto
&
a
:
weightedActions
)
{
fprintf
(
stderr
,
"%s : "
,
a
.
second
.
second
.
c_str
());
Oracle
::
explainCostOfAction
(
stderr
,
config
,
a
.
second
.
second
);
}
exit
(
1
);
}
std
::
string
oAction
=
zeroCostActions
[
0
];
for
(
auto
&
s
:
zeroCostActions
)
if
(
action
->
name
==
s
)
oAction
=
s
;
int
actionCost
=
classifier
->
getActionCost
(
config
,
action
->
name
);
int
linkLengthPrediction
=
ActionBank
::
getLinkLength
(
config
,
action
->
name
);
int
linkLengthGold
=
ActionBank
::
getLinkLength
(
config
,
oAction
);
errors
.
add
({
action
->
name
,
oAction
,
weightedActions
,
actionCost
,
linkLengthPrediction
,
linkLengthGold
});
}
}
}
auto
weightedActions
=
classifier
->
weightActions
(
config
);
void
printAdvancement
(
Config
&
config
,
float
currentSpeed
)
{
// Print current iter advancement in percentage
if
(
ProgramParameters
::
interactive
)
if
(
ProgramParameters
::
interactive
)
{
{
int
totalSize
=
config
.
tapes
[
0
].
hyp
.
size
();
int
totalSize
=
config
.
tapes
[
0
].
hyp
.
size
();
...
@@ -44,13 +53,24 @@ void Decoder::decode()
...
@@ -44,13 +53,24 @@ void Decoder::decode()
if
(
steps
&&
(
steps
%
200
==
0
||
totalSize
-
steps
<
200
))
if
(
steps
&&
(
steps
%
200
==
0
||
totalSize
-
steps
<
200
))
fprintf
(
stderr
,
"Decode : %.2f%% speed : %s actions/s
\r
"
,
100.0
*
steps
/
totalSize
,
int2humanStr
((
int
)
currentSpeed
).
c_str
());
fprintf
(
stderr
,
"Decode : %.2f%% speed : %s actions/s
\r
"
,
100.0
*
steps
/
totalSize
,
int2humanStr
((
int
)
currentSpeed
).
c_str
());
}
}
}
void
printDebugInfos
(
FILE
*
output
,
Config
&
config
,
TransitionMachine
&
tm
,
Classifier
::
WeightedActions
&
weightedActions
)
{
if
(
ProgramParameters
::
debug
)
if
(
ProgramParameters
::
debug
)
{
{
Classifier
::
printWeightedActions
(
stderr
,
weightedActions
);
TransitionMachine
::
State
*
currentState
=
tm
.
getCurrentState
();
fprintf
(
stderr
,
"
\n
"
);
config
.
printForDebug
(
output
);
fprintf
(
output
,
"State :
\'
%s
\'\n
"
,
currentState
->
name
.
c_str
());
Classifier
::
printWeightedActions
(
output
,
weightedActions
);
fprintf
(
output
,
"
\n
"
);
}
}
}
std
::
string
&
getClassifierAction
(
Config
&
config
,
Classifier
::
WeightedActions
&
weightedActions
,
Classifier
*
classifier
)
{
std
::
string
&
predictedAction
=
weightedActions
[
0
].
second
.
second
;
std
::
string
&
predictedAction
=
weightedActions
[
0
].
second
.
second
;
Action
*
action
=
classifier
->
getAction
(
predictedAction
);
Action
*
action
=
classifier
->
getAction
(
predictedAction
);
...
@@ -70,7 +90,7 @@ void Decoder::decode()
...
@@ -70,7 +90,7 @@ void Decoder::decode()
{
{
while
(
!
config
.
stackEmpty
())
while
(
!
config
.
stackEmpty
())
config
.
stackPop
();
config
.
stackPop
();
continue
;
throw
EndOfDecode
()
;
}
}
else
else
{
{
...
@@ -79,44 +99,11 @@ void Decoder::decode()
...
@@ -79,44 +99,11 @@ void Decoder::decode()
}
}
}
}
if
(
classifier
->
needsTrain
()
&&
ProgramParameters
::
errorAnalysis
&&
(
classifier
->
name
==
ProgramParameters
::
classifierName
||
ProgramParameters
::
classifierName
.
empty
()))
return
predictedAction
;
{
auto
zeroCostActions
=
classifier
->
getZeroCostActions
(
config
);
if
(
zeroCostActions
.
empty
())
{
fprintf
(
stderr
,
"ERROR (%s) : could not find zero cost action for classifier
\'
%s
\'
. Aborting.
\n
"
,
ERRINFO
,
classifier
->
name
.
c_str
());
config
.
printForDebug
(
stderr
);
for
(
auto
&
a
:
weightedActions
)
{
fprintf
(
stderr
,
"%s : "
,
a
.
second
.
second
.
c_str
());
Oracle
::
explainCostOfAction
(
stderr
,
config
,
a
.
second
.
second
);
}
exit
(
1
);
}
std
::
string
oAction
=
zeroCostActions
[
0
];
for
(
auto
&
s
:
zeroCostActions
)
if
(
action
->
name
==
s
)
oAction
=
s
;
int
actionCost
=
classifier
->
getActionCost
(
config
,
action
->
name
);
int
linkLengthPrediction
=
ActionBank
::
getLinkLength
(
config
,
action
->
name
);
int
linkLengthGold
=
ActionBank
::
getLinkLength
(
config
,
oAction
);
errors
.
add
({
action
->
name
,
oAction
,
weightedActions
,
actionCost
,
linkLengthPrediction
,
linkLengthGold
});
}
}
TransitionMachine
::
Transition
*
transition
=
tm
.
getTransition
(
predictedAction
);
void
computeSpeed
(
std
::
chrono
::
time_point
<
std
::
chrono
::
system_clock
>
&
pastTime
,
int
&
nbActions
,
int
&
nbActionsCutoff
,
float
&
currentSpeed
)
{
action
->
setInfos
(
transition
->
headMvt
,
currentState
->
name
);
action
->
apply
(
config
);
tm
.
takeTransition
(
transition
);
float
entropy
=
Classifier
::
computeEntropy
(
weightedActions
);
config
.
addToEntropyHistory
(
entropy
);
nbActionsInSequence
++
;
nbActions
++
;
if
(
nbActions
>=
nbActionsCutoff
)
if
(
nbActions
>=
nbActionsCutoff
)
{
{
auto
actualTime
=
std
::
chrono
::
high_resolution_clock
::
now
();
auto
actualTime
=
std
::
chrono
::
high_resolution_clock
::
now
();
...
@@ -128,9 +115,10 @@ void Decoder::decode()
...
@@ -128,9 +115,10 @@ void Decoder::decode()
nbActions
=
0
;
nbActions
=
0
;
}
}
}
entropyAccumulator
+=
entropy
;
void
computeAndPrintSequenceEntropy
(
Config
&
config
,
bool
&
justFlipped
,
Errors
&
errors
,
float
&
entropyAccumulator
,
int
&
nbActionsInSequence
)
{
if
(
ProgramParameters
::
printEntropy
||
ProgramParameters
::
errorAnalysis
)
if
(
ProgramParameters
::
printEntropy
||
ProgramParameters
::
errorAnalysis
)
{
{
if
(
config
.
head
>=
1
&&
config
.
getTape
(
ProgramParameters
::
sequenceDelimiterTape
)[
config
.
head
-
1
]
!=
ProgramParameters
::
sequenceDelimiter
)
if
(
config
.
head
>=
1
&&
config
.
getTape
(
ProgramParameters
::
sequenceDelimiterTape
)[
config
.
head
-
1
]
!=
ProgramParameters
::
sequenceDelimiter
)
...
@@ -147,7 +135,61 @@ void Decoder::decode()
...
@@ -147,7 +135,61 @@ void Decoder::decode()
entropyAccumulator
=
0.0
;
entropyAccumulator
=
0.0
;
}
}
}
}
}
void
computeAndRecordEntropy
(
Config
&
config
,
Classifier
::
WeightedActions
&
weightedActions
,
float
&
entropyAccumulator
)
{
float
entropy
=
Classifier
::
computeEntropy
(
weightedActions
);
config
.
addToEntropyHistory
(
entropy
);
entropyAccumulator
+=
entropy
;
}
void
applyActionAndTakeTransition
(
TransitionMachine
&
tm
,
Action
*
action
,
Config
&
config
)
{
TransitionMachine
::
Transition
*
transition
=
tm
.
getTransition
(
action
->
name
);
action
->
setInfos
(
transition
->
headMvt
,
tm
.
getCurrentState
()
->
name
);
action
->
apply
(
config
);
tm
.
takeTransition
(
transition
);
}
void
Decoder
::
decode
()
{
float
entropyAccumulator
=
0.0
;
int
nbActionsInSequence
=
0
;
bool
justFlipped
=
false
;
Errors
errors
;
errors
.
newSequence
();
int
nbActions
=
0
;
int
nbActionsCutoff
=
200
;
float
currentSpeed
=
0.0
;
auto
pastTime
=
std
::
chrono
::
high_resolution_clock
::
now
();
while
(
!
config
.
isFinal
())
{
TransitionMachine
::
State
*
currentState
=
tm
.
getCurrentState
();
Classifier
*
classifier
=
currentState
->
classifier
;
config
.
setCurrentStateName
(
&
currentState
->
name
);
Dict
::
currentClassifierName
=
classifier
->
name
;
auto
weightedActions
=
classifier
->
weightActions
(
config
);
printAdvancement
(
config
,
currentSpeed
);
printDebugInfos
(
stderr
,
config
,
tm
,
weightedActions
);
std
::
string
predictedAction
;
try
{
predictedAction
=
getClassifierAction
(
config
,
weightedActions
,
classifier
);}
catch
(
EndOfDecode
&
)
{
continue
;};
Action
*
action
=
classifier
->
getAction
(
predictedAction
);
checkAndRecordError
(
config
,
classifier
,
weightedActions
,
action
,
errors
);
applyActionAndTakeTransition
(
tm
,
action
,
config
);
nbActionsInSequence
++
;
nbActions
++
;
computeSpeed
(
pastTime
,
nbActions
,
nbActionsCutoff
,
currentSpeed
);
computeAndRecordEntropy
(
config
,
weightedActions
,
entropyAccumulator
);
computeAndPrintSequenceEntropy
(
config
,
justFlipped
,
errors
,
entropyAccumulator
,
nbActionsInSequence
);
}
}
if
(
ProgramParameters
::
errorAnalysis
)
if
(
ProgramParameters
::
errorAnalysis
)
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment