diff --git a/Experiments/OAR.2063553.stderr b/Experiments/OAR.2063553.stderr deleted file mode 100644 index f23c8c2b48bbf3644f706e52d36cdaa1146ef50c..0000000000000000000000000000000000000000 --- a/Experiments/OAR.2063553.stderr +++ /dev/null @@ -1,42 +0,0 @@ -/data1/home/julien.dejasmin/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/_reduction.py:43: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead. - warnings.warn(warning.format(ret)) -Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning. -Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning. -Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning. -Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning. -Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning. -Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning. -Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning. -Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning. -Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning. -Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning. -Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning. -Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning. -Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning. -Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning. -Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning. -Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning. -Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning. -Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning. -Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning. -Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning. -Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning. -Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning. -Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning. -Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning. -Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning. -Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning. -Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning. -Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning. -Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning. -Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning. -Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning. -Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning. -Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning. -Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning. -Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning. -Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning. -Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning. -Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning. -Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning. -Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning. diff --git a/Experiments/OAR.2063553.stdout b/Experiments/OAR.2063553.stdout deleted file mode 100644 index e66b1fed58f8a754e5a80c01e1a068c942719988..0000000000000000000000000000000000000000 --- a/Experiments/OAR.2063553.stdout +++ /dev/null @@ -1,960 +0,0 @@ -Namespace(batch_size=64, beta=4, cont_capacity=None, dataset='rendered_chairs', disc_capacity=None, epochs=40, experiment_name='beta_VAE', is_beta_VAE=True, latent_name='', latent_spec_cont=10, latent_spec_disc=None, lr=0.0001, print_loss_every=50, record_loss_every=50, save_model=True, save_reconstruction_image=True) -load dataset: rendered_chairs, with: 69120 train images of shape: (3, 64, 64) -use 1 gpu who named: GeForce GTX 1080 Ti -VAE( - (img_to_last_conv): Sequential( - (0): Conv2d(3, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) - (1): ReLU() - (2): Conv2d(32, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) - (3): ReLU() - (4): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) - (5): ReLU() - (6): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) - (7): ReLU() - ) - (last_conv_to_continuous_features): Sequential( - (0): Conv2d(64, 256, kernel_size=(4, 4), stride=(1, 1)) - (1): ReLU() - ) - (features_to_hidden_continue): Sequential( - (0): Linear(in_features=256, out_features=20, bias=True) - (1): ReLU() - ) - (latent_to_features): Sequential( - (0): Linear(in_features=10, out_features=256, bias=True) - (1): ReLU() - ) - (features_to_img): Sequential( - (0): ConvTranspose2d(256, 64, kernel_size=(4, 4), stride=(1, 1)) - (1): ReLU() - (2): ConvTranspose2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) - (3): ReLU() - (4): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) - (5): ReLU() - (6): ConvTranspose2d(32, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) - (7): ReLU() - (8): ConvTranspose2d(32, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) - (9): Sigmoid() - ) -) -don't use continuous capacity -0/69092 Loss: 3070.657 -3200/69092 Loss: 2948.644 -6400/69092 Loss: 1024.792 -9600/69092 Loss: 540.226 -12800/69092 Loss: 470.434 -16000/69092 Loss: 457.701 -19200/69092 Loss: 450.401 -22400/69092 Loss: 443.853 -25600/69092 Loss: 451.292 -28800/69092 Loss: 440.426 -32000/69092 Loss: 432.866 -35200/69092 Loss: 308.760 -38400/69092 Loss: 242.492 -41600/69092 Loss: 237.200 -44800/69092 Loss: 229.539 -48000/69092 Loss: 229.294 -51200/69092 Loss: 230.983 -54400/69092 Loss: 224.875 -57600/69092 Loss: 220.705 -60800/69092 Loss: 226.022 -64000/69092 Loss: 220.454 -67200/69092 Loss: 221.057 -Epoch: 1 Average loss: 483.71 -0/69092 Loss: 252.325 -3200/69092 Loss: 220.659 -6400/69092 Loss: 222.659 -9600/69092 Loss: 219.368 -12800/69092 Loss: 214.168 -16000/69092 Loss: 212.482 -19200/69092 Loss: 208.190 -22400/69092 Loss: 205.885 -25600/69092 Loss: 209.691 -28800/69092 Loss: 203.559 -32000/69092 Loss: 205.094 -35200/69092 Loss: 199.194 -38400/69092 Loss: 201.090 -41600/69092 Loss: 197.349 -44800/69092 Loss: 199.072 -48000/69092 Loss: 195.767 -51200/69092 Loss: 192.350 -54400/69092 Loss: 188.418 -57600/69092 Loss: 191.497 -60800/69092 Loss: 192.091 -64000/69092 Loss: 189.586 -67200/69092 Loss: 192.249 -Epoch: 2 Average loss: 202.72 -0/69092 Loss: 192.982 -3200/69092 Loss: 191.061 -6400/69092 Loss: 191.400 -9600/69092 Loss: 188.066 -12800/69092 Loss: 193.611 -16000/69092 Loss: 188.738 -19200/69092 Loss: 186.129 -22400/69092 Loss: 186.272 -25600/69092 Loss: 188.331 -28800/69092 Loss: 190.557 -32000/69092 Loss: 189.598 -35200/69092 Loss: 186.676 -38400/69092 Loss: 191.831 -41600/69092 Loss: 189.899 -44800/69092 Loss: 187.892 -48000/69092 Loss: 190.516 -51200/69092 Loss: 185.773 -54400/69092 Loss: 188.082 -57600/69092 Loss: 185.783 -60800/69092 Loss: 187.686 -64000/69092 Loss: 185.843 -67200/69092 Loss: 186.829 -Epoch: 3 Average loss: 188.60 -0/69092 Loss: 187.544 -3200/69092 Loss: 191.071 -6400/69092 Loss: 189.176 -9600/69092 Loss: 188.550 -12800/69092 Loss: 186.759 -16000/69092 Loss: 181.701 -19200/69092 Loss: 186.361 -22400/69092 Loss: 182.646 -25600/69092 Loss: 182.758 -28800/69092 Loss: 181.681 -32000/69092 Loss: 179.355 -35200/69092 Loss: 181.080 -38400/69092 Loss: 175.913 -41600/69092 Loss: 179.204 -44800/69092 Loss: 178.583 -48000/69092 Loss: 175.061 -51200/69092 Loss: 178.126 -54400/69092 Loss: 179.390 -57600/69092 Loss: 180.663 -60800/69092 Loss: 180.700 -64000/69092 Loss: 179.577 -67200/69092 Loss: 182.585 -Epoch: 4 Average loss: 181.75 -0/69092 Loss: 171.749 -3200/69092 Loss: 178.532 -6400/69092 Loss: 175.331 -9600/69092 Loss: 178.732 -12800/69092 Loss: 178.259 -16000/69092 Loss: 177.821 -19200/69092 Loss: 178.540 -22400/69092 Loss: 179.684 -25600/69092 Loss: 179.641 -28800/69092 Loss: 178.731 -32000/69092 Loss: 177.929 -35200/69092 Loss: 176.187 -38400/69092 Loss: 172.829 -41600/69092 Loss: 175.603 -44800/69092 Loss: 175.634 -48000/69092 Loss: 176.480 -51200/69092 Loss: 179.569 -54400/69092 Loss: 181.344 -57600/69092 Loss: 176.494 -60800/69092 Loss: 179.082 -64000/69092 Loss: 178.474 -67200/69092 Loss: 179.446 -Epoch: 5 Average loss: 177.85 -0/69092 Loss: 176.970 -3200/69092 Loss: 176.408 -6400/69092 Loss: 178.858 -9600/69092 Loss: 178.331 -12800/69092 Loss: 178.803 -16000/69092 Loss: 176.130 -19200/69092 Loss: 177.191 -22400/69092 Loss: 174.921 -25600/69092 Loss: 175.036 -28800/69092 Loss: 174.670 -32000/69092 Loss: 176.784 -35200/69092 Loss: 175.621 -38400/69092 Loss: 177.801 -41600/69092 Loss: 178.486 -44800/69092 Loss: 177.910 -48000/69092 Loss: 179.626 -51200/69092 Loss: 177.689 -54400/69092 Loss: 177.252 -57600/69092 Loss: 178.268 -60800/69092 Loss: 174.528 -64000/69092 Loss: 180.210 -67200/69092 Loss: 172.981 -Epoch: 6 Average loss: 176.96 -0/69092 Loss: 172.989 -3200/69092 Loss: 177.269 -6400/69092 Loss: 175.520 -9600/69092 Loss: 174.691 -12800/69092 Loss: 176.304 -16000/69092 Loss: 175.862 -19200/69092 Loss: 177.363 -22400/69092 Loss: 175.648 -25600/69092 Loss: 179.415 -28800/69092 Loss: 174.836 -32000/69092 Loss: 179.038 -35200/69092 Loss: 176.845 -38400/69092 Loss: 174.129 -41600/69092 Loss: 174.983 -44800/69092 Loss: 176.457 -48000/69092 Loss: 175.409 -51200/69092 Loss: 179.870 -54400/69092 Loss: 169.886 -57600/69092 Loss: 174.812 -60800/69092 Loss: 179.631 -64000/69092 Loss: 175.263 -67200/69092 Loss: 175.486 -Epoch: 7 Average loss: 176.17 -0/69092 Loss: 182.499 -3200/69092 Loss: 174.724 -6400/69092 Loss: 175.769 -9600/69092 Loss: 175.399 -12800/69092 Loss: 174.957 -16000/69092 Loss: 173.714 -19200/69092 Loss: 174.390 -22400/69092 Loss: 173.339 -25600/69092 Loss: 177.868 -28800/69092 Loss: 176.928 -32000/69092 Loss: 177.665 -35200/69092 Loss: 177.244 -38400/69092 Loss: 175.513 -41600/69092 Loss: 174.195 -44800/69092 Loss: 176.281 -48000/69092 Loss: 174.752 -51200/69092 Loss: 173.814 -54400/69092 Loss: 176.664 -57600/69092 Loss: 173.527 -60800/69092 Loss: 180.212 -64000/69092 Loss: 175.542 -67200/69092 Loss: 176.622 -Epoch: 8 Average loss: 175.71 -0/69092 Loss: 156.728 -3200/69092 Loss: 175.051 -6400/69092 Loss: 173.837 -9600/69092 Loss: 177.669 -12800/69092 Loss: 177.154 -16000/69092 Loss: 175.246 -19200/69092 Loss: 175.701 -22400/69092 Loss: 179.730 -25600/69092 Loss: 177.038 -28800/69092 Loss: 177.181 -32000/69092 Loss: 175.627 -35200/69092 Loss: 175.962 -38400/69092 Loss: 175.364 -41600/69092 Loss: 170.164 -44800/69092 Loss: 173.813 -48000/69092 Loss: 171.629 -51200/69092 Loss: 173.023 -54400/69092 Loss: 173.645 -57600/69092 Loss: 174.178 -60800/69092 Loss: 175.416 -64000/69092 Loss: 177.290 -67200/69092 Loss: 178.709 -Epoch: 9 Average loss: 175.37 -0/69092 Loss: 152.802 -3200/69092 Loss: 176.915 -6400/69092 Loss: 171.666 -9600/69092 Loss: 174.777 -12800/69092 Loss: 174.794 -16000/69092 Loss: 178.759 -19200/69092 Loss: 176.785 -22400/69092 Loss: 171.137 -25600/69092 Loss: 176.108 -28800/69092 Loss: 176.560 -32000/69092 Loss: 175.608 -35200/69092 Loss: 175.323 -38400/69092 Loss: 177.557 -41600/69092 Loss: 171.559 -44800/69092 Loss: 172.491 -48000/69092 Loss: 173.053 -51200/69092 Loss: 175.481 -54400/69092 Loss: 175.098 -57600/69092 Loss: 174.345 -60800/69092 Loss: 174.417 -64000/69092 Loss: 174.409 -67200/69092 Loss: 173.647 -Epoch: 10 Average loss: 174.87 -0/69092 Loss: 179.213 -3200/69092 Loss: 173.855 -6400/69092 Loss: 174.073 -9600/69092 Loss: 176.358 -12800/69092 Loss: 177.361 -16000/69092 Loss: 175.480 -19200/69092 Loss: 175.529 -22400/69092 Loss: 175.376 -25600/69092 Loss: 171.995 -28800/69092 Loss: 174.104 -32000/69092 Loss: 175.416 -35200/69092 Loss: 175.910 -38400/69092 Loss: 176.820 -41600/69092 Loss: 172.693 -44800/69092 Loss: 174.116 -48000/69092 Loss: 173.678 -51200/69092 Loss: 174.519 -54400/69092 Loss: 172.135 -57600/69092 Loss: 175.865 -60800/69092 Loss: 174.142 -64000/69092 Loss: 172.220 -67200/69092 Loss: 175.021 -Epoch: 11 Average loss: 174.59 -0/69092 Loss: 163.699 -3200/69092 Loss: 175.630 -6400/69092 Loss: 172.353 -9600/69092 Loss: 173.406 -12800/69092 Loss: 174.677 -16000/69092 Loss: 173.659 -19200/69092 Loss: 174.090 -22400/69092 Loss: 174.560 -25600/69092 Loss: 173.179 -28800/69092 Loss: 173.921 -32000/69092 Loss: 173.517 -35200/69092 Loss: 173.364 -38400/69092 Loss: 175.125 -41600/69092 Loss: 174.068 -44800/69092 Loss: 175.291 -48000/69092 Loss: 173.934 -51200/69092 Loss: 176.217 -54400/69092 Loss: 172.097 -57600/69092 Loss: 175.580 -60800/69092 Loss: 174.114 -64000/69092 Loss: 174.610 -67200/69092 Loss: 173.205 -Epoch: 12 Average loss: 174.19 -0/69092 Loss: 191.185 -3200/69092 Loss: 176.507 -6400/69092 Loss: 173.035 -9600/69092 Loss: 171.493 -12800/69092 Loss: 177.913 -16000/69092 Loss: 175.693 -19200/69092 Loss: 172.820 -22400/69092 Loss: 171.520 -25600/69092 Loss: 172.230 -28800/69092 Loss: 171.340 -32000/69092 Loss: 174.393 -35200/69092 Loss: 171.775 -38400/69092 Loss: 176.086 -41600/69092 Loss: 172.708 -44800/69092 Loss: 173.573 -48000/69092 Loss: 172.352 -51200/69092 Loss: 175.626 -54400/69092 Loss: 174.810 -57600/69092 Loss: 173.702 -60800/69092 Loss: 176.364 -64000/69092 Loss: 170.513 -67200/69092 Loss: 173.109 -Epoch: 13 Average loss: 173.72 -0/69092 Loss: 159.396 -3200/69092 Loss: 175.317 -6400/69092 Loss: 173.157 -9600/69092 Loss: 173.722 -12800/69092 Loss: 171.148 -16000/69092 Loss: 175.274 -19200/69092 Loss: 176.036 -22400/69092 Loss: 172.847 -25600/69092 Loss: 166.729 -28800/69092 Loss: 169.101 -32000/69092 Loss: 172.004 -35200/69092 Loss: 171.771 -38400/69092 Loss: 170.028 -41600/69092 Loss: 168.713 -44800/69092 Loss: 172.679 -48000/69092 Loss: 168.519 -51200/69092 Loss: 167.290 -54400/69092 Loss: 168.153 -57600/69092 Loss: 170.113 -60800/69092 Loss: 171.238 -64000/69092 Loss: 166.859 -67200/69092 Loss: 171.634 -Epoch: 14 Average loss: 171.06 -0/69092 Loss: 156.795 -3200/69092 Loss: 169.503 -6400/69092 Loss: 170.706 -9600/69092 Loss: 167.440 -12800/69092 Loss: 169.419 -16000/69092 Loss: 167.985 -19200/69092 Loss: 168.267 -22400/69092 Loss: 168.930 -25600/69092 Loss: 168.436 -28800/69092 Loss: 170.484 -32000/69092 Loss: 170.663 -35200/69092 Loss: 165.644 -38400/69092 Loss: 169.801 -41600/69092 Loss: 167.051 -44800/69092 Loss: 169.969 -48000/69092 Loss: 166.461 -51200/69092 Loss: 169.932 -54400/69092 Loss: 170.305 -57600/69092 Loss: 168.247 -60800/69092 Loss: 168.439 -64000/69092 Loss: 169.417 -67200/69092 Loss: 168.768 -Epoch: 15 Average loss: 168.81 -0/69092 Loss: 146.202 -3200/69092 Loss: 170.373 -6400/69092 Loss: 164.957 -9600/69092 Loss: 168.805 -12800/69092 Loss: 169.648 -16000/69092 Loss: 167.061 -19200/69092 Loss: 168.539 -22400/69092 Loss: 166.700 -25600/69092 Loss: 170.245 -28800/69092 Loss: 169.705 -32000/69092 Loss: 167.015 -35200/69092 Loss: 165.536 -38400/69092 Loss: 168.482 -41600/69092 Loss: 166.300 -44800/69092 Loss: 169.078 -48000/69092 Loss: 170.031 -51200/69092 Loss: 164.588 -54400/69092 Loss: 169.601 -57600/69092 Loss: 170.558 -60800/69092 Loss: 169.104 -64000/69092 Loss: 168.777 -67200/69092 Loss: 167.852 -Epoch: 16 Average loss: 168.19 -0/69092 Loss: 173.891 -3200/69092 Loss: 167.297 -6400/69092 Loss: 167.022 -9600/69092 Loss: 169.751 -12800/69092 Loss: 173.189 -16000/69092 Loss: 167.852 -19200/69092 Loss: 166.613 -22400/69092 Loss: 168.037 -25600/69092 Loss: 166.963 -28800/69092 Loss: 165.673 -32000/69092 Loss: 168.600 -35200/69092 Loss: 168.420 -38400/69092 Loss: 167.863 -41600/69092 Loss: 170.022 -44800/69092 Loss: 166.268 -48000/69092 Loss: 166.312 -51200/69092 Loss: 168.968 -54400/69092 Loss: 170.261 -57600/69092 Loss: 167.449 -60800/69092 Loss: 167.692 -64000/69092 Loss: 166.189 -67200/69092 Loss: 169.426 -Epoch: 17 Average loss: 168.12 -0/69092 Loss: 158.335 -3200/69092 Loss: 166.323 -6400/69092 Loss: 168.593 -9600/69092 Loss: 169.252 -12800/69092 Loss: 165.184 -16000/69092 Loss: 166.948 -19200/69092 Loss: 167.851 -22400/69092 Loss: 167.890 -25600/69092 Loss: 166.056 -28800/69092 Loss: 166.313 -32000/69092 Loss: 166.684 -35200/69092 Loss: 166.046 -38400/69092 Loss: 168.573 -41600/69092 Loss: 170.346 -44800/69092 Loss: 168.576 -48000/69092 Loss: 167.955 -51200/69092 Loss: 169.517 -54400/69092 Loss: 170.773 -57600/69092 Loss: 167.093 -60800/69092 Loss: 167.219 -64000/69092 Loss: 167.393 -67200/69092 Loss: 170.586 -Epoch: 18 Average loss: 167.79 -0/69092 Loss: 156.962 -3200/69092 Loss: 166.186 -6400/69092 Loss: 168.813 -9600/69092 Loss: 166.400 -12800/69092 Loss: 167.795 -16000/69092 Loss: 167.113 -19200/69092 Loss: 165.320 -22400/69092 Loss: 165.054 -25600/69092 Loss: 168.109 -28800/69092 Loss: 167.833 -32000/69092 Loss: 166.894 -35200/69092 Loss: 168.014 -38400/69092 Loss: 168.756 -41600/69092 Loss: 168.506 -44800/69092 Loss: 169.647 -48000/69092 Loss: 167.893 -51200/69092 Loss: 167.151 -54400/69092 Loss: 168.016 -57600/69092 Loss: 168.288 -60800/69092 Loss: 169.788 -64000/69092 Loss: 165.068 -67200/69092 Loss: 170.401 -Epoch: 19 Average loss: 167.60 -0/69092 Loss: 168.727 -3200/69092 Loss: 168.674 -6400/69092 Loss: 166.971 -9600/69092 Loss: 164.222 -12800/69092 Loss: 165.916 -16000/69092 Loss: 173.111 -19200/69092 Loss: 165.215 -22400/69092 Loss: 168.106 -25600/69092 Loss: 167.273 -28800/69092 Loss: 167.498 -32000/69092 Loss: 167.923 -35200/69092 Loss: 168.921 -38400/69092 Loss: 163.615 -41600/69092 Loss: 167.147 -44800/69092 Loss: 167.341 -48000/69092 Loss: 169.308 -51200/69092 Loss: 168.601 -54400/69092 Loss: 169.787 -57600/69092 Loss: 166.595 -60800/69092 Loss: 169.997 -64000/69092 Loss: 165.999 -67200/69092 Loss: 169.111 -Epoch: 20 Average loss: 167.67 -0/69092 Loss: 164.874 -3200/69092 Loss: 167.108 -6400/69092 Loss: 166.190 -9600/69092 Loss: 168.514 -12800/69092 Loss: 166.697 -16000/69092 Loss: 165.785 -19200/69092 Loss: 166.035 -22400/69092 Loss: 165.913 -25600/69092 Loss: 167.381 -28800/69092 Loss: 167.176 -32000/69092 Loss: 168.299 -35200/69092 Loss: 168.890 -38400/69092 Loss: 165.954 -41600/69092 Loss: 168.868 -44800/69092 Loss: 166.462 -48000/69092 Loss: 165.692 -51200/69092 Loss: 170.486 -54400/69092 Loss: 166.689 -57600/69092 Loss: 168.941 -60800/69092 Loss: 167.861 -64000/69092 Loss: 169.267 -67200/69092 Loss: 168.650 -Epoch: 21 Average loss: 167.43 -0/69092 Loss: 158.066 -3200/69092 Loss: 166.319 -6400/69092 Loss: 168.866 -9600/69092 Loss: 168.782 -12800/69092 Loss: 168.118 -16000/69092 Loss: 168.135 -19200/69092 Loss: 165.183 -22400/69092 Loss: 171.452 -25600/69092 Loss: 166.608 -28800/69092 Loss: 169.231 -32000/69092 Loss: 164.687 -35200/69092 Loss: 167.388 -38400/69092 Loss: 168.537 -41600/69092 Loss: 166.544 -44800/69092 Loss: 170.258 -48000/69092 Loss: 164.814 -51200/69092 Loss: 167.838 -54400/69092 Loss: 167.264 -57600/69092 Loss: 165.149 -60800/69092 Loss: 168.617 -64000/69092 Loss: 166.318 -67200/69092 Loss: 167.333 -Epoch: 22 Average loss: 167.44 -0/69092 Loss: 164.065 -3200/69092 Loss: 164.064 -6400/69092 Loss: 167.517 -9600/69092 Loss: 170.663 -12800/69092 Loss: 167.499 -16000/69092 Loss: 167.594 -19200/69092 Loss: 166.922 -22400/69092 Loss: 163.598 -25600/69092 Loss: 168.840 -28800/69092 Loss: 169.542 -32000/69092 Loss: 168.083 -35200/69092 Loss: 169.357 -38400/69092 Loss: 166.990 -41600/69092 Loss: 164.629 -44800/69092 Loss: 168.236 -48000/69092 Loss: 164.188 -51200/69092 Loss: 166.765 -54400/69092 Loss: 166.641 -57600/69092 Loss: 170.710 -60800/69092 Loss: 165.975 -64000/69092 Loss: 168.500 -67200/69092 Loss: 169.258 -Epoch: 23 Average loss: 167.42 -0/69092 Loss: 172.864 -3200/69092 Loss: 165.477 -6400/69092 Loss: 167.342 -9600/69092 Loss: 167.331 -12800/69092 Loss: 168.253 -16000/69092 Loss: 164.884 -19200/69092 Loss: 165.278 -22400/69092 Loss: 169.037 -25600/69092 Loss: 170.689 -28800/69092 Loss: 167.800 -32000/69092 Loss: 167.574 -35200/69092 Loss: 165.465 -38400/69092 Loss: 167.913 -41600/69092 Loss: 166.253 -44800/69092 Loss: 167.459 -48000/69092 Loss: 167.662 -51200/69092 Loss: 167.989 -54400/69092 Loss: 166.499 -57600/69092 Loss: 165.015 -60800/69092 Loss: 167.097 -64000/69092 Loss: 167.709 -67200/69092 Loss: 166.537 -Epoch: 24 Average loss: 167.05 -0/69092 Loss: 159.477 -3200/69092 Loss: 168.167 -6400/69092 Loss: 165.640 -9600/69092 Loss: 169.130 -12800/69092 Loss: 164.949 -16000/69092 Loss: 170.971 -19200/69092 Loss: 168.796 -22400/69092 Loss: 165.955 -25600/69092 Loss: 166.466 -28800/69092 Loss: 169.212 -32000/69092 Loss: 167.467 -35200/69092 Loss: 166.299 -38400/69092 Loss: 167.788 -41600/69092 Loss: 166.326 -44800/69092 Loss: 166.547 -48000/69092 Loss: 164.943 -51200/69092 Loss: 166.213 -54400/69092 Loss: 167.823 -57600/69092 Loss: 165.211 -60800/69092 Loss: 169.700 -64000/69092 Loss: 167.593 -67200/69092 Loss: 169.609 -Epoch: 25 Average loss: 167.31 -0/69092 Loss: 161.595 -3200/69092 Loss: 167.897 -6400/69092 Loss: 163.418 -9600/69092 Loss: 164.333 -12800/69092 Loss: 167.346 -16000/69092 Loss: 168.345 -19200/69092 Loss: 167.494 -22400/69092 Loss: 168.143 -25600/69092 Loss: 166.609 -28800/69092 Loss: 162.532 -32000/69092 Loss: 169.007 -35200/69092 Loss: 168.024 -38400/69092 Loss: 162.925 -41600/69092 Loss: 169.271 -44800/69092 Loss: 170.226 -48000/69092 Loss: 167.290 -51200/69092 Loss: 167.863 -54400/69092 Loss: 168.024 -57600/69092 Loss: 165.059 -60800/69092 Loss: 168.377 -64000/69092 Loss: 168.783 -67200/69092 Loss: 165.481 -Epoch: 26 Average loss: 167.07 -0/69092 Loss: 181.461 -3200/69092 Loss: 168.982 -6400/69092 Loss: 164.568 -9600/69092 Loss: 164.560 -12800/69092 Loss: 169.223 -16000/69092 Loss: 163.325 -19200/69092 Loss: 169.991 -22400/69092 Loss: 168.719 -25600/69092 Loss: 164.910 -28800/69092 Loss: 166.165 -32000/69092 Loss: 167.390 -35200/69092 Loss: 165.877 -38400/69092 Loss: 166.731 -41600/69092 Loss: 167.863 -44800/69092 Loss: 166.774 -48000/69092 Loss: 170.445 -51200/69092 Loss: 169.944 -54400/69092 Loss: 164.762 -57600/69092 Loss: 165.063 -60800/69092 Loss: 169.285 -64000/69092 Loss: 167.527 -67200/69092 Loss: 166.289 -Epoch: 27 Average loss: 167.06 -0/69092 Loss: 170.792 -3200/69092 Loss: 167.467 -6400/69092 Loss: 167.020 -9600/69092 Loss: 168.963 -12800/69092 Loss: 165.771 -16000/69092 Loss: 166.994 -19200/69092 Loss: 164.748 -22400/69092 Loss: 165.152 -25600/69092 Loss: 165.792 -28800/69092 Loss: 165.728 -32000/69092 Loss: 165.544 -35200/69092 Loss: 169.239 -38400/69092 Loss: 165.895 -41600/69092 Loss: 169.427 -44800/69092 Loss: 168.761 -48000/69092 Loss: 167.436 -51200/69092 Loss: 167.477 -54400/69092 Loss: 167.437 -57600/69092 Loss: 163.032 -60800/69092 Loss: 168.386 -64000/69092 Loss: 169.317 -67200/69092 Loss: 166.414 -Epoch: 28 Average loss: 166.99 -0/69092 Loss: 159.176 -3200/69092 Loss: 169.014 -6400/69092 Loss: 165.076 -9600/69092 Loss: 164.450 -12800/69092 Loss: 165.767 -16000/69092 Loss: 163.509 -19200/69092 Loss: 164.611 -22400/69092 Loss: 169.482 -25600/69092 Loss: 164.828 -28800/69092 Loss: 165.456 -32000/69092 Loss: 166.719 -35200/69092 Loss: 165.716 -38400/69092 Loss: 169.506 -41600/69092 Loss: 167.452 -44800/69092 Loss: 167.680 -48000/69092 Loss: 168.003 -51200/69092 Loss: 166.531 -54400/69092 Loss: 166.318 -57600/69092 Loss: 167.834 -60800/69092 Loss: 168.869 -64000/69092 Loss: 166.379 -67200/69092 Loss: 168.122 -Epoch: 29 Average loss: 166.84 -0/69092 Loss: 168.089 -3200/69092 Loss: 164.654 -6400/69092 Loss: 166.092 -9600/69092 Loss: 165.645 -12800/69092 Loss: 168.803 -16000/69092 Loss: 166.414 -19200/69092 Loss: 168.405 -22400/69092 Loss: 167.098 -25600/69092 Loss: 166.610 -28800/69092 Loss: 167.311 -32000/69092 Loss: 170.033 -35200/69092 Loss: 166.208 -38400/69092 Loss: 163.407 -41600/69092 Loss: 166.786 -44800/69092 Loss: 165.290 -48000/69092 Loss: 168.366 -51200/69092 Loss: 166.125 -54400/69092 Loss: 164.718 -57600/69092 Loss: 166.249 -60800/69092 Loss: 166.015 -64000/69092 Loss: 168.616 -67200/69092 Loss: 169.330 -Epoch: 30 Average loss: 166.81 -0/69092 Loss: 178.073 -3200/69092 Loss: 166.946 -6400/69092 Loss: 165.648 -9600/69092 Loss: 168.367 -12800/69092 Loss: 166.295 -16000/69092 Loss: 164.805 -19200/69092 Loss: 168.451 -22400/69092 Loss: 164.708 -25600/69092 Loss: 166.344 -28800/69092 Loss: 168.841 -32000/69092 Loss: 168.623 -35200/69092 Loss: 166.244 -38400/69092 Loss: 163.225 -41600/69092 Loss: 169.795 -44800/69092 Loss: 167.962 -48000/69092 Loss: 167.154 -51200/69092 Loss: 168.323 -54400/69092 Loss: 166.811 -57600/69092 Loss: 171.028 -60800/69092 Loss: 164.150 -64000/69092 Loss: 166.475 -67200/69092 Loss: 164.613 -Epoch: 31 Average loss: 166.86 -0/69092 Loss: 163.393 -3200/69092 Loss: 164.554 -6400/69092 Loss: 167.080 -9600/69092 Loss: 171.187 -12800/69092 Loss: 164.961 -16000/69092 Loss: 163.417 -19200/69092 Loss: 168.374 -22400/69092 Loss: 165.693 -25600/69092 Loss: 167.065 -28800/69092 Loss: 164.642 -32000/69092 Loss: 165.408 -35200/69092 Loss: 166.985 -38400/69092 Loss: 166.414 -41600/69092 Loss: 170.011 -44800/69092 Loss: 166.610 -48000/69092 Loss: 169.074 -51200/69092 Loss: 166.680 -54400/69092 Loss: 168.125 -57600/69092 Loss: 164.007 -60800/69092 Loss: 166.152 -64000/69092 Loss: 168.865 -67200/69092 Loss: 163.801 -Epoch: 32 Average loss: 166.62 -0/69092 Loss: 145.340 -3200/69092 Loss: 166.538 -6400/69092 Loss: 167.724 -9600/69092 Loss: 168.375 -12800/69092 Loss: 165.096 -16000/69092 Loss: 164.871 -19200/69092 Loss: 164.603 -22400/69092 Loss: 166.468 -25600/69092 Loss: 168.188 -28800/69092 Loss: 164.460 -32000/69092 Loss: 166.869 -35200/69092 Loss: 167.491 -38400/69092 Loss: 164.241 -41600/69092 Loss: 164.496 -44800/69092 Loss: 164.427 -48000/69092 Loss: 167.685 -51200/69092 Loss: 167.454 -54400/69092 Loss: 166.709 -57600/69092 Loss: 166.155 -60800/69092 Loss: 169.947 -64000/69092 Loss: 165.286 -67200/69092 Loss: 168.151 -Epoch: 33 Average loss: 166.48 -0/69092 Loss: 166.408 -3200/69092 Loss: 167.340 -6400/69092 Loss: 163.513 -9600/69092 Loss: 165.043 -12800/69092 Loss: 164.295 -16000/69092 Loss: 164.211 -19200/69092 Loss: 166.156 -22400/69092 Loss: 169.719 -25600/69092 Loss: 166.807 -28800/69092 Loss: 165.537 -32000/69092 Loss: 170.431 -35200/69092 Loss: 165.410 -38400/69092 Loss: 166.952 -41600/69092 Loss: 168.332 -44800/69092 Loss: 164.908 -48000/69092 Loss: 168.677 -51200/69092 Loss: 167.376 -54400/69092 Loss: 166.778 -57600/69092 Loss: 164.389 -60800/69092 Loss: 168.517 -64000/69092 Loss: 167.666 -67200/69092 Loss: 168.580 -Epoch: 34 Average loss: 166.73 -0/69092 Loss: 160.057 -3200/69092 Loss: 165.825 -6400/69092 Loss: 165.326 -9600/69092 Loss: 166.921 -12800/69092 Loss: 166.886 -16000/69092 Loss: 164.149 -19200/69092 Loss: 166.317 -22400/69092 Loss: 169.658 -25600/69092 Loss: 167.307 -28800/69092 Loss: 165.543 -32000/69092 Loss: 165.719 -35200/69092 Loss: 167.927 -38400/69092 Loss: 166.107 -41600/69092 Loss: 169.167 -44800/69092 Loss: 164.462 -48000/69092 Loss: 164.943 -51200/69092 Loss: 167.662 -54400/69092 Loss: 169.840 -57600/69092 Loss: 162.048 -60800/69092 Loss: 165.538 -64000/69092 Loss: 169.039 -67200/69092 Loss: 167.731 -Epoch: 35 Average loss: 166.58 -0/69092 Loss: 159.147 -3200/69092 Loss: 168.711 -6400/69092 Loss: 164.872 -9600/69092 Loss: 166.459 -12800/69092 Loss: 165.101 -16000/69092 Loss: 165.007 -19200/69092 Loss: 163.900 -22400/69092 Loss: 165.651 -25600/69092 Loss: 166.408 -28800/69092 Loss: 168.436 -32000/69092 Loss: 168.150 -35200/69092 Loss: 166.650 -38400/69092 Loss: 166.221 -41600/69092 Loss: 165.079 -44800/69092 Loss: 169.582 -48000/69092 Loss: 166.133 -51200/69092 Loss: 164.542 -54400/69092 Loss: 166.025 -57600/69092 Loss: 169.027 -60800/69092 Loss: 166.220 -64000/69092 Loss: 169.161 -67200/69092 Loss: 165.857 -Epoch: 36 Average loss: 166.59 -0/69092 Loss: 163.261 -3200/69092 Loss: 164.025 -6400/69092 Loss: 165.219 -9600/69092 Loss: 167.676 -12800/69092 Loss: 168.915 -16000/69092 Loss: 164.794 -19200/69092 Loss: 171.172 -22400/69092 Loss: 168.779 -25600/69092 Loss: 166.745 -28800/69092 Loss: 164.285 -32000/69092 Loss: 165.940 -35200/69092 Loss: 168.109 -38400/69092 Loss: 164.635 -41600/69092 Loss: 166.126 -44800/69092 Loss: 167.761 -48000/69092 Loss: 170.240 -51200/69092 Loss: 163.132 -54400/69092 Loss: 162.558 -57600/69092 Loss: 169.405 -60800/69092 Loss: 165.856 -64000/69092 Loss: 166.046 -67200/69092 Loss: 165.596 -Epoch: 37 Average loss: 166.51 -0/69092 Loss: 156.881 -3200/69092 Loss: 164.568 -6400/69092 Loss: 165.541 -9600/69092 Loss: 165.421 -12800/69092 Loss: 163.026 -16000/69092 Loss: 163.961 -19200/69092 Loss: 166.480 -22400/69092 Loss: 166.765 -25600/69092 Loss: 170.223 -28800/69092 Loss: 166.316 -32000/69092 Loss: 167.886 -35200/69092 Loss: 170.375 -38400/69092 Loss: 167.319 -41600/69092 Loss: 166.670 -44800/69092 Loss: 165.619 -48000/69092 Loss: 167.852 -51200/69092 Loss: 164.861 -54400/69092 Loss: 164.426 -57600/69092 Loss: 169.410 -60800/69092 Loss: 166.011 -64000/69092 Loss: 165.486 -67200/69092 Loss: 164.748 -Epoch: 38 Average loss: 166.34 -0/69092 Loss: 154.482 -3200/69092 Loss: 166.846 -6400/69092 Loss: 168.174 -9600/69092 Loss: 166.786 -12800/69092 Loss: 164.238 -16000/69092 Loss: 169.103 -19200/69092 Loss: 165.084 -22400/69092 Loss: 166.447 -25600/69092 Loss: 166.493 -28800/69092 Loss: 169.040 -32000/69092 Loss: 169.643 -35200/69092 Loss: 163.526 -38400/69092 Loss: 166.074 -41600/69092 Loss: 167.452 -44800/69092 Loss: 164.863 -48000/69092 Loss: 163.552 -51200/69092 Loss: 166.178 -54400/69092 Loss: 164.692 -57600/69092 Loss: 167.591 -60800/69092 Loss: 165.594 -64000/69092 Loss: 167.904 -67200/69092 Loss: 163.778 -Epoch: 39 Average loss: 166.47 -0/69092 Loss: 174.792 -3200/69092 Loss: 168.603 -6400/69092 Loss: 167.578 -9600/69092 Loss: 166.309 -12800/69092 Loss: 166.655 -16000/69092 Loss: 166.937 -19200/69092 Loss: 166.851 -22400/69092 Loss: 168.955 -25600/69092 Loss: 162.846 -28800/69092 Loss: 166.400 -32000/69092 Loss: 161.946 -35200/69092 Loss: 164.650 -38400/69092 Loss: 165.843 -41600/69092 Loss: 166.236 -44800/69092 Loss: 165.099 -48000/69092 Loss: 165.672 -51200/69092 Loss: 165.388 -54400/69092 Loss: 169.332 -57600/69092 Loss: 167.883 -60800/69092 Loss: 167.210 -64000/69092 Loss: 166.555 -67200/69092 Loss: 167.086 -Epoch: 40 Average loss: 166.41 diff --git a/Experiments/main.py b/Experiments/main.py deleted file mode 100644 index 8c30e48060b87139bed86b9377973d305437519e..0000000000000000000000000000000000000000 --- a/Experiments/main.py +++ /dev/null @@ -1,164 +0,0 @@ -""" -Code from: https://github.com/Schlumberger/joint-vae -https://github.com/1Konny/Beta-VAE -""" - -# import sys -# sys.path.append('') -import os -from dataloader.dataloaders import * -import torch.nn as nn -from VAE_model.models import VAE -from torch import optim -from viz.visualize import Visualizer -from utils.training import Trainer, gpu_config -import argparse -import json - - -def main(args): - - # continue and discrete capacity - if args.cont_capacity is not None: - cont_capacity = [float(item) for item in args.cont_capacity.split(',')] - else: - cont_capacity = args.cont_capacity - if args.disc_capacity is not None: - disc_capacity = [float(item) for item in args.disc_capacity.split(',')] - else: - disc_capacity = args.disc_capacity - - # latent_spec - latent_spec = {"cont": args.latent_spec_cont} - - # number of classes and image size: - if args.dataset == 'mnist' or args.dataset == 'fashion_data': - nb_classes = 10 - img_size = (1, 32, 32) - elif args.dataset == 'celeba_64': - nb_classes = None - img_size = (3, 64, 64) - elif args.dataset == 'rendered_chairs': - nb_classes = 1393 - img_size = (3, 64, 64) - elif args.dataset == 'dSprites': - nb_classes = 6 - - # create and write a json file: - if not args.load_model_checkpoint: - ckpt_dir = os.path.join('trained_models', args.dataset, args.experiment_name, args.ckpt_dir) - if not os.path.exists(ckpt_dir): - os.makedirs(ckpt_dir, exist_ok=True) - - parameter = {'dataset': args.dataset, 'epochs': args.epochs, 'cont_capacity': args.cont_capacity, - 'disc_capacity': args.disc_capacity, 'record_loss_every': args.record_loss_every, - 'batch_size': args.batch_size, 'latent_spec_cont': args.latent_spec_cont, - 'experiment_name': args.experiment_name, 'print_loss_every': args.print_loss_every, - 'latent_spec_disc': args.latent_spec_disc, 'nb_classes': nb_classes} - - # Save json parameters: - file_path = os.path.join('trained_models/', args.dataset, args.experiment_name, 'specs.json') - with open(file_path, 'w') as json_file: - json.dump(parameter, json_file) - print('ok') - - # create model - model = VAE(img_size, latent_spec=latent_spec) - # load dataset - train_loader, test_loader, dataset_name = load_dataset(args.dataset, args.batch_size, num_worker=args.num_worker) - - # Define model - model, use_gpu, device = gpu_config(model) - - if args.verbose: - print(model) - num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - print('The number of parameters of model is', num_params) - - # Define optimizer and criterion - optimizer = optim.Adam(model.parameters(), lr=args.lr) - criterion = nn.CrossEntropyLoss() - - # Define trainer - trainer = Trainer(model, device, optimizer, criterion, save_step=args.save_step, ckpt_dir=args.ckpt_dir, - ckpt_name=args.ckpt_name, - expe_name=args.experiment_name, - dataset=args.dataset, - cont_capacity=cont_capacity, - disc_capacity=disc_capacity, - is_beta=args.is_beta_VAE, - beta=args.beta) - - # define visualizer - viz = Visualizer(model) - - # Train model: - trainer.train(train_loader, args.epochs, save_training_gif=('../img_gif/' + dataset_name + '_' + - args.latent_name + args.experiment_name + '.gif', viz)) - - """ - # Save trained model - if args.save_model: - torch.save(trainer.model.state_dict(), - '../trained_models/' + dataset_name + '/model_' + args.experiment_name + '.pt') - """ - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description='VAE') - parser.add_argument('--batch-size', type=int, default=64, metavar='integer value', - help='input batch size for training (default: 64)') - parser.add_argument('--record-loss-every', type=int, default=50, metavar='integer value', - help='Record loss every (value)') - parser.add_argument('--print-loss-every', type=int, default=50, metavar='integer value', - help='Print loss every (value)') - parser.add_argument('--epochs', type=int, default=100, metavar='integer value', - help='number of epochs to train (default: 100)') - parser.add_argument('--lr', type=float, default=5e-4, metavar='value', - help='learning rate value') - parser.add_argument('--dataset', type=str, default=None, metavar='name', - help='Dataset Name') - parser.add_argument('--save-model', type=bool, default=True, metavar='bool', - help='Save model') - parser.add_argument('--save-reconstruction-image', type=bool, default=False, metavar='bool', - help='Save reconstruction image') - parser.add_argument('--latent_spec_cont', type=int, default=10, metavar='integer value', - help='Capacity of continue latent space') - parser.add_argument('--latent_spec_disc', type=list, default=None, metavar='integer list', - help='Capacity of discrete latent space') - parser.add_argument('--cont-capacity', type=str, default=None, metavar='integer tuple', - help='capacity of continuous channels') - parser.add_argument('--disc-capacity', type=str, default=None, metavar='integer tuple', - help='capacity of discrete channels') - parser.add_argument('--experiment-name', type=str, default='', metavar='name', - help='experiment name') - parser.add_argument('--latent-name', type=str, default='', metavar='name', - help='Latent space name') - parser.add_argument('--is-beta-VAE', type=bool, default=False, metavar='beta_VAE', - help='If use beta-VAE') - parser.add_argument('--beta', type=int, default=None, metavar='beta', - help='Beta value') - parser.add_argument("--gpu-devices", type=int, nargs='+', default=None, help="GPU devices available") - parser.add_argument("--load-model-checkpoint", type=bool, default=False, help="If we use a pre trained model") - parser.add_argument("--load-expe-name", type=str, default='', help="The name expe to loading") - parser.add_argument("--num-worker", type=int, default=4, help="num worker to dataloader") - parser.add_argument("--verbose", type=bool, default=True, help="To print details model") - parser.add_argument("--save-step", type=int, default=1, help="save model every step") - parser.add_argument('--ckpt_dir', default='checkpoints', type=str, help='checkpoint directory') - parser.add_argument('--ckpt_name', default='last', type=str, - help='load previous checkpoint. insert checkpoint filename') - - args = parser.parse_args() - - assert args.dataset in ['mnist', 'fashion_data', 'celeba_64', 'rendered_chairs', 'dSprites'], \ - "The choisen dataset is not available. Please choose a dataset from the following: ['mnist', 'fashion_data', " \ - "'celeba_64', 'rendered_chairs', 'dSprites'] " - if args.is_beta_VAE: - assert args.beta is not None, 'Beta is null or if you use Beta-VAe model, please enter a beta value' - - print(parser.parse_args()) - - gpu_devices = ','.join([str(id) for id in args.gpu_devices]) - os.environ["CUDA_VISIBLE_DEVICES"] = gpu_devices - - main(args) diff --git a/parameters_combinations/param_combinations_chairs.txt b/parameters_combinations/param_combinations_chairs.txt index 83a03d337830d90303b8e1c1389455b2f1474027..3564326e4e7db5ddce4e8e13009f245304c9926b 100644 --- a/parameters_combinations/param_combinations_chairs.txt +++ b/parameters_combinations/param_combinations_chairs.txt @@ -1,4 +1,4 @@ ---batch-size=256 --dataset=rendered_chairs --epochs=40 --latent_spec_cont=10 --is-beta-VAE=True --beta=4 --lr=1e-4 --experiment-name=beta_VAE_bs_256 --gpu-devices 0 1 ---batch-size=64 --dataset=rendered_chairs --epochs=40 --latent_spec_cont=10 --is-beta-VAE=True --beta=4 --lr=1e-4 --experiment-name=beta_VAE_bs_64 --gpu-devices 0 1 ---batch-size=256 --dataset=rendered_chairs --epochs=40 --latent_spec_cont=10 --lr=1e-4 --experiment-name=VAE_bs_256 --gpu-devices 0 1 ---batch-size=64 --dataset=rendered_chairs --epochs=40 --latent_spec_cont=10 --lr=1e-4 --experiment-name=VAE_bs_64 --gpu-devices 0 1 --load-model_checkpoint=False --experiment-name=VAE_bs_64 +--batch-size=256 --dataset=rendered_chairs --epochs=400 --latent_spec_cont=10 --is-beta-VAE=True --beta=4 --lr=1e-4 --experiment-name=beta_VAE_bs_256 --gpu-devices 0 1 --experiment-name=beta_VAE_bs_256 +--batch-size=64 --dataset=rendered_chairs --epochs=400 --latent_spec_cont=10 --is-beta-VAE=True --beta=4 --lr=1e-4 --experiment-name=beta_VAE_bs_64 --gpu-devices 0 1 --experiment-name=beta_VAE_bs_64 +--batch-size=256 --dataset=rendered_chairs --epochs=400 --latent_spec_cont=10 --lr=1e-4 --experiment-name=VAE_bs_256 --gpu-devices 0 1 --experiment-name=VAE_bs_256 +--batch-size=64 --dataset=rendered_chairs --epochs=400 --latent_spec_cont=10 --lr=1e-4 --experiment-name=VAE_bs_64 --gpu-devices 0 1 --experiment-name=VAE_bs_64 diff --git a/utils/training.py b/utils/training.py index 5871de5b039b2295bbe05e20d12700d02550a7fb..aca5b92fcc8cc0f3709f9dd8bb706c48f98a5d52 100644 --- a/utils/training.py +++ b/utils/training.py @@ -32,7 +32,7 @@ class Trainer: record_loss_every : int Frequency with which loss is recorded during training. """ - if type(model) == 'torch.nn.parallel.data_parallel.DataParallel': + if 'parallel' in str(type(model)): self.model = model.module else: self.model = model diff --git a/viz/visualize.py b/viz/visualize.py index f2dd36e41c5dee61d88b4c09ddfda4cacb985738..560749a3f903560c6c98aa3620109ab711181be1 100644 --- a/viz/visualize.py +++ b/viz/visualize.py @@ -15,7 +15,7 @@ class Visualizer: ---------- model : VAE_model.models.VAE instance """ - if type(model) == 'torch.nn.parallel.data_parallel.DataParallel': + if 'parallel' in str(type(model)): self.model = model.module else: self.model = model