Skip to content
Snippets Groups Projects
Commit 0980c6a1 authored by Julien Dejasmin's avatar Julien Dejasmin
Browse files

cluster update

parent 32b57c07
Branches
No related tags found
No related merge requests found
/data1/home/julien.dejasmin/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/_reduction.py:43: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.
warnings.warn(warning.format(ret))
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Namespace(batch_size=64, beta=4, cont_capacity=None, dataset='rendered_chairs', disc_capacity=None, epochs=40, experiment_name='beta_VAE', is_beta_VAE=True, latent_name='', latent_spec_cont=10, latent_spec_disc=None, lr=0.0001, print_loss_every=50, record_loss_every=50, save_model=True, save_reconstruction_image=True)
load dataset: rendered_chairs, with: 69120 train images of shape: (3, 64, 64)
use 1 gpu who named: GeForce GTX 1080 Ti
VAE(
(img_to_last_conv): Sequential(
(0): Conv2d(3, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(1): ReLU()
(2): Conv2d(32, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(3): ReLU()
(4): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(5): ReLU()
(6): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(7): ReLU()
)
(last_conv_to_continuous_features): Sequential(
(0): Conv2d(64, 256, kernel_size=(4, 4), stride=(1, 1))
(1): ReLU()
)
(features_to_hidden_continue): Sequential(
(0): Linear(in_features=256, out_features=20, bias=True)
(1): ReLU()
)
(latent_to_features): Sequential(
(0): Linear(in_features=10, out_features=256, bias=True)
(1): ReLU()
)
(features_to_img): Sequential(
(0): ConvTranspose2d(256, 64, kernel_size=(4, 4), stride=(1, 1))
(1): ReLU()
(2): ConvTranspose2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(3): ReLU()
(4): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(5): ReLU()
(6): ConvTranspose2d(32, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(7): ReLU()
(8): ConvTranspose2d(32, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(9): Sigmoid()
)
)
don't use continuous capacity
0/69092 Loss: 3070.657
3200/69092 Loss: 2948.644
6400/69092 Loss: 1024.792
9600/69092 Loss: 540.226
12800/69092 Loss: 470.434
16000/69092 Loss: 457.701
19200/69092 Loss: 450.401
22400/69092 Loss: 443.853
25600/69092 Loss: 451.292
28800/69092 Loss: 440.426
32000/69092 Loss: 432.866
35200/69092 Loss: 308.760
38400/69092 Loss: 242.492
41600/69092 Loss: 237.200
44800/69092 Loss: 229.539
48000/69092 Loss: 229.294
51200/69092 Loss: 230.983
54400/69092 Loss: 224.875
57600/69092 Loss: 220.705
60800/69092 Loss: 226.022
64000/69092 Loss: 220.454
67200/69092 Loss: 221.057
Epoch: 1 Average loss: 483.71
0/69092 Loss: 252.325
3200/69092 Loss: 220.659
6400/69092 Loss: 222.659
9600/69092 Loss: 219.368
12800/69092 Loss: 214.168
16000/69092 Loss: 212.482
19200/69092 Loss: 208.190
22400/69092 Loss: 205.885
25600/69092 Loss: 209.691
28800/69092 Loss: 203.559
32000/69092 Loss: 205.094
35200/69092 Loss: 199.194
38400/69092 Loss: 201.090
41600/69092 Loss: 197.349
44800/69092 Loss: 199.072
48000/69092 Loss: 195.767
51200/69092 Loss: 192.350
54400/69092 Loss: 188.418
57600/69092 Loss: 191.497
60800/69092 Loss: 192.091
64000/69092 Loss: 189.586
67200/69092 Loss: 192.249
Epoch: 2 Average loss: 202.72
0/69092 Loss: 192.982
3200/69092 Loss: 191.061
6400/69092 Loss: 191.400
9600/69092 Loss: 188.066
12800/69092 Loss: 193.611
16000/69092 Loss: 188.738
19200/69092 Loss: 186.129
22400/69092 Loss: 186.272
25600/69092 Loss: 188.331
28800/69092 Loss: 190.557
32000/69092 Loss: 189.598
35200/69092 Loss: 186.676
38400/69092 Loss: 191.831
41600/69092 Loss: 189.899
44800/69092 Loss: 187.892
48000/69092 Loss: 190.516
51200/69092 Loss: 185.773
54400/69092 Loss: 188.082
57600/69092 Loss: 185.783
60800/69092 Loss: 187.686
64000/69092 Loss: 185.843
67200/69092 Loss: 186.829
Epoch: 3 Average loss: 188.60
0/69092 Loss: 187.544
3200/69092 Loss: 191.071
6400/69092 Loss: 189.176
9600/69092 Loss: 188.550
12800/69092 Loss: 186.759
16000/69092 Loss: 181.701
19200/69092 Loss: 186.361
22400/69092 Loss: 182.646
25600/69092 Loss: 182.758
28800/69092 Loss: 181.681
32000/69092 Loss: 179.355
35200/69092 Loss: 181.080
38400/69092 Loss: 175.913
41600/69092 Loss: 179.204
44800/69092 Loss: 178.583
48000/69092 Loss: 175.061
51200/69092 Loss: 178.126
54400/69092 Loss: 179.390
57600/69092 Loss: 180.663
60800/69092 Loss: 180.700
64000/69092 Loss: 179.577
67200/69092 Loss: 182.585
Epoch: 4 Average loss: 181.75
0/69092 Loss: 171.749
3200/69092 Loss: 178.532
6400/69092 Loss: 175.331
9600/69092 Loss: 178.732
12800/69092 Loss: 178.259
16000/69092 Loss: 177.821
19200/69092 Loss: 178.540
22400/69092 Loss: 179.684
25600/69092 Loss: 179.641
28800/69092 Loss: 178.731
32000/69092 Loss: 177.929
35200/69092 Loss: 176.187
38400/69092 Loss: 172.829
41600/69092 Loss: 175.603
44800/69092 Loss: 175.634
48000/69092 Loss: 176.480
51200/69092 Loss: 179.569
54400/69092 Loss: 181.344
57600/69092 Loss: 176.494
60800/69092 Loss: 179.082
64000/69092 Loss: 178.474
67200/69092 Loss: 179.446
Epoch: 5 Average loss: 177.85
0/69092 Loss: 176.970
3200/69092 Loss: 176.408
6400/69092 Loss: 178.858
9600/69092 Loss: 178.331
12800/69092 Loss: 178.803
16000/69092 Loss: 176.130
19200/69092 Loss: 177.191
22400/69092 Loss: 174.921
25600/69092 Loss: 175.036
28800/69092 Loss: 174.670
32000/69092 Loss: 176.784
35200/69092 Loss: 175.621
38400/69092 Loss: 177.801
41600/69092 Loss: 178.486
44800/69092 Loss: 177.910
48000/69092 Loss: 179.626
51200/69092 Loss: 177.689
54400/69092 Loss: 177.252
57600/69092 Loss: 178.268
60800/69092 Loss: 174.528
64000/69092 Loss: 180.210
67200/69092 Loss: 172.981
Epoch: 6 Average loss: 176.96
0/69092 Loss: 172.989
3200/69092 Loss: 177.269
6400/69092 Loss: 175.520
9600/69092 Loss: 174.691
12800/69092 Loss: 176.304
16000/69092 Loss: 175.862
19200/69092 Loss: 177.363
22400/69092 Loss: 175.648
25600/69092 Loss: 179.415
28800/69092 Loss: 174.836
32000/69092 Loss: 179.038
35200/69092 Loss: 176.845
38400/69092 Loss: 174.129
41600/69092 Loss: 174.983
44800/69092 Loss: 176.457
48000/69092 Loss: 175.409
51200/69092 Loss: 179.870
54400/69092 Loss: 169.886
57600/69092 Loss: 174.812
60800/69092 Loss: 179.631
64000/69092 Loss: 175.263
67200/69092 Loss: 175.486
Epoch: 7 Average loss: 176.17
0/69092 Loss: 182.499
3200/69092 Loss: 174.724
6400/69092 Loss: 175.769
9600/69092 Loss: 175.399
12800/69092 Loss: 174.957
16000/69092 Loss: 173.714
19200/69092 Loss: 174.390
22400/69092 Loss: 173.339
25600/69092 Loss: 177.868
28800/69092 Loss: 176.928
32000/69092 Loss: 177.665
35200/69092 Loss: 177.244
38400/69092 Loss: 175.513
41600/69092 Loss: 174.195
44800/69092 Loss: 176.281
48000/69092 Loss: 174.752
51200/69092 Loss: 173.814
54400/69092 Loss: 176.664
57600/69092 Loss: 173.527
60800/69092 Loss: 180.212
64000/69092 Loss: 175.542
67200/69092 Loss: 176.622
Epoch: 8 Average loss: 175.71
0/69092 Loss: 156.728
3200/69092 Loss: 175.051
6400/69092 Loss: 173.837
9600/69092 Loss: 177.669
12800/69092 Loss: 177.154
16000/69092 Loss: 175.246
19200/69092 Loss: 175.701
22400/69092 Loss: 179.730
25600/69092 Loss: 177.038
28800/69092 Loss: 177.181
32000/69092 Loss: 175.627
35200/69092 Loss: 175.962
38400/69092 Loss: 175.364
41600/69092 Loss: 170.164
44800/69092 Loss: 173.813
48000/69092 Loss: 171.629
51200/69092 Loss: 173.023
54400/69092 Loss: 173.645
57600/69092 Loss: 174.178
60800/69092 Loss: 175.416
64000/69092 Loss: 177.290
67200/69092 Loss: 178.709
Epoch: 9 Average loss: 175.37
0/69092 Loss: 152.802
3200/69092 Loss: 176.915
6400/69092 Loss: 171.666
9600/69092 Loss: 174.777
12800/69092 Loss: 174.794
16000/69092 Loss: 178.759
19200/69092 Loss: 176.785
22400/69092 Loss: 171.137
25600/69092 Loss: 176.108
28800/69092 Loss: 176.560
32000/69092 Loss: 175.608
35200/69092 Loss: 175.323
38400/69092 Loss: 177.557
41600/69092 Loss: 171.559
44800/69092 Loss: 172.491
48000/69092 Loss: 173.053
51200/69092 Loss: 175.481
54400/69092 Loss: 175.098
57600/69092 Loss: 174.345
60800/69092 Loss: 174.417
64000/69092 Loss: 174.409
67200/69092 Loss: 173.647
Epoch: 10 Average loss: 174.87
0/69092 Loss: 179.213
3200/69092 Loss: 173.855
6400/69092 Loss: 174.073
9600/69092 Loss: 176.358
12800/69092 Loss: 177.361
16000/69092 Loss: 175.480
19200/69092 Loss: 175.529
22400/69092 Loss: 175.376
25600/69092 Loss: 171.995
28800/69092 Loss: 174.104
32000/69092 Loss: 175.416
35200/69092 Loss: 175.910
38400/69092 Loss: 176.820
41600/69092 Loss: 172.693
44800/69092 Loss: 174.116
48000/69092 Loss: 173.678
51200/69092 Loss: 174.519
54400/69092 Loss: 172.135
57600/69092 Loss: 175.865
60800/69092 Loss: 174.142
64000/69092 Loss: 172.220
67200/69092 Loss: 175.021
Epoch: 11 Average loss: 174.59
0/69092 Loss: 163.699
3200/69092 Loss: 175.630
6400/69092 Loss: 172.353
9600/69092 Loss: 173.406
12800/69092 Loss: 174.677
16000/69092 Loss: 173.659
19200/69092 Loss: 174.090
22400/69092 Loss: 174.560
25600/69092 Loss: 173.179
28800/69092 Loss: 173.921
32000/69092 Loss: 173.517
35200/69092 Loss: 173.364
38400/69092 Loss: 175.125
41600/69092 Loss: 174.068
44800/69092 Loss: 175.291
48000/69092 Loss: 173.934
51200/69092 Loss: 176.217
54400/69092 Loss: 172.097
57600/69092 Loss: 175.580
60800/69092 Loss: 174.114
64000/69092 Loss: 174.610
67200/69092 Loss: 173.205
Epoch: 12 Average loss: 174.19
0/69092 Loss: 191.185
3200/69092 Loss: 176.507
6400/69092 Loss: 173.035
9600/69092 Loss: 171.493
12800/69092 Loss: 177.913
16000/69092 Loss: 175.693
19200/69092 Loss: 172.820
22400/69092 Loss: 171.520
25600/69092 Loss: 172.230
28800/69092 Loss: 171.340
32000/69092 Loss: 174.393
35200/69092 Loss: 171.775
38400/69092 Loss: 176.086
41600/69092 Loss: 172.708
44800/69092 Loss: 173.573
48000/69092 Loss: 172.352
51200/69092 Loss: 175.626
54400/69092 Loss: 174.810
57600/69092 Loss: 173.702
60800/69092 Loss: 176.364
64000/69092 Loss: 170.513
67200/69092 Loss: 173.109
Epoch: 13 Average loss: 173.72
0/69092 Loss: 159.396
3200/69092 Loss: 175.317
6400/69092 Loss: 173.157
9600/69092 Loss: 173.722
12800/69092 Loss: 171.148
16000/69092 Loss: 175.274
19200/69092 Loss: 176.036
22400/69092 Loss: 172.847
25600/69092 Loss: 166.729
28800/69092 Loss: 169.101
32000/69092 Loss: 172.004
35200/69092 Loss: 171.771
38400/69092 Loss: 170.028
41600/69092 Loss: 168.713
44800/69092 Loss: 172.679
48000/69092 Loss: 168.519
51200/69092 Loss: 167.290
54400/69092 Loss: 168.153
57600/69092 Loss: 170.113
60800/69092 Loss: 171.238
64000/69092 Loss: 166.859
67200/69092 Loss: 171.634
Epoch: 14 Average loss: 171.06
0/69092 Loss: 156.795
3200/69092 Loss: 169.503
6400/69092 Loss: 170.706
9600/69092 Loss: 167.440
12800/69092 Loss: 169.419
16000/69092 Loss: 167.985
19200/69092 Loss: 168.267
22400/69092 Loss: 168.930
25600/69092 Loss: 168.436
28800/69092 Loss: 170.484
32000/69092 Loss: 170.663
35200/69092 Loss: 165.644
38400/69092 Loss: 169.801
41600/69092 Loss: 167.051
44800/69092 Loss: 169.969
48000/69092 Loss: 166.461
51200/69092 Loss: 169.932
54400/69092 Loss: 170.305
57600/69092 Loss: 168.247
60800/69092 Loss: 168.439
64000/69092 Loss: 169.417
67200/69092 Loss: 168.768
Epoch: 15 Average loss: 168.81
0/69092 Loss: 146.202
3200/69092 Loss: 170.373
6400/69092 Loss: 164.957
9600/69092 Loss: 168.805
12800/69092 Loss: 169.648
16000/69092 Loss: 167.061
19200/69092 Loss: 168.539
22400/69092 Loss: 166.700
25600/69092 Loss: 170.245
28800/69092 Loss: 169.705
32000/69092 Loss: 167.015
35200/69092 Loss: 165.536
38400/69092 Loss: 168.482
41600/69092 Loss: 166.300
44800/69092 Loss: 169.078
48000/69092 Loss: 170.031
51200/69092 Loss: 164.588
54400/69092 Loss: 169.601
57600/69092 Loss: 170.558
60800/69092 Loss: 169.104
64000/69092 Loss: 168.777
67200/69092 Loss: 167.852
Epoch: 16 Average loss: 168.19
0/69092 Loss: 173.891
3200/69092 Loss: 167.297
6400/69092 Loss: 167.022
9600/69092 Loss: 169.751
12800/69092 Loss: 173.189
16000/69092 Loss: 167.852
19200/69092 Loss: 166.613
22400/69092 Loss: 168.037
25600/69092 Loss: 166.963
28800/69092 Loss: 165.673
32000/69092 Loss: 168.600
35200/69092 Loss: 168.420
38400/69092 Loss: 167.863
41600/69092 Loss: 170.022
44800/69092 Loss: 166.268
48000/69092 Loss: 166.312
51200/69092 Loss: 168.968
54400/69092 Loss: 170.261
57600/69092 Loss: 167.449
60800/69092 Loss: 167.692
64000/69092 Loss: 166.189
67200/69092 Loss: 169.426
Epoch: 17 Average loss: 168.12
0/69092 Loss: 158.335
3200/69092 Loss: 166.323
6400/69092 Loss: 168.593
9600/69092 Loss: 169.252
12800/69092 Loss: 165.184
16000/69092 Loss: 166.948
19200/69092 Loss: 167.851
22400/69092 Loss: 167.890
25600/69092 Loss: 166.056
28800/69092 Loss: 166.313
32000/69092 Loss: 166.684
35200/69092 Loss: 166.046
38400/69092 Loss: 168.573
41600/69092 Loss: 170.346
44800/69092 Loss: 168.576
48000/69092 Loss: 167.955
51200/69092 Loss: 169.517
54400/69092 Loss: 170.773
57600/69092 Loss: 167.093
60800/69092 Loss: 167.219
64000/69092 Loss: 167.393
67200/69092 Loss: 170.586
Epoch: 18 Average loss: 167.79
0/69092 Loss: 156.962
3200/69092 Loss: 166.186
6400/69092 Loss: 168.813
9600/69092 Loss: 166.400
12800/69092 Loss: 167.795
16000/69092 Loss: 167.113
19200/69092 Loss: 165.320
22400/69092 Loss: 165.054
25600/69092 Loss: 168.109
28800/69092 Loss: 167.833
32000/69092 Loss: 166.894
35200/69092 Loss: 168.014
38400/69092 Loss: 168.756
41600/69092 Loss: 168.506
44800/69092 Loss: 169.647
48000/69092 Loss: 167.893
51200/69092 Loss: 167.151
54400/69092 Loss: 168.016
57600/69092 Loss: 168.288
60800/69092 Loss: 169.788
64000/69092 Loss: 165.068
67200/69092 Loss: 170.401
Epoch: 19 Average loss: 167.60
0/69092 Loss: 168.727
3200/69092 Loss: 168.674
6400/69092 Loss: 166.971
9600/69092 Loss: 164.222
12800/69092 Loss: 165.916
16000/69092 Loss: 173.111
19200/69092 Loss: 165.215
22400/69092 Loss: 168.106
25600/69092 Loss: 167.273
28800/69092 Loss: 167.498
32000/69092 Loss: 167.923
35200/69092 Loss: 168.921
38400/69092 Loss: 163.615
41600/69092 Loss: 167.147
44800/69092 Loss: 167.341
48000/69092 Loss: 169.308
51200/69092 Loss: 168.601
54400/69092 Loss: 169.787
57600/69092 Loss: 166.595
60800/69092 Loss: 169.997
64000/69092 Loss: 165.999
67200/69092 Loss: 169.111
Epoch: 20 Average loss: 167.67
0/69092 Loss: 164.874
3200/69092 Loss: 167.108
6400/69092 Loss: 166.190
9600/69092 Loss: 168.514
12800/69092 Loss: 166.697
16000/69092 Loss: 165.785
19200/69092 Loss: 166.035
22400/69092 Loss: 165.913
25600/69092 Loss: 167.381
28800/69092 Loss: 167.176
32000/69092 Loss: 168.299
35200/69092 Loss: 168.890
38400/69092 Loss: 165.954
41600/69092 Loss: 168.868
44800/69092 Loss: 166.462
48000/69092 Loss: 165.692
51200/69092 Loss: 170.486
54400/69092 Loss: 166.689
57600/69092 Loss: 168.941
60800/69092 Loss: 167.861
64000/69092 Loss: 169.267
67200/69092 Loss: 168.650
Epoch: 21 Average loss: 167.43
0/69092 Loss: 158.066
3200/69092 Loss: 166.319
6400/69092 Loss: 168.866
9600/69092 Loss: 168.782
12800/69092 Loss: 168.118
16000/69092 Loss: 168.135
19200/69092 Loss: 165.183
22400/69092 Loss: 171.452
25600/69092 Loss: 166.608
28800/69092 Loss: 169.231
32000/69092 Loss: 164.687
35200/69092 Loss: 167.388
38400/69092 Loss: 168.537
41600/69092 Loss: 166.544
44800/69092 Loss: 170.258
48000/69092 Loss: 164.814
51200/69092 Loss: 167.838
54400/69092 Loss: 167.264
57600/69092 Loss: 165.149
60800/69092 Loss: 168.617
64000/69092 Loss: 166.318
67200/69092 Loss: 167.333
Epoch: 22 Average loss: 167.44
0/69092 Loss: 164.065
3200/69092 Loss: 164.064
6400/69092 Loss: 167.517
9600/69092 Loss: 170.663
12800/69092 Loss: 167.499
16000/69092 Loss: 167.594
19200/69092 Loss: 166.922
22400/69092 Loss: 163.598
25600/69092 Loss: 168.840
28800/69092 Loss: 169.542
32000/69092 Loss: 168.083
35200/69092 Loss: 169.357
38400/69092 Loss: 166.990
41600/69092 Loss: 164.629
44800/69092 Loss: 168.236
48000/69092 Loss: 164.188
51200/69092 Loss: 166.765
54400/69092 Loss: 166.641
57600/69092 Loss: 170.710
60800/69092 Loss: 165.975
64000/69092 Loss: 168.500
67200/69092 Loss: 169.258
Epoch: 23 Average loss: 167.42
0/69092 Loss: 172.864
3200/69092 Loss: 165.477
6400/69092 Loss: 167.342
9600/69092 Loss: 167.331
12800/69092 Loss: 168.253
16000/69092 Loss: 164.884
19200/69092 Loss: 165.278
22400/69092 Loss: 169.037
25600/69092 Loss: 170.689
28800/69092 Loss: 167.800
32000/69092 Loss: 167.574
35200/69092 Loss: 165.465
38400/69092 Loss: 167.913
41600/69092 Loss: 166.253
44800/69092 Loss: 167.459
48000/69092 Loss: 167.662
51200/69092 Loss: 167.989
54400/69092 Loss: 166.499
57600/69092 Loss: 165.015
60800/69092 Loss: 167.097
64000/69092 Loss: 167.709
67200/69092 Loss: 166.537
Epoch: 24 Average loss: 167.05
0/69092 Loss: 159.477
3200/69092 Loss: 168.167
6400/69092 Loss: 165.640
9600/69092 Loss: 169.130
12800/69092 Loss: 164.949
16000/69092 Loss: 170.971
19200/69092 Loss: 168.796
22400/69092 Loss: 165.955
25600/69092 Loss: 166.466
28800/69092 Loss: 169.212
32000/69092 Loss: 167.467
35200/69092 Loss: 166.299
38400/69092 Loss: 167.788
41600/69092 Loss: 166.326
44800/69092 Loss: 166.547
48000/69092 Loss: 164.943
51200/69092 Loss: 166.213
54400/69092 Loss: 167.823
57600/69092 Loss: 165.211
60800/69092 Loss: 169.700
64000/69092 Loss: 167.593
67200/69092 Loss: 169.609
Epoch: 25 Average loss: 167.31
0/69092 Loss: 161.595
3200/69092 Loss: 167.897
6400/69092 Loss: 163.418
9600/69092 Loss: 164.333
12800/69092 Loss: 167.346
16000/69092 Loss: 168.345
19200/69092 Loss: 167.494
22400/69092 Loss: 168.143
25600/69092 Loss: 166.609
28800/69092 Loss: 162.532
32000/69092 Loss: 169.007
35200/69092 Loss: 168.024
38400/69092 Loss: 162.925
41600/69092 Loss: 169.271
44800/69092 Loss: 170.226
48000/69092 Loss: 167.290
51200/69092 Loss: 167.863
54400/69092 Loss: 168.024
57600/69092 Loss: 165.059
60800/69092 Loss: 168.377
64000/69092 Loss: 168.783
67200/69092 Loss: 165.481
Epoch: 26 Average loss: 167.07
0/69092 Loss: 181.461
3200/69092 Loss: 168.982
6400/69092 Loss: 164.568
9600/69092 Loss: 164.560
12800/69092 Loss: 169.223
16000/69092 Loss: 163.325
19200/69092 Loss: 169.991
22400/69092 Loss: 168.719
25600/69092 Loss: 164.910
28800/69092 Loss: 166.165
32000/69092 Loss: 167.390
35200/69092 Loss: 165.877
38400/69092 Loss: 166.731
41600/69092 Loss: 167.863
44800/69092 Loss: 166.774
48000/69092 Loss: 170.445
51200/69092 Loss: 169.944
54400/69092 Loss: 164.762
57600/69092 Loss: 165.063
60800/69092 Loss: 169.285
64000/69092 Loss: 167.527
67200/69092 Loss: 166.289
Epoch: 27 Average loss: 167.06
0/69092 Loss: 170.792
3200/69092 Loss: 167.467
6400/69092 Loss: 167.020
9600/69092 Loss: 168.963
12800/69092 Loss: 165.771
16000/69092 Loss: 166.994
19200/69092 Loss: 164.748
22400/69092 Loss: 165.152
25600/69092 Loss: 165.792
28800/69092 Loss: 165.728
32000/69092 Loss: 165.544
35200/69092 Loss: 169.239
38400/69092 Loss: 165.895
41600/69092 Loss: 169.427
44800/69092 Loss: 168.761
48000/69092 Loss: 167.436
51200/69092 Loss: 167.477
54400/69092 Loss: 167.437
57600/69092 Loss: 163.032
60800/69092 Loss: 168.386
64000/69092 Loss: 169.317
67200/69092 Loss: 166.414
Epoch: 28 Average loss: 166.99
0/69092 Loss: 159.176
3200/69092 Loss: 169.014
6400/69092 Loss: 165.076
9600/69092 Loss: 164.450
12800/69092 Loss: 165.767
16000/69092 Loss: 163.509
19200/69092 Loss: 164.611
22400/69092 Loss: 169.482
25600/69092 Loss: 164.828
28800/69092 Loss: 165.456
32000/69092 Loss: 166.719
35200/69092 Loss: 165.716
38400/69092 Loss: 169.506
41600/69092 Loss: 167.452
44800/69092 Loss: 167.680
48000/69092 Loss: 168.003
51200/69092 Loss: 166.531
54400/69092 Loss: 166.318
57600/69092 Loss: 167.834
60800/69092 Loss: 168.869
64000/69092 Loss: 166.379
67200/69092 Loss: 168.122
Epoch: 29 Average loss: 166.84
0/69092 Loss: 168.089
3200/69092 Loss: 164.654
6400/69092 Loss: 166.092
9600/69092 Loss: 165.645
12800/69092 Loss: 168.803
16000/69092 Loss: 166.414
19200/69092 Loss: 168.405
22400/69092 Loss: 167.098
25600/69092 Loss: 166.610
28800/69092 Loss: 167.311
32000/69092 Loss: 170.033
35200/69092 Loss: 166.208
38400/69092 Loss: 163.407
41600/69092 Loss: 166.786
44800/69092 Loss: 165.290
48000/69092 Loss: 168.366
51200/69092 Loss: 166.125
54400/69092 Loss: 164.718
57600/69092 Loss: 166.249
60800/69092 Loss: 166.015
64000/69092 Loss: 168.616
67200/69092 Loss: 169.330
Epoch: 30 Average loss: 166.81
0/69092 Loss: 178.073
3200/69092 Loss: 166.946
6400/69092 Loss: 165.648
9600/69092 Loss: 168.367
12800/69092 Loss: 166.295
16000/69092 Loss: 164.805
19200/69092 Loss: 168.451
22400/69092 Loss: 164.708
25600/69092 Loss: 166.344
28800/69092 Loss: 168.841
32000/69092 Loss: 168.623
35200/69092 Loss: 166.244
38400/69092 Loss: 163.225
41600/69092 Loss: 169.795
44800/69092 Loss: 167.962
48000/69092 Loss: 167.154
51200/69092 Loss: 168.323
54400/69092 Loss: 166.811
57600/69092 Loss: 171.028
60800/69092 Loss: 164.150
64000/69092 Loss: 166.475
67200/69092 Loss: 164.613
Epoch: 31 Average loss: 166.86
0/69092 Loss: 163.393
3200/69092 Loss: 164.554
6400/69092 Loss: 167.080
9600/69092 Loss: 171.187
12800/69092 Loss: 164.961
16000/69092 Loss: 163.417
19200/69092 Loss: 168.374
22400/69092 Loss: 165.693
25600/69092 Loss: 167.065
28800/69092 Loss: 164.642
32000/69092 Loss: 165.408
35200/69092 Loss: 166.985
38400/69092 Loss: 166.414
41600/69092 Loss: 170.011
44800/69092 Loss: 166.610
48000/69092 Loss: 169.074
51200/69092 Loss: 166.680
54400/69092 Loss: 168.125
57600/69092 Loss: 164.007
60800/69092 Loss: 166.152
64000/69092 Loss: 168.865
67200/69092 Loss: 163.801
Epoch: 32 Average loss: 166.62
0/69092 Loss: 145.340
3200/69092 Loss: 166.538
6400/69092 Loss: 167.724
9600/69092 Loss: 168.375
12800/69092 Loss: 165.096
16000/69092 Loss: 164.871
19200/69092 Loss: 164.603
22400/69092 Loss: 166.468
25600/69092 Loss: 168.188
28800/69092 Loss: 164.460
32000/69092 Loss: 166.869
35200/69092 Loss: 167.491
38400/69092 Loss: 164.241
41600/69092 Loss: 164.496
44800/69092 Loss: 164.427
48000/69092 Loss: 167.685
51200/69092 Loss: 167.454
54400/69092 Loss: 166.709
57600/69092 Loss: 166.155
60800/69092 Loss: 169.947
64000/69092 Loss: 165.286
67200/69092 Loss: 168.151
Epoch: 33 Average loss: 166.48
0/69092 Loss: 166.408
3200/69092 Loss: 167.340
6400/69092 Loss: 163.513
9600/69092 Loss: 165.043
12800/69092 Loss: 164.295
16000/69092 Loss: 164.211
19200/69092 Loss: 166.156
22400/69092 Loss: 169.719
25600/69092 Loss: 166.807
28800/69092 Loss: 165.537
32000/69092 Loss: 170.431
35200/69092 Loss: 165.410
38400/69092 Loss: 166.952
41600/69092 Loss: 168.332
44800/69092 Loss: 164.908
48000/69092 Loss: 168.677
51200/69092 Loss: 167.376
54400/69092 Loss: 166.778
57600/69092 Loss: 164.389
60800/69092 Loss: 168.517
64000/69092 Loss: 167.666
67200/69092 Loss: 168.580
Epoch: 34 Average loss: 166.73
0/69092 Loss: 160.057
3200/69092 Loss: 165.825
6400/69092 Loss: 165.326
9600/69092 Loss: 166.921
12800/69092 Loss: 166.886
16000/69092 Loss: 164.149
19200/69092 Loss: 166.317
22400/69092 Loss: 169.658
25600/69092 Loss: 167.307
28800/69092 Loss: 165.543
32000/69092 Loss: 165.719
35200/69092 Loss: 167.927
38400/69092 Loss: 166.107
41600/69092 Loss: 169.167
44800/69092 Loss: 164.462
48000/69092 Loss: 164.943
51200/69092 Loss: 167.662
54400/69092 Loss: 169.840
57600/69092 Loss: 162.048
60800/69092 Loss: 165.538
64000/69092 Loss: 169.039
67200/69092 Loss: 167.731
Epoch: 35 Average loss: 166.58
0/69092 Loss: 159.147
3200/69092 Loss: 168.711
6400/69092 Loss: 164.872
9600/69092 Loss: 166.459
12800/69092 Loss: 165.101
16000/69092 Loss: 165.007
19200/69092 Loss: 163.900
22400/69092 Loss: 165.651
25600/69092 Loss: 166.408
28800/69092 Loss: 168.436
32000/69092 Loss: 168.150
35200/69092 Loss: 166.650
38400/69092 Loss: 166.221
41600/69092 Loss: 165.079
44800/69092 Loss: 169.582
48000/69092 Loss: 166.133
51200/69092 Loss: 164.542
54400/69092 Loss: 166.025
57600/69092 Loss: 169.027
60800/69092 Loss: 166.220
64000/69092 Loss: 169.161
67200/69092 Loss: 165.857
Epoch: 36 Average loss: 166.59
0/69092 Loss: 163.261
3200/69092 Loss: 164.025
6400/69092 Loss: 165.219
9600/69092 Loss: 167.676
12800/69092 Loss: 168.915
16000/69092 Loss: 164.794
19200/69092 Loss: 171.172
22400/69092 Loss: 168.779
25600/69092 Loss: 166.745
28800/69092 Loss: 164.285
32000/69092 Loss: 165.940
35200/69092 Loss: 168.109
38400/69092 Loss: 164.635
41600/69092 Loss: 166.126
44800/69092 Loss: 167.761
48000/69092 Loss: 170.240
51200/69092 Loss: 163.132
54400/69092 Loss: 162.558
57600/69092 Loss: 169.405
60800/69092 Loss: 165.856
64000/69092 Loss: 166.046
67200/69092 Loss: 165.596
Epoch: 37 Average loss: 166.51
0/69092 Loss: 156.881
3200/69092 Loss: 164.568
6400/69092 Loss: 165.541
9600/69092 Loss: 165.421
12800/69092 Loss: 163.026
16000/69092 Loss: 163.961
19200/69092 Loss: 166.480
22400/69092 Loss: 166.765
25600/69092 Loss: 170.223
28800/69092 Loss: 166.316
32000/69092 Loss: 167.886
35200/69092 Loss: 170.375
38400/69092 Loss: 167.319
41600/69092 Loss: 166.670
44800/69092 Loss: 165.619
48000/69092 Loss: 167.852
51200/69092 Loss: 164.861
54400/69092 Loss: 164.426
57600/69092 Loss: 169.410
60800/69092 Loss: 166.011
64000/69092 Loss: 165.486
67200/69092 Loss: 164.748
Epoch: 38 Average loss: 166.34
0/69092 Loss: 154.482
3200/69092 Loss: 166.846
6400/69092 Loss: 168.174
9600/69092 Loss: 166.786
12800/69092 Loss: 164.238
16000/69092 Loss: 169.103
19200/69092 Loss: 165.084
22400/69092 Loss: 166.447
25600/69092 Loss: 166.493
28800/69092 Loss: 169.040
32000/69092 Loss: 169.643
35200/69092 Loss: 163.526
38400/69092 Loss: 166.074
41600/69092 Loss: 167.452
44800/69092 Loss: 164.863
48000/69092 Loss: 163.552
51200/69092 Loss: 166.178
54400/69092 Loss: 164.692
57600/69092 Loss: 167.591
60800/69092 Loss: 165.594
64000/69092 Loss: 167.904
67200/69092 Loss: 163.778
Epoch: 39 Average loss: 166.47
0/69092 Loss: 174.792
3200/69092 Loss: 168.603
6400/69092 Loss: 167.578
9600/69092 Loss: 166.309
12800/69092 Loss: 166.655
16000/69092 Loss: 166.937
19200/69092 Loss: 166.851
22400/69092 Loss: 168.955
25600/69092 Loss: 162.846
28800/69092 Loss: 166.400
32000/69092 Loss: 161.946
35200/69092 Loss: 164.650
38400/69092 Loss: 165.843
41600/69092 Loss: 166.236
44800/69092 Loss: 165.099
48000/69092 Loss: 165.672
51200/69092 Loss: 165.388
54400/69092 Loss: 169.332
57600/69092 Loss: 167.883
60800/69092 Loss: 167.210
64000/69092 Loss: 166.555
67200/69092 Loss: 167.086
Epoch: 40 Average loss: 166.41
"""
Code from: https://github.com/Schlumberger/joint-vae
https://github.com/1Konny/Beta-VAE
"""
# import sys
# sys.path.append('')
import os
from dataloader.dataloaders import *
import torch.nn as nn
from VAE_model.models import VAE
from torch import optim
from viz.visualize import Visualizer
from utils.training import Trainer, gpu_config
import argparse
import json
def main(args):
# continue and discrete capacity
if args.cont_capacity is not None:
cont_capacity = [float(item) for item in args.cont_capacity.split(',')]
else:
cont_capacity = args.cont_capacity
if args.disc_capacity is not None:
disc_capacity = [float(item) for item in args.disc_capacity.split(',')]
else:
disc_capacity = args.disc_capacity
# latent_spec
latent_spec = {"cont": args.latent_spec_cont}
# number of classes and image size:
if args.dataset == 'mnist' or args.dataset == 'fashion_data':
nb_classes = 10
img_size = (1, 32, 32)
elif args.dataset == 'celeba_64':
nb_classes = None
img_size = (3, 64, 64)
elif args.dataset == 'rendered_chairs':
nb_classes = 1393
img_size = (3, 64, 64)
elif args.dataset == 'dSprites':
nb_classes = 6
# create and write a json file:
if not args.load_model_checkpoint:
ckpt_dir = os.path.join('trained_models', args.dataset, args.experiment_name, args.ckpt_dir)
if not os.path.exists(ckpt_dir):
os.makedirs(ckpt_dir, exist_ok=True)
parameter = {'dataset': args.dataset, 'epochs': args.epochs, 'cont_capacity': args.cont_capacity,
'disc_capacity': args.disc_capacity, 'record_loss_every': args.record_loss_every,
'batch_size': args.batch_size, 'latent_spec_cont': args.latent_spec_cont,
'experiment_name': args.experiment_name, 'print_loss_every': args.print_loss_every,
'latent_spec_disc': args.latent_spec_disc, 'nb_classes': nb_classes}
# Save json parameters:
file_path = os.path.join('trained_models/', args.dataset, args.experiment_name, 'specs.json')
with open(file_path, 'w') as json_file:
json.dump(parameter, json_file)
print('ok')
# create model
model = VAE(img_size, latent_spec=latent_spec)
# load dataset
train_loader, test_loader, dataset_name = load_dataset(args.dataset, args.batch_size, num_worker=args.num_worker)
# Define model
model, use_gpu, device = gpu_config(model)
if args.verbose:
print(model)
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('The number of parameters of model is', num_params)
# Define optimizer and criterion
optimizer = optim.Adam(model.parameters(), lr=args.lr)
criterion = nn.CrossEntropyLoss()
# Define trainer
trainer = Trainer(model, device, optimizer, criterion, save_step=args.save_step, ckpt_dir=args.ckpt_dir,
ckpt_name=args.ckpt_name,
expe_name=args.experiment_name,
dataset=args.dataset,
cont_capacity=cont_capacity,
disc_capacity=disc_capacity,
is_beta=args.is_beta_VAE,
beta=args.beta)
# define visualizer
viz = Visualizer(model)
# Train model:
trainer.train(train_loader, args.epochs, save_training_gif=('../img_gif/' + dataset_name + '_' +
args.latent_name + args.experiment_name + '.gif', viz))
"""
# Save trained model
if args.save_model:
torch.save(trainer.model.state_dict(),
'../trained_models/' + dataset_name + '/model_' + args.experiment_name + '.pt')
"""
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='VAE')
parser.add_argument('--batch-size', type=int, default=64, metavar='integer value',
help='input batch size for training (default: 64)')
parser.add_argument('--record-loss-every', type=int, default=50, metavar='integer value',
help='Record loss every (value)')
parser.add_argument('--print-loss-every', type=int, default=50, metavar='integer value',
help='Print loss every (value)')
parser.add_argument('--epochs', type=int, default=100, metavar='integer value',
help='number of epochs to train (default: 100)')
parser.add_argument('--lr', type=float, default=5e-4, metavar='value',
help='learning rate value')
parser.add_argument('--dataset', type=str, default=None, metavar='name',
help='Dataset Name')
parser.add_argument('--save-model', type=bool, default=True, metavar='bool',
help='Save model')
parser.add_argument('--save-reconstruction-image', type=bool, default=False, metavar='bool',
help='Save reconstruction image')
parser.add_argument('--latent_spec_cont', type=int, default=10, metavar='integer value',
help='Capacity of continue latent space')
parser.add_argument('--latent_spec_disc', type=list, default=None, metavar='integer list',
help='Capacity of discrete latent space')
parser.add_argument('--cont-capacity', type=str, default=None, metavar='integer tuple',
help='capacity of continuous channels')
parser.add_argument('--disc-capacity', type=str, default=None, metavar='integer tuple',
help='capacity of discrete channels')
parser.add_argument('--experiment-name', type=str, default='', metavar='name',
help='experiment name')
parser.add_argument('--latent-name', type=str, default='', metavar='name',
help='Latent space name')
parser.add_argument('--is-beta-VAE', type=bool, default=False, metavar='beta_VAE',
help='If use beta-VAE')
parser.add_argument('--beta', type=int, default=None, metavar='beta',
help='Beta value')
parser.add_argument("--gpu-devices", type=int, nargs='+', default=None, help="GPU devices available")
parser.add_argument("--load-model-checkpoint", type=bool, default=False, help="If we use a pre trained model")
parser.add_argument("--load-expe-name", type=str, default='', help="The name expe to loading")
parser.add_argument("--num-worker", type=int, default=4, help="num worker to dataloader")
parser.add_argument("--verbose", type=bool, default=True, help="To print details model")
parser.add_argument("--save-step", type=int, default=1, help="save model every step")
parser.add_argument('--ckpt_dir', default='checkpoints', type=str, help='checkpoint directory')
parser.add_argument('--ckpt_name', default='last', type=str,
help='load previous checkpoint. insert checkpoint filename')
args = parser.parse_args()
assert args.dataset in ['mnist', 'fashion_data', 'celeba_64', 'rendered_chairs', 'dSprites'], \
"The choisen dataset is not available. Please choose a dataset from the following: ['mnist', 'fashion_data', " \
"'celeba_64', 'rendered_chairs', 'dSprites'] "
if args.is_beta_VAE:
assert args.beta is not None, 'Beta is null or if you use Beta-VAe model, please enter a beta value'
print(parser.parse_args())
gpu_devices = ','.join([str(id) for id in args.gpu_devices])
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_devices
main(args)
--batch-size=256 --dataset=rendered_chairs --epochs=40 --latent_spec_cont=10 --is-beta-VAE=True --beta=4 --lr=1e-4 --experiment-name=beta_VAE_bs_256 --gpu-devices 0 1
--batch-size=64 --dataset=rendered_chairs --epochs=40 --latent_spec_cont=10 --is-beta-VAE=True --beta=4 --lr=1e-4 --experiment-name=beta_VAE_bs_64 --gpu-devices 0 1
--batch-size=256 --dataset=rendered_chairs --epochs=40 --latent_spec_cont=10 --lr=1e-4 --experiment-name=VAE_bs_256 --gpu-devices 0 1
--batch-size=64 --dataset=rendered_chairs --epochs=40 --latent_spec_cont=10 --lr=1e-4 --experiment-name=VAE_bs_64 --gpu-devices 0 1 --load-model_checkpoint=False --experiment-name=VAE_bs_64
--batch-size=256 --dataset=rendered_chairs --epochs=400 --latent_spec_cont=10 --is-beta-VAE=True --beta=4 --lr=1e-4 --experiment-name=beta_VAE_bs_256 --gpu-devices 0 1 --experiment-name=beta_VAE_bs_256
--batch-size=64 --dataset=rendered_chairs --epochs=400 --latent_spec_cont=10 --is-beta-VAE=True --beta=4 --lr=1e-4 --experiment-name=beta_VAE_bs_64 --gpu-devices 0 1 --experiment-name=beta_VAE_bs_64
--batch-size=256 --dataset=rendered_chairs --epochs=400 --latent_spec_cont=10 --lr=1e-4 --experiment-name=VAE_bs_256 --gpu-devices 0 1 --experiment-name=VAE_bs_256
--batch-size=64 --dataset=rendered_chairs --epochs=400 --latent_spec_cont=10 --lr=1e-4 --experiment-name=VAE_bs_64 --gpu-devices 0 1 --experiment-name=VAE_bs_64
......@@ -32,7 +32,7 @@ class Trainer:
record_loss_every : int
Frequency with which loss is recorded during training.
"""
if type(model) == 'torch.nn.parallel.data_parallel.DataParallel':
if 'parallel' in str(type(model)):
self.model = model.module
else:
self.model = model
......
......@@ -15,7 +15,7 @@ class Visualizer:
----------
model : VAE_model.models.VAE instance
"""
if type(model) == 'torch.nn.parallel.data_parallel.DataParallel':
if 'parallel' in str(type(model)):
self.model = model.module
else:
self.model = model
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment