Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
R
RL-Parsing
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Container Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Franck Dary
RL-Parsing
Commits
67c67305
Commit
67c67305
authored
3 years ago
by
Franck Dary
Browse files
Options
Downloads
Patches
Plain Diff
Cleaned main, put functions into Train.py and Decode.py
parent
cf60ff93
No related branches found
Branches containing commit
No related tags found
No related merge requests found
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
Decode.py
+27
-0
27 additions, 0 deletions
Decode.py
Train.py
+69
-1
69 additions, 1 deletion
Train.py
Util.py
+7
-0
7 additions, 0 deletions
Util.py
main.py
+3
-98
3 additions, 98 deletions
main.py
with
106 additions
and
99 deletions
Decode.py
+
27
−
0
View file @
67c67305
...
...
@@ -2,6 +2,7 @@ import random
import
sys
from
Transition
import
Transition
,
getMissingLinks
,
applyTransition
from
Features
import
extractFeatures
import
Config
import
torch
################################################################################
...
...
@@ -62,3 +63,29 @@ def decodeModel(ts, strat, config, network, dicts, debug) :
EOS
.
apply
(
config
)
################################################################################
################################################################################
def
decodeMode
(
debug
,
filename
,
type
,
network
=
None
,
dicts
=
None
,
output
=
sys
.
stdout
)
:
transitionSet
=
[
Transition
(
elem
)
for
elem
in
[
"
RIGHT
"
,
"
LEFT
"
,
"
SHIFT
"
,
"
REDUCE
"
]]
strategy
=
{
"
RIGHT
"
:
1
,
"
SHIFT
"
:
1
,
"
LEFT
"
:
0
,
"
REDUCE
"
:
0
}
sentences
=
Config
.
readConllu
(
filename
)
if
type
in
[
"
random
"
,
"
oracle
"
]
:
decodeFunc
=
oracleDecode
if
type
==
"
oracle
"
else
randomDecode
for
config
in
sentences
:
decodeFunc
(
transitionSet
,
strategy
,
config
,
debug
)
sentences
[
0
].
print
(
sys
.
stdout
,
header
=
True
)
for
config
in
sentences
[
1
:]
:
config
.
print
(
sys
.
stdout
,
header
=
False
)
elif
type
==
"
model
"
:
for
config
in
sentences
:
decodeModel
(
transitionSet
,
strategy
,
config
,
network
,
dicts
,
debug
)
sentences
[
0
].
print
(
output
,
header
=
True
)
for
config
in
sentences
[
1
:]
:
config
.
print
(
output
,
header
=
False
)
else
:
print
(
"
ERROR : unknown type
'
%s
'"
%
type
,
file
=
sys
.
stderr
)
exit
(
1
)
################################################################################
This diff is collapsed.
Click to expand it.
Train.py
+
69
−
1
View file @
67c67305
import
sys
import
random
import
torch
from
Transition
import
Transition
,
getMissingLinks
,
applyTransition
import
Features
from
Dicts
import
Dicts
from
Util
import
timeStamp
import
Networks
import
Decode
import
Config
import
torch
from
conll18_ud_eval
import
load_conllu
,
evaluate
################################################################################
def
trainMode
(
debug
,
filename
,
type
,
modelDir
,
nbIter
,
batchSize
,
devFile
,
silent
=
False
)
:
transitionSet
=
[
Transition
(
elem
)
for
elem
in
[
"
RIGHT
"
,
"
LEFT
"
,
"
SHIFT
"
,
"
REDUCE
"
]]
strategy
=
{
"
RIGHT
"
:
1
,
"
SHIFT
"
:
1
,
"
LEFT
"
:
0
,
"
REDUCE
"
:
0
}
sentences
=
Config
.
readConllu
(
filename
)
if
type
==
"
oracle
"
:
trainModelOracle
(
debug
,
modelDir
,
filename
,
nbIter
,
batchSize
,
devFile
,
transitionSet
,
strategy
,
sentences
,
silent
)
return
print
(
"
ERROR : unknown type
'
%s
'"
%
type
,
file
=
sys
.
stderr
)
exit
(
1
)
################################################################################
################################################################################
def
extractExamples
(
ts
,
strat
,
config
,
dicts
,
debug
=
False
)
:
...
...
@@ -32,3 +54,49 @@ def extractExamples(ts, strat, config, dicts, debug=False) :
return
examples
################################################################################
################################################################################
def
trainModelOracle
(
debug
,
modelDir
,
filename
,
nbIter
,
batchSize
,
devFile
,
transitionSet
,
strategy
,
sentences
,
silent
=
False
)
:
examples
=
[]
dicts
=
Dicts
()
dicts
.
readConllu
(
filename
,
[
"
FORM
"
,
"
UPOS
"
])
dicts
.
save
(
modelDir
+
"
/dicts.json
"
)
print
(
"
%s : Starting to extract examples...
"
%
(
timeStamp
()),
file
=
sys
.
stderr
)
for
config
in
sentences
:
examples
+=
extractExamples
(
transitionSet
,
strategy
,
config
,
dicts
,
debug
)
print
(
"
%s : Extracted %d examples
"
%
(
timeStamp
(),
len
(
examples
)),
file
=
sys
.
stderr
)
examples
=
torch
.
stack
(
examples
)
network
=
Networks
.
BaseNet
(
dicts
,
examples
[
0
].
size
(
0
)
-
1
,
len
(
transitionSet
))
network
.
train
()
optimizer
=
torch
.
optim
.
Adam
(
network
.
parameters
(),
lr
=
0.0001
)
lossFct
=
torch
.
nn
.
CrossEntropyLoss
()
for
iter
in
range
(
1
,
nbIter
+
1
)
:
examples
=
examples
.
index_select
(
0
,
torch
.
randperm
(
examples
.
size
(
0
)))
totalLoss
=
0.0
nbEx
=
0
printInterval
=
2000
advancement
=
0
for
batchIndex
in
range
(
0
,
examples
.
size
(
0
)
-
batchSize
,
batchSize
)
:
batch
=
examples
[
batchIndex
:
batchIndex
+
batchSize
]
targets
=
batch
[:,:
1
].
view
(
-
1
)
inputs
=
batch
[:,
1
:]
nbEx
+=
targets
.
size
(
0
)
advancement
+=
targets
.
size
(
0
)
if
not
silent
and
advancement
>=
printInterval
:
advancement
=
0
print
(
"
Curent epoch %6.2f%%
"
%
(
100.0
*
nbEx
/
examples
.
size
(
0
)),
end
=
"
\r
"
,
file
=
sys
.
stderr
)
outputs
=
network
(
inputs
)
loss
=
lossFct
(
outputs
,
targets
)
network
.
zero_grad
()
loss
.
backward
()
optimizer
.
step
()
totalLoss
+=
float
(
loss
)
devScore
=
""
if
devFile
is
not
None
:
outFilename
=
modelDir
+
"
/predicted_dev.conllu
"
Decode
.
decodeMode
(
debug
,
devFile
,
"
model
"
,
network
,
dicts
,
open
(
outFilename
,
"
w
"
))
res
=
evaluate
(
load_conllu
(
open
(
devFile
,
"
r
"
)),
load_conllu
(
open
(
outFilename
,
"
r
"
)),
[])
devScore
=
"
, Dev : UAS=%.2f
"
%
(
res
[
"
UAS
"
][
0
].
f1
)
print
(
"
%s : Epoch %d, loss=%.2f%s
"
%
(
timeStamp
(),
iter
,
totalLoss
,
devScore
),
file
=
sys
.
stderr
)
################################################################################
This diff is collapsed.
Click to expand it.
Util.py
0 → 100644
+
7
−
0
View file @
67c67305
from
datetime
import
datetime
################################################################################
def
timeStamp
()
:
return
"
[%s]
"
%
datetime
.
now
().
strftime
(
"
%H:%M:%S
"
)
################################################################################
This diff is collapsed.
Click to expand it.
main.py
+
3
−
98
View file @
67c67305
...
...
@@ -3,104 +3,9 @@
import
sys
import
os
import
argparse
from
datetime
import
datetime
import
Config
import
Decode
import
Train
from
Transition
import
Transition
import
Networks
from
Dicts
import
Dicts
from
conll18_ud_eval
import
load_conllu
,
evaluate
import
torch
################################################################################
def
timeStamp
()
:
return
"
[%s]
"
%
datetime
.
now
().
strftime
(
"
%H:%M:%S
"
)
################################################################################
################################################################################
def
trainMode
(
debug
,
filename
,
type
,
modelDir
,
nbIter
,
batchSize
,
devFile
,
silent
=
False
)
:
transitionSet
=
[
Transition
(
elem
)
for
elem
in
[
"
RIGHT
"
,
"
LEFT
"
,
"
SHIFT
"
,
"
REDUCE
"
]]
strategy
=
{
"
RIGHT
"
:
1
,
"
SHIFT
"
:
1
,
"
LEFT
"
:
0
,
"
REDUCE
"
:
0
}
sentences
=
Config
.
readConllu
(
filename
)
if
type
==
"
oracle
"
:
examples
=
[]
dicts
=
Dicts
()
dicts
.
readConllu
(
filename
,
[
"
FORM
"
,
"
UPOS
"
])
dicts
.
save
(
modelDir
+
"
/dicts.json
"
)
print
(
"
%s : Starting to extract examples...
"
%
(
timeStamp
()),
file
=
sys
.
stderr
)
for
config
in
sentences
:
examples
+=
Train
.
extractExamples
(
transitionSet
,
strategy
,
config
,
dicts
,
args
.
debug
)
print
(
"
%s : Extracted %d examples
"
%
(
timeStamp
(),
len
(
examples
)),
file
=
sys
.
stderr
)
examples
=
torch
.
stack
(
examples
)
network
=
Networks
.
BaseNet
(
dicts
,
examples
[
0
].
size
(
0
)
-
1
,
len
(
transitionSet
))
network
.
train
()
optimizer
=
torch
.
optim
.
Adam
(
network
.
parameters
(),
lr
=
0.0001
)
lossFct
=
torch
.
nn
.
CrossEntropyLoss
()
for
iter
in
range
(
1
,
nbIter
+
1
)
:
examples
=
examples
.
index_select
(
0
,
torch
.
randperm
(
examples
.
size
(
0
)))
totalLoss
=
0.0
nbEx
=
0
printInterval
=
2000
advancement
=
0
for
batchIndex
in
range
(
0
,
examples
.
size
(
0
)
-
batchSize
,
batchSize
)
:
batch
=
examples
[
batchIndex
:
batchIndex
+
batchSize
]
targets
=
batch
[:,:
1
].
view
(
-
1
)
inputs
=
batch
[:,
1
:]
nbEx
+=
targets
.
size
(
0
)
advancement
+=
targets
.
size
(
0
)
if
not
silent
and
advancement
>=
printInterval
:
advancement
=
0
print
(
"
Curent epoch %6.2f%%
"
%
(
100.0
*
nbEx
/
examples
.
size
(
0
)),
end
=
"
\r
"
,
file
=
sys
.
stderr
)
outputs
=
network
(
inputs
)
loss
=
lossFct
(
outputs
,
targets
)
network
.
zero_grad
()
loss
.
backward
()
optimizer
.
step
()
totalLoss
+=
float
(
loss
)
devScore
=
""
if
devFile
is
not
None
:
outFilename
=
modelDir
+
"
/predicted_dev.conllu
"
decodeMode
(
debug
,
devFile
,
"
model
"
,
network
,
dicts
,
open
(
outFilename
,
"
w
"
))
res
=
evaluate
(
load_conllu
(
open
(
devFile
,
"
r
"
)),
load_conllu
(
open
(
outFilename
,
"
r
"
)),
[])
devScore
=
"
, Dev : UAS=%.2f
"
%
(
res
[
"
UAS
"
][
0
].
f1
)
print
(
"
%s : Epoch %d, loss=%.2f%s
"
%
(
timeStamp
(),
iter
,
totalLoss
,
devScore
),
file
=
sys
.
stderr
)
return
print
(
"
ERROR : unknown type
'
%s
'"
%
type
,
file
=
sys
.
stderr
)
exit
(
1
)
################################################################################
################################################################################
def
decodeMode
(
debug
,
filename
,
type
,
network
=
None
,
dicts
=
None
,
output
=
sys
.
stdout
)
:
transitionSet
=
[
Transition
(
elem
)
for
elem
in
[
"
RIGHT
"
,
"
LEFT
"
,
"
SHIFT
"
,
"
REDUCE
"
]]
strategy
=
{
"
RIGHT
"
:
1
,
"
SHIFT
"
:
1
,
"
LEFT
"
:
0
,
"
REDUCE
"
:
0
}
sentences
=
Config
.
readConllu
(
filename
)
if
type
in
[
"
random
"
,
"
oracle
"
]
:
decodeFunc
=
Decode
.
oracleDecode
if
type
==
"
oracle
"
else
Decode
.
randomDecode
for
config
in
sentences
:
decodeFunc
(
transitionSet
,
strategy
,
config
,
args
.
debug
)
sentences
[
0
].
print
(
sys
.
stdout
,
header
=
True
)
for
config
in
sentences
[
1
:]
:
config
.
print
(
sys
.
stdout
,
header
=
False
)
elif
type
==
"
model
"
:
for
config
in
sentences
:
Decode
.
decodeModel
(
transitionSet
,
strategy
,
config
,
network
,
dicts
,
args
.
debug
)
sentences
[
0
].
print
(
output
,
header
=
True
)
for
config
in
sentences
[
1
:]
:
config
.
print
(
output
,
header
=
False
)
else
:
print
(
"
ERROR : unknown type
'
%s
'"
%
type
,
file
=
sys
.
stderr
)
exit
(
1
)
################################################################################
import
Decode
################################################################################
if
__name__
==
"
__main__
"
:
...
...
@@ -128,9 +33,9 @@ if __name__ == "__main__" :
os
.
makedirs
(
args
.
model
,
exist_ok
=
True
)
if
args
.
mode
==
"
train
"
:
trainMode
(
args
.
debug
,
args
.
corpus
,
args
.
type
,
args
.
model
,
int
(
args
.
iter
),
int
(
args
.
batchSize
),
args
.
dev
,
args
.
silent
)
Train
.
trainMode
(
args
.
debug
,
args
.
corpus
,
args
.
type
,
args
.
model
,
int
(
args
.
iter
),
int
(
args
.
batchSize
),
args
.
dev
,
args
.
silent
)
elif
args
.
mode
==
"
decode
"
:
decodeMode
(
args
.
debug
,
args
.
corpus
,
args
.
type
)
Decode
.
decodeMode
(
args
.
debug
,
args
.
corpus
,
args
.
type
)
else
:
print
(
"
ERROR : unknown mode
'
%s
'"
%
args
.
mode
,
file
=
sys
.
stderr
)
exit
(
1
)
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment