Commit fa104908 authored by Maxence Ferrari's avatar Maxence Ferrari
Browse files

Move model class outside of main

parent 5fa2a7b0
......@@ -17,131 +17,130 @@ from tqdm import tqdm, trange
from math import ceil
class UpDimV2(torch.nn.Module):
def __init__(self, num_class):
super(UpDimV2, self).__init__()
self.activation = torch.nn.LeakyReLU(0.001, inplace=True)
# Block 1D 1
self.conv11 = torch.nn.Conv1d(1, 32, 3, 1, 1)
self.norm11 = torch.nn.BatchNorm1d(32)
self.conv21 = torch.nn.Conv1d(32, 32, 3, 2, 1)
self.norm21 = torch.nn.BatchNorm1d(32)
self.skip11 = torch.nn.Conv1d(1, 32, 1, 2)
# Block 1D 2
self.conv12 = torch.nn.Conv1d(32, 64, 3, 2, 1)
self.norm12 = torch.nn.BatchNorm1d(64)
self.conv22 = torch.nn.Conv1d(64, 128, 3, 2, 1)
self.norm22 = torch.nn.BatchNorm1d(128)
self.skip12 = torch.nn.Conv1d(32, 128, 1, 4)
# Block 2D 1
self.conv31 = torch.nn.Conv2d(1, 32, 3, 1, 1)
self.norm31 = torch.nn.BatchNorm2d(32)
self.conv41 = torch.nn.Conv2d(32, 32, 3, 2, 1)
self.norm41 = torch.nn.BatchNorm2d(32)
self.skip21 = torch.nn.Conv2d(1, 32, 1, 2)
# Block 2D 2
self.conv32 = torch.nn.Conv2d(32, 64, 3, 2, 1)
self.norm32 = torch.nn.BatchNorm2d(64)
self.conv42 = torch.nn.Conv2d(64, 128, 3, 2, 1)
self.norm42 = torch.nn.BatchNorm2d(128)
self.skip22 = torch.nn.Conv2d(32, 128, 1, 4)
# Block 3D 1
self.conv51 = torch.nn.Conv3d(1, 32, 3, (1, 2, 1), 1)
self.norm51 = torch.nn.BatchNorm3d(32)
self.conv61 = torch.nn.Conv3d(32, 64, 3, 2, 1)
self.norm61 = torch.nn.BatchNorm3d(64)
self.skip31 = torch.nn.Conv3d(1, 64, 1, (2, 4, 2))
# Block 3D 2
self.conv52 = torch.nn.Conv3d(64, 128, 3, 2, 1)
self.norm52 = torch.nn.BatchNorm3d(128)
self.conv62 = torch.nn.Conv3d(128, 256, 3, 2, 1)
self.norm62 = torch.nn.BatchNorm3d(256)
self.skip32 = torch.nn.Conv3d(64, 256, 1, 4)
# Fully connected
self.soft_max = torch.nn.Softmax(-1) # If the time stride is too big, the softmax will be done on a singleton
# which always ouput a 1
self.fc1 = torch.nn.Linear(4096, 1024)
self.fc2 = torch.nn.Linear(1024, 512)
self.fc3 = torch.nn.Linear(512, num_class)
def forward(self, x):
# Block 1D 1
out = self.conv11(x)
out = self.norm11(out)
out = self.activation(out)
out = self.conv21(out)
out = self.norm21(out)
skip = self.skip11(x)
out = self.activation(out + skip)
# Block 1D 2
skip = self.skip12(out)
out = self.conv12(out)
out = self.norm12(out)
out = self.activation(out)
out = self.conv22(out)
out = self.norm22(out)
out = self.activation(out + skip)
# Block 2D 1
out = out.reshape((lambda b, c, h: (b, 1, c, h))(*out.shape))
skip = self.skip21(out)
out = self.conv31(out)
out = self.norm31(out)
out = self.activation(out)
out = self.conv41(out)
out = self.norm41(out)
out = self.activation(out + skip)
# Block 2D 2
skip = self.skip22(out)
out = self.conv32(out)
out = self.norm32(out)
out = self.activation(out)
out = self.conv42(out)
out = self.norm42(out)
out = self.activation(out + skip)
# Block 3D 1
out = out.reshape((lambda b, c, w, h: (b, 1, c, w, h))(*out.shape))
skip = self.skip31(out)
out = self.conv51(out)
out = self.norm51(out)
out = self.activation(out)
out = self.conv61(out)
out = self.norm61(out)
out = self.activation(out + skip)
# Block 3D 2
skip = self.skip32(out)
out = self.conv52(out)
out = self.norm52(out)
out = self.activation(out)
out = self.conv62(out)
out = self.norm62(out)
out = self.activation(out + skip)
# Fully connected
out = torch.max(self.soft_max(out), -1)[0].reshape(-1, 4096)
out = self.activation(self.fc1(out))
out = self.activation(self.fc2(out))
return self.fc3(out)
def main(args):
batch_size = 64
num_feature = 4096
num_classes = 10
rng = np.random.RandomState(42)
class UpDimV2(torch.nn.Module):
def __init__(self, num_class):
super(UpDimV2, self).__init__()
self.activation = torch.nn.LeakyReLU(0.001, inplace=True)
# Block 1D 1
self.conv11 = torch.nn.Conv1d(1, 32, 3, 1, 1)
self.norm11 = torch.nn.BatchNorm1d(32)
self.conv21 = torch.nn.Conv1d(32, 32, 3, 2, 1)
self.norm21 = torch.nn.BatchNorm1d(32)
self.skip11 = torch.nn.Conv1d(1, 32, 1, 2)
# Block 1D 2
self.conv12 = torch.nn.Conv1d(32, 64, 3, 2, 1)
self.norm12 = torch.nn.BatchNorm1d(64)
self.conv22 = torch.nn.Conv1d(64, 128, 3, 2, 1)
self.norm22 = torch.nn.BatchNorm1d(128)
self.skip12 = torch.nn.Conv1d(32, 128, 1, 4)
# Block 2D 1
self.conv31 = torch.nn.Conv2d(1, 32, 3, 1, 1)
self.norm31 = torch.nn.BatchNorm2d(32)
self.conv41 = torch.nn.Conv2d(32, 32, 3, 2, 1)
self.norm41 = torch.nn.BatchNorm2d(32)
self.skip21 = torch.nn.Conv2d(1, 32, 1, 2)
# Block 2D 2
self.conv32 = torch.nn.Conv2d(32, 64, 3, 2, 1)
self.norm32 = torch.nn.BatchNorm2d(64)
self.conv42 = torch.nn.Conv2d(64, 128, 3, 2, 1)
self.norm42 = torch.nn.BatchNorm2d(128)
self.skip22 = torch.nn.Conv2d(32, 128, 1, 4)
# Block 3D 1
self.conv51 = torch.nn.Conv3d(1, 32, 3, (1, 2, 1), 1)
self.norm51 = torch.nn.BatchNorm3d(32)
self.conv61 = torch.nn.Conv3d(32, 64, 3, 2, 1)
self.norm61 = torch.nn.BatchNorm3d(64)
self.skip31 = torch.nn.Conv3d(1, 64, 1, (2, 4, 2))
# Block 3D 2
self.conv52 = torch.nn.Conv3d(64, 128, 3, 2, 1)
self.norm52 = torch.nn.BatchNorm3d(128)
self.conv62 = torch.nn.Conv3d(128, 256, 3, 2, 1)
self.norm62 = torch.nn.BatchNorm3d(256)
self.skip32 = torch.nn.Conv3d(64, 256, 1, 4)
# Fully connected
self.soft_max = torch.nn.Softmax(-1) # If the time stride is too big, the softmax will be done on a singleton
# which always ouput a 1
self.fc1 = torch.nn.Linear(4096, 1024)
self.fc2 = torch.nn.Linear(1024, 512)
self.fc3 = torch.nn.Linear(512, num_class)
def forward(self, x):
# Block 1D 1
out = self.conv11(x)
out = self.norm11(out)
out = self.activation(out)
out = self.conv21(out)
out = self.norm21(out)
skip = self.skip11(x)
out = self.activation(out + skip)
# Block 1D 2
skip = self.skip12(out)
out = self.conv12(out)
out = self.norm12(out)
out = self.activation(out)
out = self.conv22(out)
out = self.norm22(out)
out = self.activation(out + skip)
# Block 2D 1
out = out.reshape((lambda b, c, h: (b, 1, c, h))(*out.shape))
skip = self.skip21(out)
out = self.conv31(out)
out = self.norm31(out)
out = self.activation(out)
out = self.conv41(out)
out = self.norm41(out)
out = self.activation(out + skip)
# Block 2D 2
skip = self.skip22(out)
out = self.conv32(out)
out = self.norm32(out)
out = self.activation(out)
out = self.conv42(out)
out = self.norm42(out)
out = self.activation(out + skip)
# Block 3D 1
out = out.reshape((lambda b, c, w, h: (b, 1, c, w, h))(*out.shape))
skip = self.skip31(out)
out = self.conv51(out)
out = self.norm51(out)
out = self.activation(out)
out = self.conv61(out)
out = self.norm61(out)
out = self.activation(out + skip)
# Block 3D 2
skip = self.skip32(out)
out = self.conv52(out)
out = self.norm52(out)
out = self.activation(out)
out = self.conv62(out)
out = self.norm62(out)
out = self.activation(out + skip)
# Fully connected
out = torch.max(self.soft_max(out), -1)[0].reshape(-1, 4096)
out = self.activation(self.fc1(out))
out = self.activation(self.fc2(out))
return self.fc3(out)
model = torch.nn.DataParallel(UpDimV2(num_classes))
model.load_state_dict((torch.load(args.weight)['model']))
model.to('cuda')
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment