diff --git a/forward_UpDimV2_long.py b/forward_UpDimV2_long.py
index 945056c3e0307740066aff175d7f63957e1790cc..f941d4f0aa1f882552ad4f33e55215becf70c664 100644
--- a/forward_UpDimV2_long.py
+++ b/forward_UpDimV2_long.py
@@ -17,131 +17,130 @@ from tqdm import tqdm, trange
 from math import ceil
 
 
+class UpDimV2(torch.nn.Module):
+    def __init__(self, num_class):
+        super(UpDimV2, self).__init__()
+        self.activation = torch.nn.LeakyReLU(0.001, inplace=True)
+
+        # Block 1D 1
+        self.conv11 = torch.nn.Conv1d(1, 32, 3, 1, 1)
+        self.norm11 = torch.nn.BatchNorm1d(32)
+        self.conv21 = torch.nn.Conv1d(32, 32, 3, 2, 1)
+        self.norm21 = torch.nn.BatchNorm1d(32)
+        self.skip11 = torch.nn.Conv1d(1, 32, 1, 2)
+
+        # Block 1D 2
+        self.conv12 = torch.nn.Conv1d(32, 64, 3, 2, 1)
+        self.norm12 = torch.nn.BatchNorm1d(64)
+        self.conv22 = torch.nn.Conv1d(64, 128, 3, 2, 1)
+        self.norm22 = torch.nn.BatchNorm1d(128)
+        self.skip12 = torch.nn.Conv1d(32, 128, 1, 4)
+
+        # Block 2D 1
+        self.conv31 = torch.nn.Conv2d(1, 32, 3, 1, 1)
+        self.norm31 = torch.nn.BatchNorm2d(32)
+        self.conv41 = torch.nn.Conv2d(32, 32, 3, 2, 1)
+        self.norm41 = torch.nn.BatchNorm2d(32)
+        self.skip21 = torch.nn.Conv2d(1, 32, 1, 2)
+
+        # Block 2D 2
+        self.conv32 = torch.nn.Conv2d(32, 64, 3, 2, 1)
+        self.norm32 = torch.nn.BatchNorm2d(64)
+        self.conv42 = torch.nn.Conv2d(64, 128, 3, 2, 1)
+        self.norm42 = torch.nn.BatchNorm2d(128)
+        self.skip22 = torch.nn.Conv2d(32, 128, 1, 4)
+
+        # Block 3D 1
+        self.conv51 = torch.nn.Conv3d(1, 32, 3, (1, 2, 1), 1)
+        self.norm51 = torch.nn.BatchNorm3d(32)
+        self.conv61 = torch.nn.Conv3d(32, 64, 3, 2, 1)
+        self.norm61 = torch.nn.BatchNorm3d(64)
+        self.skip31 = torch.nn.Conv3d(1, 64, 1, (2, 4, 2))
+
+        # Block 3D 2
+        self.conv52 = torch.nn.Conv3d(64, 128, 3, 2, 1)
+        self.norm52 = torch.nn.BatchNorm3d(128)
+        self.conv62 = torch.nn.Conv3d(128, 256, 3, 2, 1)
+        self.norm62 = torch.nn.BatchNorm3d(256)
+        self.skip32 = torch.nn.Conv3d(64, 256, 1, 4)
+
+        # Fully connected
+        self.soft_max = torch.nn.Softmax(-1)  # If the time stride is too big, the softmax will be done on a singleton
+        # which always ouput a 1
+        self.fc1 = torch.nn.Linear(4096, 1024)
+        self.fc2 = torch.nn.Linear(1024, 512)
+        self.fc3 = torch.nn.Linear(512, num_class)
+
+    def forward(self, x):
+        # Block 1D 1
+        out = self.conv11(x)
+        out = self.norm11(out)
+        out = self.activation(out)
+        out = self.conv21(out)
+        out = self.norm21(out)
+        skip = self.skip11(x)
+        out = self.activation(out + skip)
+
+        # Block 1D 2
+        skip = self.skip12(out)
+        out = self.conv12(out)
+        out = self.norm12(out)
+        out = self.activation(out)
+        out = self.conv22(out)
+        out = self.norm22(out)
+        out = self.activation(out + skip)
+
+        # Block 2D 1
+        out = out.reshape((lambda b, c, h: (b, 1, c, h))(*out.shape))
+        skip = self.skip21(out)
+        out = self.conv31(out)
+        out = self.norm31(out)
+        out = self.activation(out)
+        out = self.conv41(out)
+        out = self.norm41(out)
+        out = self.activation(out + skip)
+
+        # Block 2D 2
+        skip = self.skip22(out)
+        out = self.conv32(out)
+        out = self.norm32(out)
+        out = self.activation(out)
+        out = self.conv42(out)
+        out = self.norm42(out)
+        out = self.activation(out + skip)
+
+        # Block 3D 1
+        out = out.reshape((lambda b, c, w, h: (b, 1, c, w, h))(*out.shape))
+        skip = self.skip31(out)
+        out = self.conv51(out)
+        out = self.norm51(out)
+        out = self.activation(out)
+        out = self.conv61(out)
+        out = self.norm61(out)
+        out = self.activation(out + skip)
+
+        # Block 3D 2
+        skip = self.skip32(out)
+        out = self.conv52(out)
+        out = self.norm52(out)
+        out = self.activation(out)
+        out = self.conv62(out)
+        out = self.norm62(out)
+        out = self.activation(out + skip)
+
+        # Fully connected
+        out = torch.max(self.soft_max(out), -1)[0].reshape(-1, 4096)
+        out = self.activation(self.fc1(out))
+        out = self.activation(self.fc2(out))
+        return self.fc3(out)
+
+
 def main(args):
     batch_size = 64
     num_feature = 4096
     num_classes = 10
     rng = np.random.RandomState(42)
 
-    class UpDimV2(torch.nn.Module):
-
-        def __init__(self, num_class):
-            super(UpDimV2, self).__init__()
-            self.activation = torch.nn.LeakyReLU(0.001, inplace=True)
-
-            # Block 1D 1
-            self.conv11 = torch.nn.Conv1d(1, 32, 3, 1, 1)
-            self.norm11 = torch.nn.BatchNorm1d(32)
-            self.conv21 = torch.nn.Conv1d(32, 32, 3, 2, 1)
-            self.norm21 = torch.nn.BatchNorm1d(32)
-            self.skip11 = torch.nn.Conv1d(1, 32, 1, 2)
-
-            # Block 1D 2
-            self.conv12 = torch.nn.Conv1d(32, 64, 3, 2, 1)
-            self.norm12 = torch.nn.BatchNorm1d(64)
-            self.conv22 = torch.nn.Conv1d(64, 128, 3, 2, 1)
-            self.norm22 = torch.nn.BatchNorm1d(128)
-            self.skip12 = torch.nn.Conv1d(32, 128, 1, 4)
-
-            # Block 2D 1
-            self.conv31 = torch.nn.Conv2d(1, 32, 3, 1, 1)
-            self.norm31 = torch.nn.BatchNorm2d(32)
-            self.conv41 = torch.nn.Conv2d(32, 32, 3, 2, 1)
-            self.norm41 = torch.nn.BatchNorm2d(32)
-            self.skip21 = torch.nn.Conv2d(1, 32, 1, 2)
-
-            # Block 2D 2
-            self.conv32 = torch.nn.Conv2d(32, 64, 3, 2, 1)
-            self.norm32 = torch.nn.BatchNorm2d(64)
-            self.conv42 = torch.nn.Conv2d(64, 128, 3, 2, 1)
-            self.norm42 = torch.nn.BatchNorm2d(128)
-            self.skip22 = torch.nn.Conv2d(32, 128, 1, 4)
-
-            # Block 3D 1
-            self.conv51 = torch.nn.Conv3d(1, 32, 3, (1, 2, 1), 1)
-            self.norm51 = torch.nn.BatchNorm3d(32)
-            self.conv61 = torch.nn.Conv3d(32, 64, 3, 2, 1)
-            self.norm61 = torch.nn.BatchNorm3d(64)
-            self.skip31 = torch.nn.Conv3d(1, 64, 1, (2, 4, 2))
-
-            # Block 3D 2
-            self.conv52 = torch.nn.Conv3d(64, 128, 3, 2, 1)
-            self.norm52 = torch.nn.BatchNorm3d(128)
-            self.conv62 = torch.nn.Conv3d(128, 256, 3, 2, 1)
-            self.norm62 = torch.nn.BatchNorm3d(256)
-            self.skip32 = torch.nn.Conv3d(64, 256, 1, 4)
-
-            # Fully connected
-            self.soft_max = torch.nn.Softmax(-1)  # If the time stride is too big, the softmax will be done on a singleton
-            # which always ouput a 1
-            self.fc1 = torch.nn.Linear(4096, 1024)
-            self.fc2 = torch.nn.Linear(1024, 512)
-            self.fc3 = torch.nn.Linear(512, num_class)
-
-        def forward(self, x):
-            # Block 1D 1
-            out = self.conv11(x)
-            out = self.norm11(out)
-            out = self.activation(out)
-            out = self.conv21(out)
-            out = self.norm21(out)
-            skip = self.skip11(x)
-            out = self.activation(out + skip)
-
-            # Block 1D 2
-            skip = self.skip12(out)
-            out = self.conv12(out)
-            out = self.norm12(out)
-            out = self.activation(out)
-            out = self.conv22(out)
-            out = self.norm22(out)
-            out = self.activation(out + skip)
-
-            # Block 2D 1
-            out = out.reshape((lambda b, c, h: (b, 1, c, h))(*out.shape))
-            skip = self.skip21(out)
-            out = self.conv31(out)
-            out = self.norm31(out)
-            out = self.activation(out)
-            out = self.conv41(out)
-            out = self.norm41(out)
-            out = self.activation(out + skip)
-
-            # Block 2D 2
-            skip = self.skip22(out)
-            out = self.conv32(out)
-            out = self.norm32(out)
-            out = self.activation(out)
-            out = self.conv42(out)
-            out = self.norm42(out)
-            out = self.activation(out + skip)
-
-            # Block 3D 1
-            out = out.reshape((lambda b, c, w, h: (b, 1, c, w, h))(*out.shape))
-            skip = self.skip31(out)
-            out = self.conv51(out)
-            out = self.norm51(out)
-            out = self.activation(out)
-            out = self.conv61(out)
-            out = self.norm61(out)
-            out = self.activation(out + skip)
-
-            # Block 3D 2
-            skip = self.skip32(out)
-            out = self.conv52(out)
-            out = self.norm52(out)
-            out = self.activation(out)
-            out = self.conv62(out)
-            out = self.norm62(out)
-            out = self.activation(out + skip)
-
-            # Fully connected
-            out = torch.max(self.soft_max(out), -1)[0].reshape(-1, 4096)
-            out = self.activation(self.fc1(out))
-            out = self.activation(self.fc2(out))
-            return self.fc3(out)
-
-
     model = torch.nn.DataParallel(UpDimV2(num_classes))
     model.load_state_dict((torch.load(args.weight)['model']))
     model.to('cuda')