Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
R
RL-Parsing
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Franck Dary
RL-Parsing
Commits
98e2ffb0
Commit
98e2ffb0
authored
Sep 18, 2021
by
Franck Dary
Browse files
Options
Downloads
Patches
Plain Diff
Feature canBack is now only used when back is available
parent
1e33b269
Branches
Branches containing commit
Tags
Tags containing commit
No related merge requests found
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
Networks.py
+20
-20
20 additions, 20 deletions
Networks.py
Train.py
+8
-8
8 additions, 8 deletions
Train.py
main.py
+6
-1
6 additions, 1 deletion
main.py
with
34 additions
and
29 deletions
Networks.py
+
20
−
20
View file @
98e2ffb0
...
@@ -37,7 +37,7 @@ def getNeededDicts(name) :
...
@@ -37,7 +37,7 @@ def getNeededDicts(name) :
################################################################################
################################################################################
################################################################################
################################################################################
def
createNetwork
(
name
,
dicts
,
outputSizes
,
incremental
,
pretrained
)
:
def
createNetwork
(
name
,
dicts
,
outputSizes
,
incremental
,
pretrained
,
hasBack
)
:
featureFunctionAll
=
"
b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1
"
featureFunctionAll
=
"
b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1
"
featureFunctionNostack
=
"
b.-2 b.-1 b.0 b.1 b.2
"
featureFunctionNostack
=
"
b.-2 b.-1 b.0 b.1 b.2
"
historyNb
=
10
historyNb
=
10
...
@@ -48,26 +48,20 @@ def createNetwork(name, dicts, outputSizes, incremental, pretrained) :
...
@@ -48,26 +48,20 @@ def createNetwork(name, dicts, outputSizes, incremental, pretrained) :
columns
=
[
"
UPOS
"
,
"
FORM
"
]
columns
=
[
"
UPOS
"
,
"
FORM
"
]
if
name
==
"
base
"
:
if
name
==
"
base
"
:
return
BaseNet
(
dicts
,
outputSizes
,
incremental
,
featureFunctionAll
,
historyNb
,
historyPopNb
,
suffixSize
,
prefixSize
,
columns
,
hiddenSize
,
pretrained
)
return
BaseNet
(
dicts
,
outputSizes
,
incremental
,
featureFunctionAll
,
historyNb
,
historyPopNb
,
suffixSize
,
prefixSize
,
columns
,
hiddenSize
,
pretrained
,
hasBack
)
elif
name
==
"
semi
"
:
elif
name
==
"
baseNoLetters
"
:
return
SemiNet
(
dicts
,
outputSizes
,
incremental
,
featureFunctionAll
,
historyNb
,
suffixSize
,
prefixSize
,
columns
,
hiddenSize
)
return
BaseNet
(
dicts
,
outputSizes
,
incremental
,
featureFunctionAll
,
historyNb
,
historyPopNb
,
0
,
0
,
columns
,
hiddenSize
,
pretrained
,
hasBack
)
elif
name
==
"
big
"
:
return
BaseNet
(
dicts
,
outputSizes
,
incremental
,
featureFunctionAll
,
historyNb
,
suffixSize
,
prefixSize
,
columns
,
hiddenSize
*
2
,
pretrained
)
elif
name
==
"
lstm
"
:
return
LSTMNet
(
dicts
,
outputSizes
,
incremental
)
elif
name
==
"
separated
"
:
return
SeparatedNet
(
dicts
,
outputSizes
,
incremental
,
featureFunctionAll
,
historyNb
,
historyPopNb
,
suffixSize
,
prefixSize
,
columns
,
hiddenSize
)
elif
name
==
"
tagger
"
:
elif
name
==
"
tagger
"
:
return
BaseNet
(
dicts
,
outputSizes
,
incremental
,
featureFunctionNostack
,
historyNb
,
historyPopNb
,
suffixSize
,
prefixSize
,
columns
,
hiddenSize
,
pretrained
)
return
BaseNet
(
dicts
,
outputSizes
,
incremental
,
featureFunctionNostack
,
historyNb
,
historyPopNb
,
suffixSize
,
prefixSize
,
columns
,
hiddenSize
,
pretrained
,
hasBack
)
elif
name
==
"
taggerLexicon
"
:
elif
name
==
"
taggerLexicon
"
:
return
BaseNet
(
dicts
,
outputSizes
,
incremental
,
featureFunctionNostack
,
historyNb
,
historyPopNb
,
suffixSize
,
prefixSize
,
[
"
UPOS
"
,
"
FORM
"
,
"
LEXICON
"
],
hiddenSize
,
pretrained
)
return
BaseNet
(
dicts
,
outputSizes
,
incremental
,
featureFunctionNostack
,
historyNb
,
historyPopNb
,
suffixSize
,
prefixSize
,
[
"
UPOS
"
,
"
FORM
"
,
"
LEXICON
"
],
hiddenSize
,
pretrained
,
hasBack
)
raise
Exception
(
"
Unknown network name
'
%s
'"
%
name
)
raise
Exception
(
"
Unknown network name
'
%s
'"
%
name
)
################################################################################
################################################################################
################################################################################
################################################################################
class
BaseNet
(
nn
.
Module
):
class
BaseNet
(
nn
.
Module
):
def
__init__
(
self
,
dicts
,
outputSizes
,
incremental
,
featureFunction
,
historyNb
,
historyPopNb
,
suffixSize
,
prefixSize
,
columns
,
hiddenSize
,
pretrained
)
:
def
__init__
(
self
,
dicts
,
outputSizes
,
incremental
,
featureFunction
,
historyNb
,
historyPopNb
,
suffixSize
,
prefixSize
,
columns
,
hiddenSize
,
pretrained
,
hasBack
)
:
super
().
__init__
()
super
().
__init__
()
self
.
dummyParam
=
nn
.
Parameter
(
torch
.
empty
(
0
),
requires_grad
=
False
)
self
.
dummyParam
=
nn
.
Parameter
(
torch
.
empty
(
0
),
requires_grad
=
False
)
...
@@ -79,6 +73,7 @@ class BaseNet(nn.Module):
...
@@ -79,6 +73,7 @@ class BaseNet(nn.Module):
self
.
suffixSize
=
suffixSize
self
.
suffixSize
=
suffixSize
self
.
prefixSize
=
prefixSize
self
.
prefixSize
=
prefixSize
self
.
columns
=
columns
self
.
columns
=
columns
self
.
hasBack
=
hasBack
self
.
embSize
=
64
self
.
embSize
=
64
embSizes
=
{}
embSizes
=
{}
...
@@ -94,10 +89,10 @@ class BaseNet(nn.Module):
...
@@ -94,10 +89,10 @@ class BaseNet(nn.Module):
else
:
else
:
embSizes
[
name
]
=
self
.
embSize
embSizes
[
name
]
=
self
.
embSize
self
.
add_module
(
"
emb_
"
+
name
,
nn
.
Embedding
(
len
(
dicts
.
dicts
[
name
]),
self
.
embSize
))
self
.
add_module
(
"
emb_
"
+
name
,
nn
.
Embedding
(
len
(
dicts
.
dicts
[
name
]),
self
.
embSize
))
self
.
inputSize
=
(
self
.
historyNb
+
self
.
historyPopNb
)
*
embSizes
[
"
HISTORY
"
]
+
(
self
.
suffixSize
+
self
.
prefixSize
)
*
embSizes
[
"
LETTER
"
]
+
sum
([
self
.
nbTargets
*
embSizes
[
col
]
for
col
in
self
.
columns
])
self
.
inputSize
=
(
self
.
historyNb
+
self
.
historyPopNb
)
*
embSizes
.
get
(
"
HISTORY
"
,
0
)
+
(
self
.
suffixSize
+
self
.
prefixSize
)
*
embSizes
.
get
(
"
LETTER
"
,
0
)
+
sum
([
self
.
nbTargets
*
embSizes
.
get
(
col
,
0
)
for
col
in
self
.
columns
])
self
.
fc1
=
nn
.
Linear
(
self
.
inputSize
,
hiddenSize
)
self
.
fc1
=
nn
.
Linear
(
self
.
inputSize
,
hiddenSize
)
for
i
in
range
(
len
(
outputSizes
))
:
for
i
in
range
(
len
(
outputSizes
))
:
self
.
add_module
(
"
output_
"
+
str
(
i
),
nn
.
Linear
(
hiddenSize
+
1
,
outputSizes
[
i
]))
self
.
add_module
(
"
output_
"
+
str
(
i
),
nn
.
Linear
(
hiddenSize
+
(
1
if
self
.
hasBack
else
0
)
,
outputSizes
[
i
]))
self
.
dropout
=
nn
.
Dropout
(
0.3
)
self
.
dropout
=
nn
.
Dropout
(
0.3
)
self
.
apply
(
self
.
initWeights
)
self
.
apply
(
self
.
initWeights
)
...
@@ -107,6 +102,7 @@ class BaseNet(nn.Module):
...
@@ -107,6 +102,7 @@ class BaseNet(nn.Module):
def
forward
(
self
,
x
)
:
def
forward
(
self
,
x
)
:
embeddings
=
[]
embeddings
=
[]
if
self
.
hasBack
:
canBack
=
x
[...,
0
:
1
]
canBack
=
x
[...,
0
:
1
]
x
=
x
[...,
1
:]
x
=
x
[...,
1
:]
...
@@ -132,6 +128,7 @@ class BaseNet(nn.Module):
...
@@ -132,6 +128,7 @@ class BaseNet(nn.Module):
curIndex
=
curIndex
+
self
.
suffixSize
curIndex
=
curIndex
+
self
.
suffixSize
y
=
self
.
dropout
(
y
)
y
=
self
.
dropout
(
y
)
y
=
F
.
relu
(
self
.
dropout
(
self
.
fc1
(
y
)))
y
=
F
.
relu
(
self
.
dropout
(
self
.
fc1
(
y
)))
if
self
.
hasBack
:
y
=
torch
.
cat
([
y
,
canBack
],
1
)
y
=
torch
.
cat
([
y
,
canBack
],
1
)
y
=
getattr
(
self
,
"
output_
"
+
str
(
self
.
state
))(
y
)
y
=
getattr
(
self
,
"
output_
"
+
str
(
self
.
state
))(
y
)
return
y
return
y
...
@@ -150,8 +147,11 @@ class BaseNet(nn.Module):
...
@@ -150,8 +147,11 @@ class BaseNet(nn.Module):
historyPopValues
=
Features
.
extractHistoryPopFeatures
(
dicts
,
config
,
self
.
historyPopNb
)
historyPopValues
=
Features
.
extractHistoryPopFeatures
(
dicts
,
config
,
self
.
historyPopNb
)
prefixValues
=
Features
.
extractPrefixFeatures
(
dicts
,
config
,
self
.
prefixSize
)
prefixValues
=
Features
.
extractPrefixFeatures
(
dicts
,
config
,
self
.
prefixSize
)
suffixValues
=
Features
.
extractSuffixFeatures
(
dicts
,
config
,
self
.
suffixSize
)
suffixValues
=
Features
.
extractSuffixFeatures
(
dicts
,
config
,
self
.
suffixSize
)
backAction
=
None
if
self
.
hasBack
:
backAction
=
torch
.
ones
(
1
,
dtype
=
torch
.
int
)
if
Transition
.
Transition
(
"
BACK 1
"
).
appliable
(
config
)
else
torch
.
zeros
(
1
,
dtype
=
torch
.
int
)
backAction
=
torch
.
ones
(
1
,
dtype
=
torch
.
int
)
if
Transition
.
Transition
(
"
BACK 1
"
).
appliable
(
config
)
else
torch
.
zeros
(
1
,
dtype
=
torch
.
int
)
return
torch
.
cat
([
backAction
,
colsValues
,
historyValues
,
historyPopValues
,
prefixValues
,
suffixValues
])
allFeatures
=
[
f
for
f
in
[
backAction
,
colsValues
,
historyValues
,
historyPopValues
,
prefixValues
,
suffixValues
]
if
f
is
not
None
]
return
torch
.
cat
(
allFeatures
)
################################################################################
################################################################################
################################################################################
################################################################################
...
...
This diff is collapsed.
Click to expand it.
Train.py
+
8
−
8
View file @
98e2ffb0
...
@@ -18,15 +18,15 @@ import Config
...
@@ -18,15 +18,15 @@ import Config
from
conll18_ud_eval
import
load_conllu
,
evaluate
from
conll18_ud_eval
import
load_conllu
,
evaluate
################################################################################
################################################################################
def
trainMode
(
debug
,
networkName
,
filename
,
type
,
transitionSet
,
strategy
,
modelDir
,
nbIter
,
batchSize
,
devFile
,
bootstrapInterval
,
incremental
,
rewardFunc
,
lr
,
gamma
,
probas
,
countBreak
,
predicted
,
pretrained
,
silent
=
False
)
:
def
trainMode
(
debug
,
networkName
,
filename
,
type
,
transitionSet
,
strategy
,
modelDir
,
nbIter
,
batchSize
,
devFile
,
bootstrapInterval
,
incremental
,
rewardFunc
,
lr
,
gamma
,
probas
,
countBreak
,
predicted
,
pretrained
,
silent
=
False
,
hasBack
=
False
)
:
sentences
=
Config
.
readConllu
(
filename
,
predicted
)
sentences
=
Config
.
readConllu
(
filename
,
predicted
)
if
type
==
"
oracle
"
:
if
type
==
"
oracle
"
:
trainModelOracle
(
debug
,
networkName
,
modelDir
,
filename
,
nbIter
,
batchSize
,
devFile
,
transitionSet
,
strategy
,
sentences
,
bootstrapInterval
,
incremental
,
rewardFunc
,
lr
,
predicted
,
pretrained
,
silent
)
trainModelOracle
(
debug
,
networkName
,
modelDir
,
filename
,
nbIter
,
batchSize
,
devFile
,
transitionSet
,
strategy
,
sentences
,
bootstrapInterval
,
incremental
,
rewardFunc
,
lr
,
predicted
,
pretrained
,
silent
,
hasBack
)
return
return
if
type
==
"
rl
"
:
if
type
==
"
rl
"
:
trainModelRl
(
debug
,
networkName
,
modelDir
,
filename
,
nbIter
,
batchSize
,
devFile
,
transitionSet
,
strategy
,
sentences
,
incremental
,
rewardFunc
,
lr
,
gamma
,
probas
,
countBreak
,
predicted
,
pretrained
,
silent
)
trainModelRl
(
debug
,
networkName
,
modelDir
,
filename
,
nbIter
,
batchSize
,
devFile
,
transitionSet
,
strategy
,
sentences
,
incremental
,
rewardFunc
,
lr
,
gamma
,
probas
,
countBreak
,
predicted
,
pretrained
,
silent
,
hasBack
)
return
return
print
(
"
ERROR : unknown type
'
%s
'"
%
type
,
file
=
sys
.
stderr
)
print
(
"
ERROR : unknown type
'
%s
'"
%
type
,
file
=
sys
.
stderr
)
...
@@ -100,7 +100,7 @@ def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss
...
@@ -100,7 +100,7 @@ def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss
################################################################################
################################################################################
################################################################################
################################################################################
def
trainModelOracle
(
debug
,
networkName
,
modelDir
,
filename
,
nbEpochs
,
batchSize
,
devFile
,
transitionSets
,
strategy
,
sentencesOriginal
,
bootstrapInterval
,
incremental
,
rewardFunc
,
lr
,
predicted
,
pretrained
,
silent
=
False
)
:
def
trainModelOracle
(
debug
,
networkName
,
modelDir
,
filename
,
nbEpochs
,
batchSize
,
devFile
,
transitionSets
,
strategy
,
sentencesOriginal
,
bootstrapInterval
,
incremental
,
rewardFunc
,
lr
,
predicted
,
pretrained
,
silent
=
False
,
hasBack
=
False
)
:
dicts
=
Dicts
()
dicts
=
Dicts
()
dicts
.
readConllu
(
filename
,
Networks
.
getNeededDicts
(
networkName
),
2
,
pretrained
)
dicts
.
readConllu
(
filename
,
Networks
.
getNeededDicts
(
networkName
),
2
,
pretrained
)
transitionNames
=
{}
transitionNames
=
{}
...
@@ -111,7 +111,7 @@ def trainModelOracle(debug, networkName, modelDir, filename, nbEpochs, batchSize
...
@@ -111,7 +111,7 @@ def trainModelOracle(debug, networkName, modelDir, filename, nbEpochs, batchSize
dicts
.
addDict
(
"
HISTORY
"
,
transitionNames
)
dicts
.
addDict
(
"
HISTORY
"
,
transitionNames
)
dicts
.
save
(
modelDir
+
"
/dicts.json
"
)
dicts
.
save
(
modelDir
+
"
/dicts.json
"
)
network
=
Networks
.
createNetwork
(
networkName
,
dicts
,
[
len
(
transitionSet
)
for
transitionSet
in
transitionSets
],
incremental
,
pretrained
).
to
(
getDevice
())
network
=
Networks
.
createNetwork
(
networkName
,
dicts
,
[
len
(
transitionSet
)
for
transitionSet
in
transitionSets
],
incremental
,
pretrained
,
hasBack
).
to
(
getDevice
())
examples
=
[[]
for
_
in
transitionSets
]
examples
=
[[]
for
_
in
transitionSets
]
sentences
=
copy
.
deepcopy
(
sentencesOriginal
)
sentences
=
copy
.
deepcopy
(
sentencesOriginal
)
print
(
"
%s : Starting to extract examples...
"
%
(
timeStamp
()),
file
=
sys
.
stderr
)
print
(
"
%s : Starting to extract examples...
"
%
(
timeStamp
()),
file
=
sys
.
stderr
)
...
@@ -187,7 +187,7 @@ def trainModelOracle(debug, networkName, modelDir, filename, nbEpochs, batchSize
...
@@ -187,7 +187,7 @@ def trainModelOracle(debug, networkName, modelDir, filename, nbEpochs, batchSize
################################################################################
################################################################################
################################################################################
################################################################################
def
trainModelRl
(
debug
,
networkName
,
modelDir
,
filename
,
nbIter
,
batchSize
,
devFile
,
transitionSets
,
strategy
,
sentencesOriginal
,
incremental
,
rewardFunc
,
lr
,
gamma
,
probas
,
countBreak
,
predicted
,
pretrained
,
silent
=
False
)
:
def
trainModelRl
(
debug
,
networkName
,
modelDir
,
filename
,
nbIter
,
batchSize
,
devFile
,
transitionSets
,
strategy
,
sentencesOriginal
,
incremental
,
rewardFunc
,
lr
,
gamma
,
probas
,
countBreak
,
predicted
,
pretrained
,
silent
=
False
,
hasBack
=
False
)
:
memory
=
None
memory
=
None
dicts
=
Dicts
()
dicts
=
Dicts
()
...
@@ -207,8 +207,8 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF
...
@@ -207,8 +207,8 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF
policy_net
=
torch
.
load
(
modelDir
+
"
/lastNetwork.pt
"
)
policy_net
=
torch
.
load
(
modelDir
+
"
/lastNetwork.pt
"
)
target_net
=
torch
.
load
(
modelDir
+
"
/lastNetwork.pt
"
)
target_net
=
torch
.
load
(
modelDir
+
"
/lastNetwork.pt
"
)
else
:
else
:
policy_net
=
Networks
.
createNetwork
(
networkName
,
dicts
,
[
len
(
transitionSet
)
for
transitionSet
in
transitionSets
],
incremental
,
pretrained
).
to
(
getDevice
())
policy_net
=
Networks
.
createNetwork
(
networkName
,
dicts
,
[
len
(
transitionSet
)
for
transitionSet
in
transitionSets
],
incremental
,
pretrained
,
hasBack
).
to
(
getDevice
())
target_net
=
Networks
.
createNetwork
(
networkName
,
dicts
,
[
len
(
transitionSet
)
for
transitionSet
in
transitionSets
],
incremental
,
pretrained
).
to
(
getDevice
())
target_net
=
Networks
.
createNetwork
(
networkName
,
dicts
,
[
len
(
transitionSet
)
for
transitionSet
in
transitionSets
],
incremental
,
pretrained
,
hasBack
).
to
(
getDevice
())
target_net
.
load_state_dict
(
policy_net
.
state_dict
())
target_net
.
load_state_dict
(
policy_net
.
state_dict
())
target_net
.
eval
()
target_net
.
eval
()
policy_net
.
train
()
policy_net
.
train
()
...
...
This diff is collapsed.
Click to expand it.
main.py
+
6
−
1
View file @
98e2ffb0
...
@@ -85,6 +85,7 @@ if __name__ == "__main__" :
...
@@ -85,6 +85,7 @@ if __name__ == "__main__" :
args
.
bootstrap
=
int
(
args
.
bootstrap
)
args
.
bootstrap
=
int
(
args
.
bootstrap
)
networkName
=
args
.
network
networkName
=
args
.
network
hasBack
=
False
if
args
.
transitions
==
"
tagger
"
:
if
args
.
transitions
==
"
tagger
"
:
tmpDicts
=
Dicts
()
tmpDicts
=
Dicts
()
...
@@ -98,6 +99,7 @@ if __name__ == "__main__" :
...
@@ -98,6 +99,7 @@ if __name__ == "__main__" :
networkName
=
"
tagger
"
networkName
=
"
tagger
"
probas
=
[[
list
(
map
(
float
,
args
.
probaRandom
.
split
(
'
,
'
))),
list
(
map
(
float
,
args
.
probaOracle
.
split
(
'
,
'
)))]]
probas
=
[[
list
(
map
(
float
,
args
.
probaRandom
.
split
(
'
,
'
))),
list
(
map
(
float
,
args
.
probaOracle
.
split
(
'
,
'
)))]]
elif
args
.
transitions
==
"
taggerbt
"
:
elif
args
.
transitions
==
"
taggerbt
"
:
hasBack
=
True
tmpDicts
=
Dicts
()
tmpDicts
=
Dicts
()
tmpDicts
.
readConllu
(
args
.
corpus
,
[
"
UPOS
"
],
0
)
tmpDicts
.
readConllu
(
args
.
corpus
,
[
"
UPOS
"
],
0
)
tagActions
=
[
"
TAG UPOS %s
"
%
p
for
p
in
tmpDicts
.
getElementsOf
(
"
UPOS
"
)
if
"
__
"
not
in
p
and
not
isEmpty
(
p
)]
tagActions
=
[
"
TAG UPOS %s
"
%
p
for
p
in
tmpDicts
.
getElementsOf
(
"
UPOS
"
)
if
"
__
"
not
in
p
and
not
isEmpty
(
p
)]
...
@@ -118,6 +120,7 @@ if __name__ == "__main__" :
...
@@ -118,6 +120,7 @@ if __name__ == "__main__" :
networkName
=
"
base
"
networkName
=
"
base
"
probas
=
[[
list
(
map
(
float
,
args
.
probaRandom
.
split
(
'
,
'
))),
list
(
map
(
float
,
args
.
probaOracle
.
split
(
'
,
'
)))]]
probas
=
[[
list
(
map
(
float
,
args
.
probaRandom
.
split
(
'
,
'
))),
list
(
map
(
float
,
args
.
probaOracle
.
split
(
'
,
'
)))]]
elif
args
.
transitions
==
"
eagerbt
"
:
elif
args
.
transitions
==
"
eagerbt
"
:
hasBack
=
True
transitionSets
=
[[
Transition
(
"
NOBACK
"
),
Transition
(
"
BACK
"
+
args
.
backSize
)],
[
Transition
(
elem
)
for
elem
in
[
"
SHIFT
"
,
"
REDUCE
"
,
"
LEFT
"
,
"
RIGHT
"
]
if
len
(
elem
)
>
0
]]
transitionSets
=
[[
Transition
(
"
NOBACK
"
),
Transition
(
"
BACK
"
+
args
.
backSize
)],
[
Transition
(
elem
)
for
elem
in
[
"
SHIFT
"
,
"
REDUCE
"
,
"
LEFT
"
,
"
RIGHT
"
]
if
len
(
elem
)
>
0
]]
args
.
predictedStr
=
"
HEAD
"
args
.
predictedStr
=
"
HEAD
"
args
.
states
=
[
"
backer
"
,
"
parser
"
]
args
.
states
=
[
"
backer
"
,
"
parser
"
]
...
@@ -155,6 +158,7 @@ if __name__ == "__main__" :
...
@@ -155,6 +158,7 @@ if __name__ == "__main__" :
[
list
(
map
(
float
,
args
.
probaRandom
.
split
(
'
,
'
))),
list
(
map
(
float
,
args
.
probaOracle
.
split
(
'
,
'
)))]]
[
list
(
map
(
float
,
args
.
probaRandom
.
split
(
'
,
'
))),
list
(
map
(
float
,
args
.
probaOracle
.
split
(
'
,
'
)))]]
elif
args
.
transitions
==
"
tagparserbt
"
:
elif
args
.
transitions
==
"
tagparserbt
"
:
hasBack
=
True
tmpDicts
=
Dicts
()
tmpDicts
=
Dicts
()
tmpDicts
.
readConllu
(
args
.
corpus
,
[
"
UPOS
"
],
0
)
tmpDicts
.
readConllu
(
args
.
corpus
,
[
"
UPOS
"
],
0
)
tagActions
=
[
"
TAG UPOS %s
"
%
p
for
p
in
tmpDicts
.
getElementsOf
(
"
UPOS
"
)
if
"
__
"
not
in
p
and
not
isEmpty
(
p
)]
tagActions
=
[
"
TAG UPOS %s
"
%
p
for
p
in
tmpDicts
.
getElementsOf
(
"
UPOS
"
)
if
"
__
"
not
in
p
and
not
isEmpty
(
p
)]
...
@@ -168,6 +172,7 @@ if __name__ == "__main__" :
...
@@ -168,6 +172,7 @@ if __name__ == "__main__" :
[
list
(
map
(
float
,
args
.
probaRandom
.
split
(
'
,
'
))),
list
(
map
(
float
,
args
.
probaOracle
.
split
(
'
,
'
)))],
[
list
(
map
(
float
,
args
.
probaRandom
.
split
(
'
,
'
))),
list
(
map
(
float
,
args
.
probaOracle
.
split
(
'
,
'
)))],
[
list
(
map
(
float
,
args
.
probaRandom
.
split
(
'
,
'
))),
list
(
map
(
float
,
args
.
probaOracle
.
split
(
'
,
'
)))]]
[
list
(
map
(
float
,
args
.
probaRandom
.
split
(
'
,
'
))),
list
(
map
(
float
,
args
.
probaOracle
.
split
(
'
,
'
)))]]
elif
args
.
transitions
==
"
recovery
"
:
elif
args
.
transitions
==
"
recovery
"
:
hasBack
=
True
tmpDicts
=
Dicts
()
tmpDicts
=
Dicts
()
tmpDicts
.
readConllu
(
args
.
corpus
,
[
"
UPOS
"
],
0
)
tmpDicts
.
readConllu
(
args
.
corpus
,
[
"
UPOS
"
],
0
)
tagActions
=
[
"
TAG UPOS %s
"
%
p
for
p
in
tmpDicts
.
getElementsOf
(
"
UPOS
"
)
if
"
__
"
not
in
p
and
not
isEmpty
(
p
)]
tagActions
=
[
"
TAG UPOS %s
"
%
p
for
p
in
tmpDicts
.
getElementsOf
(
"
UPOS
"
)
if
"
__
"
not
in
p
and
not
isEmpty
(
p
)]
...
@@ -197,7 +202,7 @@ if __name__ == "__main__" :
...
@@ -197,7 +202,7 @@ if __name__ == "__main__" :
json
.
dump
([
args
.
predictedStr
,
[[
str
(
t
)
for
t
in
transitionSet
]
for
transitionSet
in
transitionSets
]],
open
(
args
.
model
+
"
/transitions.json
"
,
"
w
"
))
json
.
dump
([
args
.
predictedStr
,
[[
str
(
t
)
for
t
in
transitionSet
]
for
transitionSet
in
transitionSets
]],
open
(
args
.
model
+
"
/transitions.json
"
,
"
w
"
))
json
.
dump
(
strategy
,
open
(
args
.
model
+
"
/strategy.json
"
,
"
w
"
))
json
.
dump
(
strategy
,
open
(
args
.
model
+
"
/strategy.json
"
,
"
w
"
))
printTS
(
transitionSets
,
sys
.
stderr
)
printTS
(
transitionSets
,
sys
.
stderr
)
Train
.
trainMode
(
args
.
debug
,
networkName
,
args
.
corpus
,
args
.
type
,
transitionSets
,
strategy
,
args
.
model
,
int
(
args
.
iter
),
int
(
args
.
batchSize
),
args
.
dev
,
args
.
bootstrap
,
args
.
incr
,
args
.
reward
,
float
(
args
.
lr
),
float
(
args
.
gamma
),
probas
,
int
(
args
.
countBreak
),
args
.
predicted
,
args
.
pretrained
,
args
.
silent
)
Train
.
trainMode
(
args
.
debug
,
networkName
,
args
.
corpus
,
args
.
type
,
transitionSets
,
strategy
,
args
.
model
,
int
(
args
.
iter
),
int
(
args
.
batchSize
),
args
.
dev
,
args
.
bootstrap
,
args
.
incr
,
args
.
reward
,
float
(
args
.
lr
),
float
(
args
.
gamma
),
probas
,
int
(
args
.
countBreak
),
args
.
predicted
,
args
.
pretrained
,
args
.
silent
,
hasBack
)
elif
args
.
mode
==
"
decode
"
:
elif
args
.
mode
==
"
decode
"
:
transInfos
=
json
.
load
(
open
(
args
.
model
+
"
/transitions.json
"
,
"
r
"
))
transInfos
=
json
.
load
(
open
(
args
.
model
+
"
/transitions.json
"
,
"
r
"
))
transNames
=
json
.
load
(
open
(
args
.
model
+
"
/transitions.json
"
,
"
r
"
))[
1
]
transNames
=
json
.
load
(
open
(
args
.
model
+
"
/transitions.json
"
,
"
r
"
))[
1
]
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment