Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Maxence Ferrari
CARIMAM_DOCC10
Commits
fa104908
Commit
fa104908
authored
Jan 10, 2022
by
Maxence Ferrari
Browse files
Move model class outside of main
parent
5fa2a7b0
Changes
1
Hide whitespace changes
Inline
Side-by-side
forward_UpDimV2_long.py
View file @
fa104908
...
...
@@ -17,131 +17,130 @@ from tqdm import tqdm, trange
from
math
import
ceil
class
UpDimV2
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
num_class
):
super
(
UpDimV2
,
self
).
__init__
()
self
.
activation
=
torch
.
nn
.
LeakyReLU
(
0.001
,
inplace
=
True
)
# Block 1D 1
self
.
conv11
=
torch
.
nn
.
Conv1d
(
1
,
32
,
3
,
1
,
1
)
self
.
norm11
=
torch
.
nn
.
BatchNorm1d
(
32
)
self
.
conv21
=
torch
.
nn
.
Conv1d
(
32
,
32
,
3
,
2
,
1
)
self
.
norm21
=
torch
.
nn
.
BatchNorm1d
(
32
)
self
.
skip11
=
torch
.
nn
.
Conv1d
(
1
,
32
,
1
,
2
)
# Block 1D 2
self
.
conv12
=
torch
.
nn
.
Conv1d
(
32
,
64
,
3
,
2
,
1
)
self
.
norm12
=
torch
.
nn
.
BatchNorm1d
(
64
)
self
.
conv22
=
torch
.
nn
.
Conv1d
(
64
,
128
,
3
,
2
,
1
)
self
.
norm22
=
torch
.
nn
.
BatchNorm1d
(
128
)
self
.
skip12
=
torch
.
nn
.
Conv1d
(
32
,
128
,
1
,
4
)
# Block 2D 1
self
.
conv31
=
torch
.
nn
.
Conv2d
(
1
,
32
,
3
,
1
,
1
)
self
.
norm31
=
torch
.
nn
.
BatchNorm2d
(
32
)
self
.
conv41
=
torch
.
nn
.
Conv2d
(
32
,
32
,
3
,
2
,
1
)
self
.
norm41
=
torch
.
nn
.
BatchNorm2d
(
32
)
self
.
skip21
=
torch
.
nn
.
Conv2d
(
1
,
32
,
1
,
2
)
# Block 2D 2
self
.
conv32
=
torch
.
nn
.
Conv2d
(
32
,
64
,
3
,
2
,
1
)
self
.
norm32
=
torch
.
nn
.
BatchNorm2d
(
64
)
self
.
conv42
=
torch
.
nn
.
Conv2d
(
64
,
128
,
3
,
2
,
1
)
self
.
norm42
=
torch
.
nn
.
BatchNorm2d
(
128
)
self
.
skip22
=
torch
.
nn
.
Conv2d
(
32
,
128
,
1
,
4
)
# Block 3D 1
self
.
conv51
=
torch
.
nn
.
Conv3d
(
1
,
32
,
3
,
(
1
,
2
,
1
),
1
)
self
.
norm51
=
torch
.
nn
.
BatchNorm3d
(
32
)
self
.
conv61
=
torch
.
nn
.
Conv3d
(
32
,
64
,
3
,
2
,
1
)
self
.
norm61
=
torch
.
nn
.
BatchNorm3d
(
64
)
self
.
skip31
=
torch
.
nn
.
Conv3d
(
1
,
64
,
1
,
(
2
,
4
,
2
))
# Block 3D 2
self
.
conv52
=
torch
.
nn
.
Conv3d
(
64
,
128
,
3
,
2
,
1
)
self
.
norm52
=
torch
.
nn
.
BatchNorm3d
(
128
)
self
.
conv62
=
torch
.
nn
.
Conv3d
(
128
,
256
,
3
,
2
,
1
)
self
.
norm62
=
torch
.
nn
.
BatchNorm3d
(
256
)
self
.
skip32
=
torch
.
nn
.
Conv3d
(
64
,
256
,
1
,
4
)
# Fully connected
self
.
soft_max
=
torch
.
nn
.
Softmax
(
-
1
)
# If the time stride is too big, the softmax will be done on a singleton
# which always ouput a 1
self
.
fc1
=
torch
.
nn
.
Linear
(
4096
,
1024
)
self
.
fc2
=
torch
.
nn
.
Linear
(
1024
,
512
)
self
.
fc3
=
torch
.
nn
.
Linear
(
512
,
num_class
)
def
forward
(
self
,
x
):
# Block 1D 1
out
=
self
.
conv11
(
x
)
out
=
self
.
norm11
(
out
)
out
=
self
.
activation
(
out
)
out
=
self
.
conv21
(
out
)
out
=
self
.
norm21
(
out
)
skip
=
self
.
skip11
(
x
)
out
=
self
.
activation
(
out
+
skip
)
# Block 1D 2
skip
=
self
.
skip12
(
out
)
out
=
self
.
conv12
(
out
)
out
=
self
.
norm12
(
out
)
out
=
self
.
activation
(
out
)
out
=
self
.
conv22
(
out
)
out
=
self
.
norm22
(
out
)
out
=
self
.
activation
(
out
+
skip
)
# Block 2D 1
out
=
out
.
reshape
((
lambda
b
,
c
,
h
:
(
b
,
1
,
c
,
h
))(
*
out
.
shape
))
skip
=
self
.
skip21
(
out
)
out
=
self
.
conv31
(
out
)
out
=
self
.
norm31
(
out
)
out
=
self
.
activation
(
out
)
out
=
self
.
conv41
(
out
)
out
=
self
.
norm41
(
out
)
out
=
self
.
activation
(
out
+
skip
)
# Block 2D 2
skip
=
self
.
skip22
(
out
)
out
=
self
.
conv32
(
out
)
out
=
self
.
norm32
(
out
)
out
=
self
.
activation
(
out
)
out
=
self
.
conv42
(
out
)
out
=
self
.
norm42
(
out
)
out
=
self
.
activation
(
out
+
skip
)
# Block 3D 1
out
=
out
.
reshape
((
lambda
b
,
c
,
w
,
h
:
(
b
,
1
,
c
,
w
,
h
))(
*
out
.
shape
))
skip
=
self
.
skip31
(
out
)
out
=
self
.
conv51
(
out
)
out
=
self
.
norm51
(
out
)
out
=
self
.
activation
(
out
)
out
=
self
.
conv61
(
out
)
out
=
self
.
norm61
(
out
)
out
=
self
.
activation
(
out
+
skip
)
# Block 3D 2
skip
=
self
.
skip32
(
out
)
out
=
self
.
conv52
(
out
)
out
=
self
.
norm52
(
out
)
out
=
self
.
activation
(
out
)
out
=
self
.
conv62
(
out
)
out
=
self
.
norm62
(
out
)
out
=
self
.
activation
(
out
+
skip
)
# Fully connected
out
=
torch
.
max
(
self
.
soft_max
(
out
),
-
1
)[
0
].
reshape
(
-
1
,
4096
)
out
=
self
.
activation
(
self
.
fc1
(
out
))
out
=
self
.
activation
(
self
.
fc2
(
out
))
return
self
.
fc3
(
out
)
def
main
(
args
):
batch_size
=
64
num_feature
=
4096
num_classes
=
10
rng
=
np
.
random
.
RandomState
(
42
)
class
UpDimV2
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
num_class
):
super
(
UpDimV2
,
self
).
__init__
()
self
.
activation
=
torch
.
nn
.
LeakyReLU
(
0.001
,
inplace
=
True
)
# Block 1D 1
self
.
conv11
=
torch
.
nn
.
Conv1d
(
1
,
32
,
3
,
1
,
1
)
self
.
norm11
=
torch
.
nn
.
BatchNorm1d
(
32
)
self
.
conv21
=
torch
.
nn
.
Conv1d
(
32
,
32
,
3
,
2
,
1
)
self
.
norm21
=
torch
.
nn
.
BatchNorm1d
(
32
)
self
.
skip11
=
torch
.
nn
.
Conv1d
(
1
,
32
,
1
,
2
)
# Block 1D 2
self
.
conv12
=
torch
.
nn
.
Conv1d
(
32
,
64
,
3
,
2
,
1
)
self
.
norm12
=
torch
.
nn
.
BatchNorm1d
(
64
)
self
.
conv22
=
torch
.
nn
.
Conv1d
(
64
,
128
,
3
,
2
,
1
)
self
.
norm22
=
torch
.
nn
.
BatchNorm1d
(
128
)
self
.
skip12
=
torch
.
nn
.
Conv1d
(
32
,
128
,
1
,
4
)
# Block 2D 1
self
.
conv31
=
torch
.
nn
.
Conv2d
(
1
,
32
,
3
,
1
,
1
)
self
.
norm31
=
torch
.
nn
.
BatchNorm2d
(
32
)
self
.
conv41
=
torch
.
nn
.
Conv2d
(
32
,
32
,
3
,
2
,
1
)
self
.
norm41
=
torch
.
nn
.
BatchNorm2d
(
32
)
self
.
skip21
=
torch
.
nn
.
Conv2d
(
1
,
32
,
1
,
2
)
# Block 2D 2
self
.
conv32
=
torch
.
nn
.
Conv2d
(
32
,
64
,
3
,
2
,
1
)
self
.
norm32
=
torch
.
nn
.
BatchNorm2d
(
64
)
self
.
conv42
=
torch
.
nn
.
Conv2d
(
64
,
128
,
3
,
2
,
1
)
self
.
norm42
=
torch
.
nn
.
BatchNorm2d
(
128
)
self
.
skip22
=
torch
.
nn
.
Conv2d
(
32
,
128
,
1
,
4
)
# Block 3D 1
self
.
conv51
=
torch
.
nn
.
Conv3d
(
1
,
32
,
3
,
(
1
,
2
,
1
),
1
)
self
.
norm51
=
torch
.
nn
.
BatchNorm3d
(
32
)
self
.
conv61
=
torch
.
nn
.
Conv3d
(
32
,
64
,
3
,
2
,
1
)
self
.
norm61
=
torch
.
nn
.
BatchNorm3d
(
64
)
self
.
skip31
=
torch
.
nn
.
Conv3d
(
1
,
64
,
1
,
(
2
,
4
,
2
))
# Block 3D 2
self
.
conv52
=
torch
.
nn
.
Conv3d
(
64
,
128
,
3
,
2
,
1
)
self
.
norm52
=
torch
.
nn
.
BatchNorm3d
(
128
)
self
.
conv62
=
torch
.
nn
.
Conv3d
(
128
,
256
,
3
,
2
,
1
)
self
.
norm62
=
torch
.
nn
.
BatchNorm3d
(
256
)
self
.
skip32
=
torch
.
nn
.
Conv3d
(
64
,
256
,
1
,
4
)
# Fully connected
self
.
soft_max
=
torch
.
nn
.
Softmax
(
-
1
)
# If the time stride is too big, the softmax will be done on a singleton
# which always ouput a 1
self
.
fc1
=
torch
.
nn
.
Linear
(
4096
,
1024
)
self
.
fc2
=
torch
.
nn
.
Linear
(
1024
,
512
)
self
.
fc3
=
torch
.
nn
.
Linear
(
512
,
num_class
)
def
forward
(
self
,
x
):
# Block 1D 1
out
=
self
.
conv11
(
x
)
out
=
self
.
norm11
(
out
)
out
=
self
.
activation
(
out
)
out
=
self
.
conv21
(
out
)
out
=
self
.
norm21
(
out
)
skip
=
self
.
skip11
(
x
)
out
=
self
.
activation
(
out
+
skip
)
# Block 1D 2
skip
=
self
.
skip12
(
out
)
out
=
self
.
conv12
(
out
)
out
=
self
.
norm12
(
out
)
out
=
self
.
activation
(
out
)
out
=
self
.
conv22
(
out
)
out
=
self
.
norm22
(
out
)
out
=
self
.
activation
(
out
+
skip
)
# Block 2D 1
out
=
out
.
reshape
((
lambda
b
,
c
,
h
:
(
b
,
1
,
c
,
h
))(
*
out
.
shape
))
skip
=
self
.
skip21
(
out
)
out
=
self
.
conv31
(
out
)
out
=
self
.
norm31
(
out
)
out
=
self
.
activation
(
out
)
out
=
self
.
conv41
(
out
)
out
=
self
.
norm41
(
out
)
out
=
self
.
activation
(
out
+
skip
)
# Block 2D 2
skip
=
self
.
skip22
(
out
)
out
=
self
.
conv32
(
out
)
out
=
self
.
norm32
(
out
)
out
=
self
.
activation
(
out
)
out
=
self
.
conv42
(
out
)
out
=
self
.
norm42
(
out
)
out
=
self
.
activation
(
out
+
skip
)
# Block 3D 1
out
=
out
.
reshape
((
lambda
b
,
c
,
w
,
h
:
(
b
,
1
,
c
,
w
,
h
))(
*
out
.
shape
))
skip
=
self
.
skip31
(
out
)
out
=
self
.
conv51
(
out
)
out
=
self
.
norm51
(
out
)
out
=
self
.
activation
(
out
)
out
=
self
.
conv61
(
out
)
out
=
self
.
norm61
(
out
)
out
=
self
.
activation
(
out
+
skip
)
# Block 3D 2
skip
=
self
.
skip32
(
out
)
out
=
self
.
conv52
(
out
)
out
=
self
.
norm52
(
out
)
out
=
self
.
activation
(
out
)
out
=
self
.
conv62
(
out
)
out
=
self
.
norm62
(
out
)
out
=
self
.
activation
(
out
+
skip
)
# Fully connected
out
=
torch
.
max
(
self
.
soft_max
(
out
),
-
1
)[
0
].
reshape
(
-
1
,
4096
)
out
=
self
.
activation
(
self
.
fc1
(
out
))
out
=
self
.
activation
(
self
.
fc2
(
out
))
return
self
.
fc3
(
out
)
model
=
torch
.
nn
.
DataParallel
(
UpDimV2
(
num_classes
))
model
.
load_state_dict
((
torch
.
load
(
args
.
weight
)[
'model'
]))
model
.
to
(
'cuda'
)
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment