Training models with ternary quantized weights using PyTorch
Training models with ternary quantized weights. PyTorch implementation of https://arxiv.org/abs/1612.01064
model_full
) defined in model.py
was trained on MNIST data using full precision weights. The trained weight is stored as weights/original.ckpt
.main_original.py
.model_to_quantify
) and was trained using quantization. The trained weight is stored as weights/quantized.ckpt
.main_ternary.py
. The logs can be found inside the file logs/quantized_wp_wn_trainable.txt
.0.001
) like so:param.grad.data = torch.sign(param.grad.data) * 0.001
weights/autoquantize.ckpt
.quantification.py
gave better results:
w_p_grad = (a * grad_data).mean() # not (a * grad_data).sum()
w_n_grad = (b * grad_data).mean() # not (b * grad_data).sum()