Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
R
RL-Parsing
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Franck Dary
RL-Parsing
Commits
b1c976ae
Commit
b1c976ae
authored
4 years ago
by
Franck Dary
Browse files
Options
Downloads
Patches
Plain Diff
Added bootstrap mode for oracle training
parent
e0926705
Branches
Branches containing commit
No related tags found
No related merge requests found
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
Train.py
+48
-27
48 additions, 27 deletions
Train.py
main.py
+3
-1
3 additions, 1 deletion
main.py
with
51 additions
and
28 deletions
Train.py
+
48
−
27
View file @
b1c976ae
...
...
@@ -15,14 +15,14 @@ import Config
from
conll18_ud_eval
import
load_conllu
,
evaluate
################################################################################
def
trainMode
(
debug
,
filename
,
type
,
modelDir
,
nbIter
,
batchSize
,
devFile
,
silent
=
False
)
:
def
trainMode
(
debug
,
filename
,
type
,
modelDir
,
nbIter
,
batchSize
,
devFile
,
bootstrapInterval
,
silent
=
False
)
:
transitionSet
=
[
Transition
(
elem
)
for
elem
in
[
"
RIGHT
"
,
"
LEFT
"
,
"
SHIFT
"
,
"
REDUCE
"
]]
strategy
=
{
"
RIGHT
"
:
1
,
"
SHIFT
"
:
1
,
"
LEFT
"
:
0
,
"
REDUCE
"
:
0
}
sentences
=
Config
.
readConllu
(
filename
)
if
type
==
"
oracle
"
:
trainModelOracle
(
debug
,
modelDir
,
filename
,
nbIter
,
batchSize
,
devFile
,
transitionSet
,
strategy
,
sentences
,
silent
)
trainModelOracle
(
debug
,
modelDir
,
filename
,
nbIter
,
batchSize
,
devFile
,
transitionSet
,
strategy
,
sentences
,
bootstrapInterval
,
silent
)
return
if
type
==
"
rl
"
:
...
...
@@ -34,25 +34,36 @@ def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, silen
################################################################################
################################################################################
def
extractExamples
(
ts
,
strat
,
config
,
dicts
,
debug
=
Fals
e
)
:
def
extractExamples
(
debug
,
ts
,
strat
,
config
,
dicts
,
network
=
Non
e
)
:
examples
=
[]
with
torch
.
no_grad
()
:
EOS
=
Transition
(
"
EOS
"
)
config
.
moveWordIndex
(
0
)
moved
=
True
while
moved
:
missingLinks
=
getMissingLinks
(
config
)
candidates
=
sorted
([[
trans
.
getOracleScore
(
config
,
missingLinks
),
trans
.
name
]
for
trans
in
ts
if
trans
.
appliable
(
config
)])
candidates
=
sorted
([[
trans
.
getOracleScore
(
config
,
missingLinks
),
trans
]
for
trans
in
ts
if
trans
.
appliable
(
config
)])
if
len
(
candidates
)
==
0
:
break
candidate
=
candidates
[
0
][
1
]
candidate
Index
=
[
t
ran
s
.
name
for
tr
an
s
in
ts
].
index
(
candidate
)
best
=
min
([
cand
[
0
]
for
cand
in
candidates
]
)
candidate
Oracle
=
ran
dom
.
sample
([
cand
for
c
an
d
in
candidates
if
cand
[
0
]
==
best
],
1
)[
0
][
1
]
features
=
Features
.
extractFeatures
(
dicts
,
config
)
example
=
torch
.
cat
([
torch
.
LongTensor
([
candidateIndex
]),
features
])
examples
.
append
(
example
)
candidate
=
candidateOracle
.
name
if
debug
:
config
.
printForDebug
(
sys
.
stderr
)
print
(
str
(
candidates
)
+
"
\n
"
+
(
"
-
"
*
80
)
+
"
\n
"
,
file
=
sys
.
stderr
)
print
(
str
([[
c
[
0
],
c
[
1
].
name
]
for
c
in
candidates
])
+
"
\n
"
+
(
"
-
"
*
80
)
+
"
\n
"
,
file
=
sys
.
stderr
)
if
network
is
not
None
:
output
=
network
(
features
.
unsqueeze
(
0
).
to
(
getDevice
()))
scores
=
sorted
([[
float
(
output
[
0
][
index
]),
ts
[
index
].
appliable
(
config
),
ts
[
index
].
name
]
for
index
in
range
(
len
(
ts
))])[::
-
1
]
candidate
=
[[
cand
[
0
],
cand
[
2
]]
for
cand
in
scores
if
cand
[
1
]][
0
][
1
]
if
debug
:
print
(
candidate
.
name
,
file
=
sys
.
stderr
)
goldIndex
=
[
trans
.
name
for
trans
in
ts
].
index
(
candidateOracle
.
name
)
candidateIndex
=
[
trans
.
name
for
trans
in
ts
].
index
(
candidate
)
example
=
torch
.
cat
([
torch
.
LongTensor
([
goldIndex
]),
features
])
examples
.
append
(
example
)
moved
=
applyTransition
(
ts
,
strat
,
config
,
candidate
)
EOS
.
apply
(
config
)
...
...
@@ -82,14 +93,15 @@ def evalModelAndSave(debug, model, dicts, modelDir, devFile, bestLoss, totalLoss
################################################################################
################################################################################
def
trainModelOracle
(
debug
,
modelDir
,
filename
,
nbEpochs
,
batchSize
,
devFile
,
transitionSet
,
strategy
,
sentences
,
silent
=
False
)
:
examples
=
[]
def
trainModelOracle
(
debug
,
modelDir
,
filename
,
nbEpochs
,
batchSize
,
devFile
,
transitionSet
,
strategy
,
sentencesOriginal
,
bootstrapInterval
,
silent
=
False
)
:
dicts
=
Dicts
()
dicts
.
readConllu
(
filename
,
[
"
FORM
"
,
"
UPOS
"
])
dicts
.
save
(
modelDir
+
"
/dicts.json
"
)
examples
=
[]
sentences
=
copy
.
deepcopy
(
sentencesOriginal
)
print
(
"
%s : Starting to extract examples...
"
%
(
timeStamp
()),
file
=
sys
.
stderr
)
for
config
in
sentences
:
examples
+=
extractExamples
(
transitionSet
,
strategy
,
config
,
dicts
,
debug
)
examples
+=
extractExamples
(
debug
,
transitionSet
,
strategy
,
config
,
dicts
)
print
(
"
%s : Extracted %s examples
"
%
(
timeStamp
(),
prettyInt
(
len
(
examples
),
3
)),
file
=
sys
.
stderr
)
examples
=
torch
.
stack
(
examples
)
...
...
@@ -100,6 +112,15 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr
bestLoss
=
None
bestScore
=
None
for
epoch
in
range
(
1
,
nbEpochs
+
1
)
:
if
bootstrapInterval
is
not
None
and
epoch
>
1
and
(
epoch
-
1
)
%
bootstrapInterval
==
0
:
examples
=
[]
sentences
=
copy
.
deepcopy
(
sentencesOriginal
)
print
(
"
%s : Starting to extract dynamic examples...
"
%
(
timeStamp
()),
file
=
sys
.
stderr
)
for
config
in
sentences
:
examples
+=
extractExamples
(
debug
,
transitionSet
,
strategy
,
config
,
dicts
,
network
)
print
(
"
%s : Extracted %s examples
"
%
(
timeStamp
(),
prettyInt
(
len
(
examples
),
3
)),
file
=
sys
.
stderr
)
examples
=
torch
.
stack
(
examples
)
network
.
train
()
examples
=
examples
.
index_select
(
0
,
torch
.
randperm
(
examples
.
size
(
0
)))
totalLoss
=
0.0
...
...
This diff is collapsed.
Click to expand it.
main.py
+
3
−
1
View file @
b1c976ae
...
...
@@ -27,6 +27,8 @@ if __name__ == "__main__" :
help
=
"
Size of each batch.
"
)
parser
.
add_argument
(
"
--seed
"
,
default
=
100
,
help
=
"
Random seed.
"
)
parser
.
add_argument
(
"
--bootstrap
"
,
default
=
None
,
help
=
"
If not none, extract examples in bootstrap mode (oracle train only).
"
)
parser
.
add_argument
(
"
--dev
"
,
default
=
None
,
help
=
"
Name of the CoNLL-U file of the dev corpus.
"
)
parser
.
add_argument
(
"
--debug
"
,
"
-d
"
,
default
=
False
,
action
=
"
store_true
"
,
...
...
@@ -43,7 +45,7 @@ if __name__ == "__main__" :
torch
.
manual_seed
(
args
.
seed
)
if
args
.
mode
==
"
train
"
:
Train
.
trainMode
(
args
.
debug
,
args
.
corpus
,
args
.
type
,
args
.
model
,
int
(
args
.
iter
),
int
(
args
.
batchSize
),
args
.
dev
,
args
.
silent
)
Train
.
trainMode
(
args
.
debug
,
args
.
corpus
,
args
.
type
,
args
.
model
,
int
(
args
.
iter
),
int
(
args
.
batchSize
),
args
.
dev
,
args
.
bootstrap
,
args
.
silent
)
elif
args
.
mode
==
"
decode
"
:
Decode
.
decodeMode
(
args
.
debug
,
args
.
corpus
,
args
.
type
,
args
.
model
)
else
:
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment