Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
M
ML Quant Sep
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Hachem Kadri
ML Quant Sep
Commits
f28b94a9
Commit
f28b94a9
authored
2 years ago
by
Balthazar Casale
Browse files
Options
Downloads
Patches
Plain Diff
Upload New File
parent
724b1d6f
No related branches found
No related tags found
No related merge requests found
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
src/pipeline.py
+206
-0
206 additions, 0 deletions
src/pipeline.py
with
206 additions
and
0 deletions
src/pipeline.py
0 → 100644
+
206
−
0
View file @
f28b94a9
from
types
import
DMStack
import
numpy
as
np
import
time
class
Pipeline
:
"""
A Pipeline is a sampler which apply a serie of transformations to an initial sample.
A Pipeline repeatedly sample from the same distribution until the desired number of sample is obtained.
If the Pipeline produce more samples than initially desired, then it will return the n_sample first samples.
"""
def
__init__
(
self
,
steps
,
name
=
''
,
verbosity
=
0
,
max_batch_size
=
None
,
sep_name
=
'
__
'
):
"""
:param steps: a list of tuple (name, transfo_func). The first member of the list must be a tuple (name, sample_func).
:param name: name of the pipeline (printed if verbose)
:param verbosity: number of message outputed :
0 - no message
1 - a message at every batch (nb of states to be sampled / nb of states produced and time taken)
2 - a message at every step (nb of states produced and time taken)
:param max_batch_size: if not None, limit the number of states produced each batch.
:param sep name: any new information produced by a step will be added to the dictionary info in the key
step_name + sep_name + key (equal to
'
__
'
by default)
"""
self
.
name
=
name
self
.
steps
=
steps
self
.
max_batch_size
=
max_batch_size
self
.
verbose
=
verbosity
self
.
sep_name
=
sep_name
def
sample
(
self
,
n_samples
,
dims
=
None
):
states
=
None
infos
=
None
n_sampled
=
0
if
self
.
verbose
>=
1
:
print
(
f
'
BEGIN GENERATION PROCESS
{
self
.
name
}
[
{
n_samples
}
]
'
)
# BATCHES CREATION
while
n_sampled
<
n_samples
:
batch_size
=
self
.
batch_size
(
n_samples
,
n_sampled
)
if
self
.
verbose
>=
1
:
print
(
f
'
BEGIN BATCH [
{
batch_size
}
]...
'
,
end
=
''
)
if
self
.
verbose
>=
2
:
print
(
''
)
batch_beg
=
time
.
time
()
name
,
sample
=
self
.
steps
[
0
]
if
self
.
verbose
>=
2
:
print
(
f
'
{
name
}
step...
'
,
end
=
''
)
beg
=
time
.
time
()
batch_states
,
batch_infos
=
sample
(
batch_size
,
dims
)
end
=
time
.
time
()
if
self
.
verbose
>=
2
:
print
(
f
'
[
{
len
(
batch_states
)
}
] (time :
{
end
-
beg
:
.
2
f
}
s)
'
)
# formate initial infos keys
keys
=
list
(
batch_infos
.
keys
())
for
key
in
keys
:
batch_infos
[
name
+
self
.
sep_name
+
key
]
=
batch_infos
.
pop
(
key
)
add
=
True
# perform each steps
for
i
in
range
(
1
,
len
(
self
.
steps
)):
name
,
transfo
=
self
.
steps
[
i
]
if
self
.
verbose
>=
2
:
print
(
f
'
{
name
}
step...
'
,
end
=
''
)
beg
=
time
.
time
()
batch_states
,
batch_infos
,
new_infos
=
transfo
(
batch_states
,
batch_infos
)
end
=
time
.
time
()
if
self
.
verbose
>=
2
:
print
(
f
'
done. [
{
len
(
batch_states
)
}
] (time :
{
end
-
beg
:
.
2
f
}
s)
'
)
# no states remaining at the end of the step : make another batch.
if
len
(
batch_states
)
==
0
:
add
=
False
break
for
key
in
new_infos
.
keys
():
batch_infos
[
name
+
self
.
sep_name
+
key
]
=
new_infos
[
key
]
batch_end
=
time
.
time
()
if
self
.
verbose
>=
1
:
print
(
f
'
END BATCH [
{
len
(
batch_states
)
}
] (
{
batch_end
-
batch_beg
:
.
2
f
}
s)
'
)
# add states to previous samples.
if
add
:
if
states
is
None
:
states
=
batch_states
infos
=
batch_infos
else
:
states
=
DMStack
(
np
.
concatenate
((
states
,
batch_states
)),
dims
)
for
key
in
infos
.
keys
():
infos
[
key
]
=
np
.
concatenate
((
infos
[
key
],
batch_infos
[
key
]))
n_sampled
=
len
(
states
)
if
n_sampled
>
n_samples
:
states
=
DMStack
(
states
[:
n_samples
],
dims
)
for
key
in
infos
.
keys
():
infos
[
key
]
=
infos
[
key
][:
n_samples
]
if
self
.
verbose
>=
1
:
print
(
f
'
END GENERATION PROCESS
{
self
.
name
}
[
{
len
(
states
)
}
]
'
)
return
states
,
infos
def
batch_size
(
self
,
n_samples
,
n_sampled
):
if
self
.
max_batch_size
is
None
:
return
n_samples
-
n_sampled
else
:
return
min
(
n_samples
-
n_sampled
,
self
.
max_batch_size
)
# utility functions
# select the states associated to the desired label by a classification function
def
select
(
label_func
,
value
):
"""
select the states associated to a certain label by a model function
:param label_func: a model function
:param value: desired label
:return: tuple (states, infos, new_infos)
"""
def
transform
(
states
,
infos
=
{})
:
labels
,
infos_labels
=
label_func
(
states
,
infos
)
states
=
states
[
labels
==
value
]
for
key
in
infos
.
keys
():
infos
[
key
]
=
infos
[
key
][
labels
==
value
]
for
key
in
infos_labels
.
keys
():
infos_labels
[
key
]
=
infos_labels
[
key
][
labels
==
value
]
return
states
,
infos
,
infos_labels
return
transform
# add the transformation of a state as additional informations
def
add
(
transfo_func
,
key
):
"""
add the result of a transformer function as a key in the information dictionary
:param transfo_func: a transformer function
:param key: the key in which the result will be added
:return: tuple (state, infos, new_infos)
"""
def
transform
(
states
,
infos
=
{}):
tran
,
tran_infos
=
transfo_func
(
states
,
infos
)
return
states
,
infos
,
{
key
:
tran
,
**
tran_infos
}
return
transform
# replace the states by their transformation
def
apply
(
transfo_func
):
"""
replace the states of the set by the result of a transformer function
:param transfo_func: a transformer function
:return: tuple (state, infos, new_infos)
"""
def
transform
(
states
,
infos
=
{}):
tran
,
tran_infos
=
transfo_func
(
states
,
infos
)
return
tran
,
infos
,
tran_infos
return
transform
# order the state by increasing ('+') or decreasing ('-') value of an information.
def
order
(
key
,
first
=
'
+
'
):
"""
order the states of the set by increasing/decreasing order of the content in a key
:param key: the key containing the values
:param first:
'
+
'
: increasing order,
'
-
'
: decreasing order
:return: a tuple (states, infos, new_infos)
"""
def
transform
(
states
,
infos
=
{}):
indexes
=
np
.
argsort
(
infos
[
key
])
if
first
==
'
-
'
:
indexes
=
np
.
flip
(
indexes
)
states
=
states
[
indexes
]
for
ikey
in
infos
.
keys
():
infos
[
ikey
]
=
infos
[
ikey
][
indexes
]
return
states
,
infos
,
{}
return
transform
# randomly shuffle the states
def
shuffle
():
"""
randomly shuffle the states of the set
:return: a tuple (states, infos, new_infos)
"""
def
transform
(
states
,
infos
=
{}):
indexes
=
np
.
arange
(
len
(
states
))
np
.
random
.
shuffle
(
indexes
)
states
=
states
[
indexes
]
for
key
in
infos
.
keys
():
infos
[
key
]
=
infos
[
key
][
indexes
]
return
states
,
infos
,
{}
return
transform
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment