项目作者: ruihangdu

项目描述 :
A PyTorch implementation of the iterative pruning method described in Han et. al. (2015)
高级语言: Python
项目地址: git://github.com/ruihangdu/PyTorch-Deep-Compression.git
创建时间: 2018-01-20T22:28:13Z
项目社区:https://github.com/ruihangdu/PyTorch-Deep-Compression

开源协议:

下载


PyTorch Deep Compression

A PyTorch implementation of the iterative pruning method described in Han et. al. (2015)
The original paper: Learning both Weights and Connections for Efficient Neural Networks

Usage

The libs package contains utilities needed,
and compressor.py defines a Compressor class that allows pruning a network layer-by-layer.

The file iterative_pruning.py contains function iter_prune which achieves iterative pruning.

An example use of the function is described in the main function in the same file.
Please devise your own script and do

  1. from iterative_pruning import *

to import all necessary modules and run your script as follows.

  1. python your_script.py [-h] [--data DIR] [--arch ARCH] [-j N] [-b N]
  2. [-o O] [-m E] [-c I] [--lr LR] [--momentum M]
  3. [--weight_decay W] [--resume PATH] [--pretrained]
  4. [-t T [T ...]] [--cuda]

optional arguments:

  1. -h, --help show this help message and exit
  2. --data DIR, -d DIR path to dataset
  3. --arch ARCH, -a ARCH model architecture: alexnet | densenet121 |
  4. densenet161 | densenet169 | densenet201 | inception_v3
  5. | resnet101 | resnet152 | resnet18 | resnet34 |
  6. resnet50 | squeezenet1_0 | squeezenet1_1 | vgg11 |
  7. vgg11_bn | vgg13 | vgg13_bn | vgg16 | vgg16_bn | vgg19
  8. | vgg19_bn
  9. -j N, --workers N number of data loading workers (default: 4)
  10. -b N, --batch-size N mini-batch size (default: 256)
  11. -o O, --optimizer O optimizers: ASGD | Adadelta | Adagrad | Adam | Adamax
  12. | LBFGS | Optimizer | RMSprop | Rprop | SGD |
  13. SparseAdam (default: SGD)
  14. -m E, --max_epochs E max number of epochs while training
  15. -c I, --interval I checkpointing interval
  16. --lr LR, --learning-rate LR
  17. initial learning rate
  18. --momentum M momentum
  19. --weight_decay W, --wd W
  20. weight decay
  21. --resume PATH path to latest checkpoint (default: none)
  22. --pretrained use pre-trained model
  23. -t T [T ...], --topk T [T ...]
  24. Top k precision metrics
  25. --cuda

(other architectures in torch.vision package can also be chosen, but have not been experimented on). DATA_LOCATION should be replaced with the location of the ImageNet dataset on your machine.

Results

Model Top-1 Top-5 Compression Rate
LeNet-300-100 92% N/A 92%
LeNet-5 98.8% N/A 92%
AlexNet 39% 63% 85.99%

Note: To achieve better results, try to tweak the alpha hyper-parameter in function prune() to change the pruning rate of each layer.

Any comments, thoughts, and improvements are appreciated