diff --git a/forward_UpDimV2_long.py b/forward_UpDimV2_long.py index 945056c3e0307740066aff175d7f63957e1790cc..f941d4f0aa1f882552ad4f33e55215becf70c664 100644 --- a/forward_UpDimV2_long.py +++ b/forward_UpDimV2_long.py @@ -17,131 +17,130 @@ from tqdm import tqdm, trange from math import ceil +class UpDimV2(torch.nn.Module): + def __init__(self, num_class): + super(UpDimV2, self).__init__() + self.activation = torch.nn.LeakyReLU(0.001, inplace=True) + + # Block 1D 1 + self.conv11 = torch.nn.Conv1d(1, 32, 3, 1, 1) + self.norm11 = torch.nn.BatchNorm1d(32) + self.conv21 = torch.nn.Conv1d(32, 32, 3, 2, 1) + self.norm21 = torch.nn.BatchNorm1d(32) + self.skip11 = torch.nn.Conv1d(1, 32, 1, 2) + + # Block 1D 2 + self.conv12 = torch.nn.Conv1d(32, 64, 3, 2, 1) + self.norm12 = torch.nn.BatchNorm1d(64) + self.conv22 = torch.nn.Conv1d(64, 128, 3, 2, 1) + self.norm22 = torch.nn.BatchNorm1d(128) + self.skip12 = torch.nn.Conv1d(32, 128, 1, 4) + + # Block 2D 1 + self.conv31 = torch.nn.Conv2d(1, 32, 3, 1, 1) + self.norm31 = torch.nn.BatchNorm2d(32) + self.conv41 = torch.nn.Conv2d(32, 32, 3, 2, 1) + self.norm41 = torch.nn.BatchNorm2d(32) + self.skip21 = torch.nn.Conv2d(1, 32, 1, 2) + + # Block 2D 2 + self.conv32 = torch.nn.Conv2d(32, 64, 3, 2, 1) + self.norm32 = torch.nn.BatchNorm2d(64) + self.conv42 = torch.nn.Conv2d(64, 128, 3, 2, 1) + self.norm42 = torch.nn.BatchNorm2d(128) + self.skip22 = torch.nn.Conv2d(32, 128, 1, 4) + + # Block 3D 1 + self.conv51 = torch.nn.Conv3d(1, 32, 3, (1, 2, 1), 1) + self.norm51 = torch.nn.BatchNorm3d(32) + self.conv61 = torch.nn.Conv3d(32, 64, 3, 2, 1) + self.norm61 = torch.nn.BatchNorm3d(64) + self.skip31 = torch.nn.Conv3d(1, 64, 1, (2, 4, 2)) + + # Block 3D 2 + self.conv52 = torch.nn.Conv3d(64, 128, 3, 2, 1) + self.norm52 = torch.nn.BatchNorm3d(128) + self.conv62 = torch.nn.Conv3d(128, 256, 3, 2, 1) + self.norm62 = torch.nn.BatchNorm3d(256) + self.skip32 = torch.nn.Conv3d(64, 256, 1, 4) + + # Fully connected + self.soft_max = torch.nn.Softmax(-1) # If the time stride is too big, the softmax will be done on a singleton + # which always ouput a 1 + self.fc1 = torch.nn.Linear(4096, 1024) + self.fc2 = torch.nn.Linear(1024, 512) + self.fc3 = torch.nn.Linear(512, num_class) + + def forward(self, x): + # Block 1D 1 + out = self.conv11(x) + out = self.norm11(out) + out = self.activation(out) + out = self.conv21(out) + out = self.norm21(out) + skip = self.skip11(x) + out = self.activation(out + skip) + + # Block 1D 2 + skip = self.skip12(out) + out = self.conv12(out) + out = self.norm12(out) + out = self.activation(out) + out = self.conv22(out) + out = self.norm22(out) + out = self.activation(out + skip) + + # Block 2D 1 + out = out.reshape((lambda b, c, h: (b, 1, c, h))(*out.shape)) + skip = self.skip21(out) + out = self.conv31(out) + out = self.norm31(out) + out = self.activation(out) + out = self.conv41(out) + out = self.norm41(out) + out = self.activation(out + skip) + + # Block 2D 2 + skip = self.skip22(out) + out = self.conv32(out) + out = self.norm32(out) + out = self.activation(out) + out = self.conv42(out) + out = self.norm42(out) + out = self.activation(out + skip) + + # Block 3D 1 + out = out.reshape((lambda b, c, w, h: (b, 1, c, w, h))(*out.shape)) + skip = self.skip31(out) + out = self.conv51(out) + out = self.norm51(out) + out = self.activation(out) + out = self.conv61(out) + out = self.norm61(out) + out = self.activation(out + skip) + + # Block 3D 2 + skip = self.skip32(out) + out = self.conv52(out) + out = self.norm52(out) + out = self.activation(out) + out = self.conv62(out) + out = self.norm62(out) + out = self.activation(out + skip) + + # Fully connected + out = torch.max(self.soft_max(out), -1)[0].reshape(-1, 4096) + out = self.activation(self.fc1(out)) + out = self.activation(self.fc2(out)) + return self.fc3(out) + + def main(args): batch_size = 64 num_feature = 4096 num_classes = 10 rng = np.random.RandomState(42) - class UpDimV2(torch.nn.Module): - - def __init__(self, num_class): - super(UpDimV2, self).__init__() - self.activation = torch.nn.LeakyReLU(0.001, inplace=True) - - # Block 1D 1 - self.conv11 = torch.nn.Conv1d(1, 32, 3, 1, 1) - self.norm11 = torch.nn.BatchNorm1d(32) - self.conv21 = torch.nn.Conv1d(32, 32, 3, 2, 1) - self.norm21 = torch.nn.BatchNorm1d(32) - self.skip11 = torch.nn.Conv1d(1, 32, 1, 2) - - # Block 1D 2 - self.conv12 = torch.nn.Conv1d(32, 64, 3, 2, 1) - self.norm12 = torch.nn.BatchNorm1d(64) - self.conv22 = torch.nn.Conv1d(64, 128, 3, 2, 1) - self.norm22 = torch.nn.BatchNorm1d(128) - self.skip12 = torch.nn.Conv1d(32, 128, 1, 4) - - # Block 2D 1 - self.conv31 = torch.nn.Conv2d(1, 32, 3, 1, 1) - self.norm31 = torch.nn.BatchNorm2d(32) - self.conv41 = torch.nn.Conv2d(32, 32, 3, 2, 1) - self.norm41 = torch.nn.BatchNorm2d(32) - self.skip21 = torch.nn.Conv2d(1, 32, 1, 2) - - # Block 2D 2 - self.conv32 = torch.nn.Conv2d(32, 64, 3, 2, 1) - self.norm32 = torch.nn.BatchNorm2d(64) - self.conv42 = torch.nn.Conv2d(64, 128, 3, 2, 1) - self.norm42 = torch.nn.BatchNorm2d(128) - self.skip22 = torch.nn.Conv2d(32, 128, 1, 4) - - # Block 3D 1 - self.conv51 = torch.nn.Conv3d(1, 32, 3, (1, 2, 1), 1) - self.norm51 = torch.nn.BatchNorm3d(32) - self.conv61 = torch.nn.Conv3d(32, 64, 3, 2, 1) - self.norm61 = torch.nn.BatchNorm3d(64) - self.skip31 = torch.nn.Conv3d(1, 64, 1, (2, 4, 2)) - - # Block 3D 2 - self.conv52 = torch.nn.Conv3d(64, 128, 3, 2, 1) - self.norm52 = torch.nn.BatchNorm3d(128) - self.conv62 = torch.nn.Conv3d(128, 256, 3, 2, 1) - self.norm62 = torch.nn.BatchNorm3d(256) - self.skip32 = torch.nn.Conv3d(64, 256, 1, 4) - - # Fully connected - self.soft_max = torch.nn.Softmax(-1) # If the time stride is too big, the softmax will be done on a singleton - # which always ouput a 1 - self.fc1 = torch.nn.Linear(4096, 1024) - self.fc2 = torch.nn.Linear(1024, 512) - self.fc3 = torch.nn.Linear(512, num_class) - - def forward(self, x): - # Block 1D 1 - out = self.conv11(x) - out = self.norm11(out) - out = self.activation(out) - out = self.conv21(out) - out = self.norm21(out) - skip = self.skip11(x) - out = self.activation(out + skip) - - # Block 1D 2 - skip = self.skip12(out) - out = self.conv12(out) - out = self.norm12(out) - out = self.activation(out) - out = self.conv22(out) - out = self.norm22(out) - out = self.activation(out + skip) - - # Block 2D 1 - out = out.reshape((lambda b, c, h: (b, 1, c, h))(*out.shape)) - skip = self.skip21(out) - out = self.conv31(out) - out = self.norm31(out) - out = self.activation(out) - out = self.conv41(out) - out = self.norm41(out) - out = self.activation(out + skip) - - # Block 2D 2 - skip = self.skip22(out) - out = self.conv32(out) - out = self.norm32(out) - out = self.activation(out) - out = self.conv42(out) - out = self.norm42(out) - out = self.activation(out + skip) - - # Block 3D 1 - out = out.reshape((lambda b, c, w, h: (b, 1, c, w, h))(*out.shape)) - skip = self.skip31(out) - out = self.conv51(out) - out = self.norm51(out) - out = self.activation(out) - out = self.conv61(out) - out = self.norm61(out) - out = self.activation(out + skip) - - # Block 3D 2 - skip = self.skip32(out) - out = self.conv52(out) - out = self.norm52(out) - out = self.activation(out) - out = self.conv62(out) - out = self.norm62(out) - out = self.activation(out + skip) - - # Fully connected - out = torch.max(self.soft_max(out), -1)[0].reshape(-1, 4096) - out = self.activation(self.fc1(out)) - out = self.activation(self.fc2(out)) - return self.fc3(out) - - model = torch.nn.DataParallel(UpDimV2(num_classes)) model.load_state_dict((torch.load(args.weight)['model'])) model.to('cuda')