Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
R
RL-Parsing
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Franck Dary
RL-Parsing
Commits
bf43cbf5
Commit
bf43cbf5
authored
Apr 27, 2021
by
Franck Dary
Browse files
Options
Downloads
Patches
Plain Diff
Added Back transitions, problem : system can cycle
parent
39e488e9
No related branches found
No related tags found
No related merge requests found
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
Config.py
+12
-10
12 additions, 10 deletions
Config.py
Decode.py
+2
-4
2 additions, 4 deletions
Decode.py
Train.py
+11
-10
11 additions, 10 deletions
Train.py
Transition.py
+69
-10
69 additions, 10 deletions
Transition.py
main.py
+6
-2
6 additions, 2 deletions
main.py
with
100 additions
and
36 deletions
Config.py
+
12
−
10
View file @
bf43cbf5
...
@@ -14,6 +14,8 @@ class Config :
...
@@ -14,6 +14,8 @@ class Config :
self
.
stack
=
[]
self
.
stack
=
[]
self
.
comments
=
[]
self
.
comments
=
[]
self
.
history
=
[]
self
.
history
=
[]
self
.
historyHistory
=
set
()
self
.
historyPop
=
[]
def
addLine
(
self
,
cols
)
:
def
addLine
(
self
,
cols
)
:
self
.
lines
.
append
([[
val
,
""
]
for
val
in
cols
])
self
.
lines
.
append
([[
val
,
""
]
for
val
in
cols
])
...
@@ -22,8 +24,7 @@ class Config :
...
@@ -22,8 +24,7 @@ class Config :
def
get
(
self
,
lineIndex
,
colname
,
predicted
)
:
def
get
(
self
,
lineIndex
,
colname
,
predicted
)
:
if
lineIndex
not
in
range
(
len
(
self
.
lines
))
:
if
lineIndex
not
in
range
(
len
(
self
.
lines
))
:
print
(
"
Line index %d is out of range (0,%d)
"
%
(
lineIndex
,
len
(
self
.
lines
)),
file
=
sys
.
stderr
)
raise
(
Exception
(
"
Line index %d is out of range (0,%d)
"
%
(
lineIndex
,
len
(
self
.
lines
))))
exit
(
1
)
if
colname
not
in
self
.
col2index
:
if
colname
not
in
self
.
col2index
:
print
(
"
Unknown colname
'
%s
'"
%
(
colname
),
file
=
sys
.
stderr
)
print
(
"
Unknown colname
'
%s
'"
%
(
colname
),
file
=
sys
.
stderr
)
exit
(
1
)
exit
(
1
)
...
@@ -32,8 +33,7 @@ class Config :
...
@@ -32,8 +33,7 @@ class Config :
def
set
(
self
,
lineIndex
,
colname
,
value
,
predicted
=
True
)
:
def
set
(
self
,
lineIndex
,
colname
,
value
,
predicted
=
True
)
:
if
lineIndex
not
in
range
(
len
(
self
.
lines
))
:
if
lineIndex
not
in
range
(
len
(
self
.
lines
))
:
print
(
"
Line index %d is out of range (0,%d)
"
%
(
lineIndex
,
len
(
self
.
lines
)),
file
=
sys
.
stderr
)
raise
(
Exception
(
"
Line index %d is out of range (0,%d)
"
%
(
lineIndex
,
len
(
self
.
lines
))))
exit
(
1
)
if
colname
not
in
self
.
col2index
:
if
colname
not
in
self
.
col2index
:
print
(
"
Unknown colname
'
%s
'"
%
(
colname
),
file
=
sys
.
stderr
)
print
(
"
Unknown colname
'
%s
'"
%
(
colname
),
file
=
sys
.
stderr
)
exit
(
1
)
exit
(
1
)
...
@@ -50,22 +50,23 @@ class Config :
...
@@ -50,22 +50,23 @@ class Config :
self
.
stack
.
append
(
self
.
wordIndex
)
self
.
stack
.
append
(
self
.
wordIndex
)
def
popStack
(
self
)
:
def
popStack
(
self
)
:
self
.
stack
.
pop
()
return
self
.
stack
.
pop
()
# Move wordIndex by a relative forward movement if possible. Ignore multiwords.
# Move wordIndex by a relative forward movement if possible. Ignore multiwords.
# Don't go out of bounds, but don't fail either.
# Don't go out of bounds, but don't fail either.
# Return true if movement was completed.
# Return true if movement was completed.
def
moveWordIndex
(
self
,
movement
)
:
def
moveWordIndex
(
self
,
movement
)
:
done
=
0
done
=
0
relMov
=
1
if
movement
==
0
else
movement
//
abs
(
movement
)
if
self
.
isMultiword
(
self
.
wordIndex
)
:
if
self
.
isMultiword
(
self
.
wordIndex
)
:
self
.
wordIndex
+=
1
self
.
wordIndex
+=
relMov
while
done
!=
movement
:
while
done
!=
abs
(
movement
)
:
if
self
.
wordIndex
<
len
(
self
.
lines
)
-
1
:
if
self
.
wordIndex
+
relMov
in
range
(
0
,
len
(
(
self
.
lines
)
))
:
self
.
wordIndex
+=
1
self
.
wordIndex
+=
relMov
else
:
else
:
return
False
return
False
if
self
.
isMultiword
(
self
.
wordIndex
)
:
if
self
.
isMultiword
(
self
.
wordIndex
)
:
self
.
wordIndex
+=
1
self
.
wordIndex
+=
relMov
done
+=
1
done
+=
1
return
True
return
True
...
@@ -81,6 +82,7 @@ class Config :
...
@@ -81,6 +82,7 @@ class Config :
right
=
5
right
=
5
print
(
"
stack :
"
,[
self
.
getAsFeature
(
ind
,
"
ID
"
)
for
ind
in
self
.
stack
],
file
=
output
)
print
(
"
stack :
"
,[
self
.
getAsFeature
(
ind
,
"
ID
"
)
for
ind
in
self
.
stack
],
file
=
output
)
print
(
"
history :
"
,[
trans
.
name
for
trans
in
self
.
history
[
-
10
:]],
file
=
output
)
print
(
"
history :
"
,[
trans
.
name
for
trans
in
self
.
history
[
-
10
:]],
file
=
output
)
print
(
"
historyPop :
"
,[(
c
[
0
].
name
,
c
[
1
])
for
c
in
self
.
historyPop
[
-
10
:]],
file
=
output
)
toPrint
=
[]
toPrint
=
[]
for
lineIndex
in
range
(
self
.
wordIndex
-
left
,
self
.
wordIndex
+
right
)
:
for
lineIndex
in
range
(
self
.
wordIndex
-
left
,
self
.
wordIndex
+
right
)
:
if
lineIndex
not
in
range
(
len
(
self
.
lines
))
:
if
lineIndex
not
in
range
(
len
(
self
.
lines
))
:
...
...
This diff is collapsed.
Click to expand it.
Decode.py
+
2
−
4
View file @
bf43cbf5
...
@@ -67,15 +67,13 @@ def decodeModel(ts, strat, config, network, dicts, debug) :
...
@@ -67,15 +67,13 @@ def decodeModel(ts, strat, config, network, dicts, debug) :
print
(
"
"
.
join
([
"
%s%.2f:%s
"
%
(
"
*
"
if
score
[
1
]
else
"
"
,
score
[
0
],
score
[
2
])
for
score
in
scores
])
+
"
\n
"
+
(
"
-
"
*
80
)
+
"
\n
"
,
file
=
sys
.
stderr
)
print
(
"
"
.
join
([
"
%s%.2f:%s
"
%
(
"
*
"
if
score
[
1
]
else
"
"
,
score
[
0
],
score
[
2
])
for
score
in
scores
])
+
"
\n
"
+
(
"
-
"
*
80
)
+
"
\n
"
,
file
=
sys
.
stderr
)
moved
=
applyTransition
(
ts
,
strat
,
config
,
candidate
)
moved
=
applyTransition
(
ts
,
strat
,
config
,
candidate
)
EOS
.
apply
(
config
)
EOS
.
apply
(
config
,
strat
)
network
.
to
(
currentDevice
)
network
.
to
(
currentDevice
)
################################################################################
################################################################################
################################################################################
################################################################################
def
decodeMode
(
debug
,
filename
,
type
,
modelDir
=
None
,
network
=
None
,
dicts
=
None
,
output
=
sys
.
stdout
)
:
def
decodeMode
(
debug
,
filename
,
type
,
transitionSet
,
strategy
,
modelDir
=
None
,
network
=
None
,
dicts
=
None
,
output
=
sys
.
stdout
)
:
transitionSet
=
[
Transition
(
elem
)
for
elem
in
[
"
RIGHT
"
,
"
LEFT
"
,
"
SHIFT
"
,
"
REDUCE
"
]]
strategy
=
{
"
RIGHT
"
:
1
,
"
SHIFT
"
:
1
,
"
LEFT
"
:
0
,
"
REDUCE
"
:
0
}
sentences
=
Config
.
readConllu
(
filename
)
sentences
=
Config
.
readConllu
(
filename
)
...
...
This diff is collapsed.
Click to expand it.
Train.py
+
11
−
10
View file @
bf43cbf5
...
@@ -16,10 +16,7 @@ import Config
...
@@ -16,10 +16,7 @@ import Config
from
conll18_ud_eval
import
load_conllu
,
evaluate
from
conll18_ud_eval
import
load_conllu
,
evaluate
################################################################################
################################################################################
def
trainMode
(
debug
,
filename
,
type
,
modelDir
,
nbIter
,
batchSize
,
devFile
,
bootstrapInterval
,
silent
=
False
)
:
def
trainMode
(
debug
,
filename
,
type
,
transitionSet
,
strategy
,
modelDir
,
nbIter
,
batchSize
,
devFile
,
bootstrapInterval
,
silent
=
False
)
:
transitionSet
=
[
Transition
(
elem
)
for
elem
in
[
"
RIGHT
"
,
"
LEFT
"
,
"
SHIFT
"
,
"
REDUCE
"
]]
strategy
=
{
"
RIGHT
"
:
1
,
"
SHIFT
"
:
1
,
"
LEFT
"
:
0
,
"
REDUCE
"
:
0
}
sentences
=
Config
.
readConllu
(
filename
)
sentences
=
Config
.
readConllu
(
filename
)
if
type
==
"
oracle
"
:
if
type
==
"
oracle
"
:
...
@@ -43,7 +40,7 @@ def extractExamples(debug, ts, strat, config, dicts, network, dynamic) :
...
@@ -43,7 +40,7 @@ def extractExamples(debug, ts, strat, config, dicts, network, dynamic) :
moved
=
True
moved
=
True
while
moved
:
while
moved
:
missingLinks
=
getMissingLinks
(
config
)
missingLinks
=
getMissingLinks
(
config
)
candidates
=
sorted
([[
trans
.
getOracleScore
(
config
,
missingLinks
),
trans
]
for
trans
in
ts
if
trans
.
appliable
(
config
)])
candidates
=
sorted
([[
trans
.
getOracleScore
(
config
,
missingLinks
),
trans
]
for
trans
in
ts
if
trans
.
appliable
(
config
)
and
"
BACK
"
not
in
trans
.
name
])
if
len
(
candidates
)
==
0
:
if
len
(
candidates
)
==
0
:
break
break
best
=
min
([
cand
[
0
]
for
cand
in
candidates
])
best
=
min
([
cand
[
0
]
for
cand
in
candidates
])
...
@@ -67,19 +64,19 @@ def extractExamples(debug, ts, strat, config, dicts, network, dynamic) :
...
@@ -67,19 +64,19 @@ def extractExamples(debug, ts, strat, config, dicts, network, dynamic) :
moved
=
applyTransition
(
ts
,
strat
,
config
,
candidate
)
moved
=
applyTransition
(
ts
,
strat
,
config
,
candidate
)
EOS
.
apply
(
config
)
EOS
.
apply
(
config
,
strat
)
return
examples
return
examples
################################################################################
################################################################################
################################################################################
################################################################################
def
evalModelAndSave
(
debug
,
model
,
dicts
,
modelDir
,
devFile
,
bestLoss
,
totalLoss
,
bestScore
,
epoch
,
nbIter
)
:
def
evalModelAndSave
(
debug
,
model
,
ts
,
strat
,
dicts
,
modelDir
,
devFile
,
bestLoss
,
totalLoss
,
bestScore
,
epoch
,
nbIter
)
:
devScore
=
""
devScore
=
""
saved
=
True
if
bestLoss
is
None
else
totalLoss
<
bestLoss
saved
=
True
if
bestLoss
is
None
else
totalLoss
<
bestLoss
bestLoss
=
totalLoss
if
bestLoss
is
None
else
min
(
bestLoss
,
totalLoss
)
bestLoss
=
totalLoss
if
bestLoss
is
None
else
min
(
bestLoss
,
totalLoss
)
if
devFile
is
not
None
:
if
devFile
is
not
None
:
outFilename
=
modelDir
+
"
/predicted_dev.conllu
"
outFilename
=
modelDir
+
"
/predicted_dev.conllu
"
Decode
.
decodeMode
(
debug
,
devFile
,
"
model
"
,
modelDir
,
model
,
dicts
,
open
(
outFilename
,
"
w
"
))
Decode
.
decodeMode
(
debug
,
devFile
,
"
model
"
,
ts
,
strat
,
modelDir
,
model
,
dicts
,
open
(
outFilename
,
"
w
"
))
res
=
evaluate
(
load_conllu
(
open
(
devFile
,
"
r
"
)),
load_conllu
(
open
(
outFilename
,
"
r
"
)),
[])
res
=
evaluate
(
load_conllu
(
open
(
devFile
,
"
r
"
)),
load_conllu
(
open
(
outFilename
,
"
r
"
)),
[])
UAS
=
res
[
"
UAS
"
][
0
].
f1
UAS
=
res
[
"
UAS
"
][
0
].
f1
score
=
UAS
score
=
UAS
...
@@ -145,7 +142,7 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr
...
@@ -145,7 +142,7 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr
optimizer
.
step
()
optimizer
.
step
()
totalLoss
+=
float
(
loss
)
totalLoss
+=
float
(
loss
)
bestLoss
,
bestScore
=
evalModelAndSave
(
debug
,
network
,
dicts
,
modelDir
,
devFile
,
bestLoss
,
totalLoss
,
bestScore
,
epoch
,
nbEpochs
)
bestLoss
,
bestScore
=
evalModelAndSave
(
debug
,
network
,
transitionSet
,
strategy
,
dicts
,
modelDir
,
devFile
,
bestLoss
,
totalLoss
,
bestScore
,
epoch
,
nbEpochs
)
################################################################################
################################################################################
################################################################################
################################################################################
...
@@ -193,9 +190,13 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
...
@@ -193,9 +190,13 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
if
debug
:
if
debug
:
sentence
.
printForDebug
(
sys
.
stderr
)
sentence
.
printForDebug
(
sys
.
stderr
)
action
=
selectAction
(
policy_net
,
state
,
transitionSet
,
sentence
,
missingLinks
,
probaRandom
,
probaOracle
)
action
=
selectAction
(
policy_net
,
state
,
transitionSet
,
sentence
,
missingLinks
,
probaRandom
,
probaOracle
)
if
action
is
None
:
if
action
is
None
:
break
break
if
debug
:
print
(
"
Selected action : %s
"
%
action
.
name
,
file
=
sys
.
stderr
)
appliable
=
action
.
appliable
(
sentence
)
appliable
=
action
.
appliable
(
sentence
)
# Reward for doing an illegal action
# Reward for doing an illegal action
...
@@ -227,6 +228,6 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
...
@@ -227,6 +228,6 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
if
i
>=
nbExByEpoch
:
if
i
>=
nbExByEpoch
:
break
break
sentIndex
+=
1
sentIndex
+=
1
bestLoss
,
bestScore
=
evalModelAndSave
(
debug
,
policy_net
,
dicts
,
modelDir
,
devFile
,
bestLoss
,
totalLoss
,
bestScore
,
epoch
,
nbIter
)
bestLoss
,
bestScore
=
evalModelAndSave
(
debug
,
policy_net
,
transitionSet
,
strategy
,
dicts
,
modelDir
,
devFile
,
bestLoss
,
totalLoss
,
bestScore
,
epoch
,
nbIter
)
################################################################################
################################################################################
This diff is collapsed.
Click to expand it.
Transition.py
+
69
−
10
View file @
bf43cbf5
...
@@ -4,7 +4,7 @@ from Util import isEmpty
...
@@ -4,7 +4,7 @@ from Util import isEmpty
################################################################################
################################################################################
class
Transition
:
class
Transition
:
available
=
set
({
"
RIGHT
"
,
"
LEFT
"
,
"
SHIFT
"
,
"
REDUCE
"
,
"
EOS
"
})
available
=
set
({
"
RIGHT
"
,
"
LEFT
"
,
"
SHIFT
"
,
"
REDUCE
"
,
"
EOS
"
,
"
BACK 2
"
})
def
__init__
(
self
,
name
)
:
def
__init__
(
self
,
name
)
:
if
name
not
in
self
.
available
:
if
name
not
in
self
.
available
:
...
@@ -15,21 +15,31 @@ class Transition :
...
@@ -15,21 +15,31 @@ class Transition :
def
__lt__
(
self
,
other
)
:
def
__lt__
(
self
,
other
)
:
return
self
.
name
<
other
.
name
return
self
.
name
<
other
.
name
def
apply
(
self
,
config
)
:
def
apply
(
self
,
config
,
strategy
)
:
data
=
None
if
"
BACK
"
not
in
self
.
name
:
config
.
historyHistory
.
add
(
str
([
t
[
0
].
name
for
t
in
config
.
historyPop
]))
if
self
.
name
==
"
RIGHT
"
:
if
self
.
name
==
"
RIGHT
"
:
applyRight
(
config
)
applyRight
(
config
)
elif
self
.
name
==
"
LEFT
"
:
elif
self
.
name
==
"
LEFT
"
:
applyLeft
(
config
)
data
=
applyLeft
(
config
)
elif
self
.
name
==
"
SHIFT
"
:
elif
self
.
name
==
"
SHIFT
"
:
applyShift
(
config
)
applyShift
(
config
)
elif
self
.
name
==
"
REDUCE
"
:
elif
self
.
name
==
"
REDUCE
"
:
applyReduce
(
config
)
data
=
applyReduce
(
config
)
elif
self
.
name
==
"
EOS
"
:
elif
self
.
name
==
"
EOS
"
:
applyEOS
(
config
)
applyEOS
(
config
)
elif
"
BACK
"
in
self
.
name
:
size
=
int
(
self
.
name
.
split
()[
-
1
])
applyBack
(
config
,
strategy
,
size
)
else
:
else
:
print
(
"
ERROR : nothing to apply for
'
%s
'"
%
self
.
name
,
file
=
sys
.
stderr
)
print
(
"
ERROR : nothing to apply for
'
%s
'"
%
self
.
name
,
file
=
sys
.
stderr
)
exit
(
1
)
exit
(
1
)
config
.
history
.
append
(
self
)
config
.
history
.
append
(
self
)
if
"
BACK
"
not
in
self
.
name
:
config
.
historyPop
.
append
((
self
,
data
))
def
appliable
(
self
,
config
)
:
def
appliable
(
self
,
config
)
:
if
self
.
name
==
"
RIGHT
"
:
if
self
.
name
==
"
RIGHT
"
:
...
@@ -42,8 +52,13 @@ class Transition :
...
@@ -42,8 +52,13 @@ class Transition :
return
len
(
config
.
stack
)
>
0
and
not
isEmpty
(
config
.
getAsFeature
(
config
.
stack
[
-
1
],
"
HEAD
"
))
return
len
(
config
.
stack
)
>
0
and
not
isEmpty
(
config
.
getAsFeature
(
config
.
stack
[
-
1
],
"
HEAD
"
))
if
self
.
name
==
"
EOS
"
:
if
self
.
name
==
"
EOS
"
:
return
config
.
wordIndex
==
len
(
config
.
lines
)
-
1
return
config
.
wordIndex
==
len
(
config
.
lines
)
-
1
if
"
BACK
"
in
self
.
name
:
size
=
int
(
self
.
name
.
split
()[
-
1
])
if
len
(
config
.
historyPop
)
<
size
:
return
False
return
str
([
t
[
0
].
name
for
t
in
config
.
historyPop
])
not
in
config
.
historyHistory
print
(
"
ERROR : unknown name
'
%s
'"
%
self
.
name
,
file
=
sys
.
stderr
)
print
(
"
ERROR :
appliable,
unknown name
'
%s
'"
%
self
.
name
,
file
=
sys
.
stderr
)
exit
(
1
)
exit
(
1
)
def
getOracleScore
(
self
,
config
,
missingLinks
)
:
def
getOracleScore
(
self
,
config
,
missingLinks
)
:
...
@@ -55,8 +70,10 @@ class Transition :
...
@@ -55,8 +70,10 @@ class Transition :
return
scoreOracleShift
(
config
,
missingLinks
)
return
scoreOracleShift
(
config
,
missingLinks
)
if
self
.
name
==
"
REDUCE
"
:
if
self
.
name
==
"
REDUCE
"
:
return
scoreOracleReduce
(
config
,
missingLinks
)
return
scoreOracleReduce
(
config
,
missingLinks
)
if
"
BACK
"
in
self
.
name
:
return
1
print
(
"
ERROR : unknown name
'
%s
'"
%
self
.
name
,
file
=
sys
.
stderr
)
print
(
"
ERROR :
oracle,
unknown name
'
%s
'"
%
self
.
name
,
file
=
sys
.
stderr
)
exit
(
1
)
exit
(
1
)
################################################################################
################################################################################
...
@@ -126,6 +143,48 @@ def scoreOracleReduce(config, ml) :
...
@@ -126,6 +143,48 @@ def scoreOracleReduce(config, ml) :
return
ml
[
"
StackRight
"
]
return
ml
[
"
StackRight
"
]
################################################################################
################################################################################
################################################################################
def
applyBack
(
config
,
strategy
,
size
)
:
for
i
in
range
(
size
)
:
trans
,
data
=
config
.
historyPop
.
pop
()
config
.
moveWordIndex
(
-
strategy
[
trans
.
name
])
if
trans
.
name
==
"
RIGHT
"
:
applyBackRight
(
config
)
elif
trans
.
name
==
"
LEFT
"
:
applyBackLeft
(
config
,
data
)
elif
trans
.
name
==
"
SHIFT
"
:
applyBackShift
(
config
)
elif
trans
.
name
==
"
REDUCE
"
:
applyBackReduce
(
config
,
data
)
else
:
print
(
"
ERROR : trying to apply BACK to
'
%s
'"
%
trans
.
name
,
file
=
sys
.
stderr
)
exit
(
1
)
################################################################################
################################################################################
def
applyBackRight
(
config
)
:
config
.
stack
.
pop
()
config
.
set
(
config
.
wordIndex
,
"
HEAD
"
,
""
)
config
.
predChilds
[
config
.
stack
[
-
1
]].
pop
()
################################################################################
################################################################################
def
applyBackLeft
(
config
,
data
)
:
config
.
stack
.
append
(
data
)
config
.
set
(
config
.
stack
[
-
1
],
"
HEAD
"
,
""
)
config
.
predChilds
[
config
.
wordIndex
].
pop
()
################################################################################
################################################################################
def
applyBackShift
(
config
)
:
config
.
stack
.
pop
()
################################################################################
################################################################################
def
applyBackReduce
(
config
,
data
)
:
config
.
stack
.
append
(
data
)
################################################################################
################################################################################
################################################################################
def
applyRight
(
config
)
:
def
applyRight
(
config
)
:
config
.
set
(
config
.
wordIndex
,
"
HEAD
"
,
config
.
stack
[
-
1
])
config
.
set
(
config
.
wordIndex
,
"
HEAD
"
,
config
.
stack
[
-
1
])
...
@@ -137,7 +196,7 @@ def applyRight(config) :
...
@@ -137,7 +196,7 @@ def applyRight(config) :
def
applyLeft
(
config
)
:
def
applyLeft
(
config
)
:
config
.
set
(
config
.
stack
[
-
1
],
"
HEAD
"
,
config
.
wordIndex
)
config
.
set
(
config
.
stack
[
-
1
],
"
HEAD
"
,
config
.
wordIndex
)
config
.
predChilds
[
config
.
wordIndex
].
append
(
config
.
stack
[
-
1
])
config
.
predChilds
[
config
.
wordIndex
].
append
(
config
.
stack
[
-
1
])
config
.
popStack
()
return
config
.
popStack
()
################################################################################
################################################################################
################################################################################
################################################################################
...
@@ -147,7 +206,7 @@ def applyShift(config) :
...
@@ -147,7 +206,7 @@ def applyShift(config) :
################################################################################
################################################################################
def
applyReduce
(
config
)
:
def
applyReduce
(
config
)
:
config
.
popStack
()
return
config
.
popStack
()
################################################################################
################################################################################
################################################################################
################################################################################
...
@@ -175,8 +234,8 @@ def applyEOS(config) :
...
@@ -175,8 +234,8 @@ def applyEOS(config) :
################################################################################
################################################################################
def
applyTransition
(
ts
,
strat
,
config
,
name
)
:
def
applyTransition
(
ts
,
strat
,
config
,
name
)
:
transition
=
[
trans
for
trans
in
ts
if
trans
.
name
==
name
][
0
]
transition
=
[
trans
for
trans
in
ts
if
trans
.
name
==
name
][
0
]
movement
=
strat
[
transition
.
name
]
movement
=
strat
[
transition
.
name
]
if
transition
.
name
in
strat
else
0
transition
.
apply
(
config
)
transition
.
apply
(
config
,
strat
)
return
config
.
moveWordIndex
(
movement
)
return
config
.
moveWordIndex
(
movement
)
################################################################################
################################################################################
This diff is collapsed.
Click to expand it.
main.py
+
6
−
2
View file @
bf43cbf5
...
@@ -9,6 +9,7 @@ import torch
...
@@ -9,6 +9,7 @@ import torch
import
Util
import
Util
import
Train
import
Train
import
Decode
import
Decode
from
Transition
import
Transition
################################################################################
################################################################################
if
__name__
==
"
__main__
"
:
if
__name__
==
"
__main__
"
:
...
@@ -47,10 +48,13 @@ if __name__ == "__main__" :
...
@@ -47,10 +48,13 @@ if __name__ == "__main__" :
if
args
.
bootstrap
is
not
None
:
if
args
.
bootstrap
is
not
None
:
args
.
bootstrap
=
int
(
args
.
bootstrap
)
args
.
bootstrap
=
int
(
args
.
bootstrap
)
transitionSet
=
[
Transition
(
elem
)
for
elem
in
[
"
RIGHT
"
,
"
LEFT
"
,
"
SHIFT
"
,
"
REDUCE
"
,
"
BACK 2
"
]]
strategy
=
{
"
RIGHT
"
:
1
,
"
SHIFT
"
:
1
,
"
LEFT
"
:
0
,
"
REDUCE
"
:
0
}
if
args
.
mode
==
"
train
"
:
if
args
.
mode
==
"
train
"
:
Train
.
trainMode
(
args
.
debug
,
args
.
corpus
,
args
.
type
,
args
.
model
,
int
(
args
.
iter
),
int
(
args
.
batchSize
),
args
.
dev
,
args
.
bootstrap
,
args
.
silent
)
Train
.
trainMode
(
args
.
debug
,
args
.
corpus
,
args
.
type
,
transitionSet
,
strategy
,
args
.
model
,
int
(
args
.
iter
),
int
(
args
.
batchSize
),
args
.
dev
,
args
.
bootstrap
,
args
.
silent
)
elif
args
.
mode
==
"
decode
"
:
elif
args
.
mode
==
"
decode
"
:
Decode
.
decodeMode
(
args
.
debug
,
args
.
corpus
,
args
.
type
,
args
.
model
)
Decode
.
decodeMode
(
args
.
debug
,
args
.
corpus
,
args
.
type
,
transitionSet
,
strategy
,
args
.
model
)
else
:
else
:
print
(
"
ERROR : unknown mode
'
%s
'"
%
args
.
mode
,
file
=
sys
.
stderr
)
print
(
"
ERROR : unknown mode
'
%s
'"
%
args
.
mode
,
file
=
sys
.
stderr
)
exit
(
1
)
exit
(
1
)
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment