项目作者: swasun

项目描述 :
PyTorch implementation of VQ-VAE applied on CIFAR10 dataset
高级语言: Python
项目地址: git://github.com/swasun/VQ-VAE-Images.git
创建时间: 2019-02-18T15:10:17Z
项目社区:https://github.com/swasun/VQ-VAE-Images

开源协议:MIT License

下载


PyTorch implementation of VQ-VAE by [van den Oord et al., 2017] applied to CIFAR10 dataset by [Alex Krizhevsky, 2009] using classes, inspired from the code of [zalandoresearch/pytorch-vq-vae] and [deepmind/sonnet].

Results

The trained models used in the following experiments are saved in results/shuffled and results/unshuffled/ directories.

The experiments was shorter than necessary as it was only for educational purpose. In order to obtain better image reconstructions, it is necessary to increase the number of residual hidden neurons (ie., 256 instead of 256) and to increase the number of training updates (ie., 250K instead of 25K).

The following results (results/unsuffled) are slightly less good than (results/shuffled).

Using original version

Reconstruction loss plot using the original version by [van den Oord et al., 2017]:

alt text

The original images:

alt text

The reconstructed images:

alt text

Using EMA updates

In my experiments, using the EMA updates proposes in [Roy et al., 2018], the final reconstruction loss was 2.66 times smaller (0.235 instead of 0.627) for shuffled dataset, and similar for unshuffled dataset:

alt text

The original images:

alt text

As we can see, the reconstructed images are less blurred than the previous ones:

alt text

Using EMA updates + kaiming normal

One can also use the weight normalization proposed by [He, K et al., 2015], as the model converges a little faster.

alt text

The original images :

alt text

The reconstructed images :

alt text

I also used nn.utils.weight_norm() before each call of kaiming_normal(), as they do in [ksw0306/ClariNet] because the model converged better. In my experiments, EMA + kaiming without this additional normalisation reduces the performances, as we can see in the additional results.

Installation

It requires python3, python3-pip and the packages listed in requirements.txt.

To install the required packages:

  1. pip3 install -r requirements.txt

Examples of usage

First, move to the source directory:

  1. cd src
  1. python3 main.py --help

Output:

  1. usage: main.py [-h] [--batch_size [BATCH_SIZE]]
  2. [--num_training_updates [NUM_TRAINING_UPDATES]]
  3. [--num_hiddens [NUM_HIDDENS]]
  4. [--num_residual_hiddens [NUM_RESIDUAL_HIDDENS]]
  5. [--num_residual_layers [NUM_RESIDUAL_LAYERS]]
  6. [--embedding_dim [EMBEDDING_DIM]]
  7. [--num_embeddings [NUM_EMBEDDINGS]]
  8. [--commitment_cost [COMMITMENT_COST]] [--decay [DECAY]]
  9. [--learning_rate [LEARNING_RATE]]
  10. [--use_kaiming_normal [USE_KAIMING_NORMAL]]
  11. [--shuffle_dataset [SHUFFLE_DATASET]] [--data_path [DATA_PATH]]
  12. [--results_path [RESULTS_PATH]]
  13. [--loss_plot_name [LOSS_PLOT_NAME]] [--model_name [MODEL_NAME]]
  14. [--original_images_name [ORIGINAL_IMAGES_NAME]]
  15. [--validation_images_name [VALIDATION_IMAGES_NAME]]
  16. [--use_cuda_if_available [USE_CUDA_IF_AVAILABLE]]
  17. optional arguments:
  18. -h, --help show this help message and exit
  19. --batch_size [BATCH_SIZE]
  20. The size of the batch during training (default: 32)
  21. --num_training_updates [NUM_TRAINING_UPDATES]
  22. The number of updates during training (default: 25000)
  23. --num_hiddens [NUM_HIDDENS]
  24. The number of hidden neurons in each layer (default:
  25. 128)
  26. --num_residual_hiddens [NUM_RESIDUAL_HIDDENS]
  27. The number of hidden neurons in each layer within a
  28. residual block (default: 32)
  29. --num_residual_layers [NUM_RESIDUAL_LAYERS]
  30. The number of residual layers in a residual stack
  31. (default: 2)
  32. --embedding_dim [EMBEDDING_DIM]
  33. Representing the dimensionality of the tensors in the
  34. quantized space (default: 64)
  35. --num_embeddings [NUM_EMBEDDINGS]
  36. The number of vectors in the quantized space (default:
  37. 512)
  38. --commitment_cost [COMMITMENT_COST]
  39. Controls the weighting of the loss terms (default:
  40. 0.25)
  41. --decay [DECAY] Decay for the moving averages (set to 0.0 to not use
  42. EMA) (default: 0.99)
  43. --learning_rate [LEARNING_RATE]
  44. The learning rate of the optimizer during training
  45. updates (default: 0.0003)
  46. --use_kaiming_normal [USE_KAIMING_NORMAL]
  47. Use the weight normalization proposed in [He, K et
  48. al., 2015] (default: True)
  49. --unshuffle_dataset
  50. Do not shuffle the dataset before training (default: False)
  51. --data_path [DATA_PATH]
  52. The path of the data directory (default: data)
  53. --results_path [RESULTS_PATH]
  54. The path of the results directory (default: results)
  55. --loss_plot_name [LOSS_PLOT_NAME]
  56. The file name of the training loss plot (default:
  57. loss.png)
  58. --model_name [MODEL_NAME]
  59. The file name of trained model (default: model.pth)
  60. --original_images_name [ORIGINAL_IMAGES_NAME]
  61. The file name of the original images used in
  62. evaluation (default: original_images.png)
  63. --validation_images_name [VALIDATION_IMAGES_NAME]
  64. The file name of the reconstructed images used in
  65. evaluation (default: validation_images.png)
  66. --use_cuda_if_available [USE_CUDA_IF_AVAILABLE]
  67. Specify if GPU will be used if available (default:
  68. True)

Use default vector quantized algorithm, do not shuffle the dataset and do not use [He, K et al., 2015] weight normalization:

  1. python main.py --results_path="results/unshuffled/" --use_kaiming_normal=False --decay=0.0 --unshuffle_dataset

Use EMA vector quantized algorithm, do not shuffle the dataset and do not use [He, K et al., 2015] weight normalization:

  1. python main.py --results_path="results/unshuffled/" --use_kaiming_normal=False --decay=0.99 --loss_plot_name="loss_ema.png" --model_name="model_ema.pth" --original_images_name="original_images_ema.png" --validation_images_name="validation_images_ema.png" --unshuffle_dataset

Use EMA vector quantized algorithm, do not shuffle the dataset and do use [He, K et al., 2015] weight normalization:

  1. python main.py --results_path="results/unshuffled/" --use_kaiming_normal=True --decay=0.99 --loss_plot_name="loss_ema_norm_he-et-al.png" --model_name="model_ema_norm_he-et-al.pth" --original_images_name="original_images_ema_norm_he-et-al.png" --validation_images_name="validation_images_ema_norm_he-et-al.png" --unshuffle_dataset

Code usage

Example of usage (see here for the complete example):

  1. configuration = Configuration.build_from_args(args) # Get the dataset and model hyperparameters
  2. dataset = Cifar10Dataset(configuration.batch_size, dataset_path) # Create an instance of CIFAR10 dataset
  3. auto_encoder = AutoEncoder(device, configuration).to(device) # Create an AutoEncoder model using our GPU device
  4. optimizer = optim.Adam(auto_encoder.parameters(), lr=configuration.learning_rate, amsgrad=True) # Create an Adam optimizer instance
  5. trainer = Trainer(device, auto_encoder, optimizer, dataset) # Create a trainer instance
  6. trainer.train(configuration.num_training_updates) # Train our model on the CIFAR10 dataset
  7. trainer.save_loss_plot(results_path + os.sep + 'loss.png') # Save the loss plot
  8. auto_encoder.save(results_path + os.sep + 'model.pth') # Save our trained model
  9. evaluator = Evaluator(device, auto_encoder, dataset) # Create en Evaluator instance to evaluate our trained model
  10. evaluator.reconstruct() # Reconstruct our images from the embedded space
  11. evaluator.save_original_images_plot(results_path + os.sep + 'original_images.png') # Save the original images for comparaison purpose
  12. evaluator.save_validation_reconstructions_plot(results_path + os.sep + 'validation_images.png') # Reconstruct the decoded images and save them

References