项目作者: soskek

项目描述 :
Learning to learn by gradient descent by gradient descent, Andrychowicz et al., NIPS 2016
高级语言: Python
项目地址: git://github.com/soskek/learning_to_learn.git
创建时间: 2019-07-04T00:34:24Z
项目社区:https://github.com/soskek/learning_to_learn

开源协议:MIT License

下载


Learning to Learn in Chainer

A chainer implementation of “Learning to learn by gradient descent by gradient descent“ by Andrychowicz et al.
It trains and tests an LSTM-based optimizer which has learnable parameters transforming a series of gradients to an update value.

testloss_comaparing_adam_and_lstmopt

What is LSTM-based optimizer?

SGD is a simplest transformation of gradient; just multiplying a gradient with a constant value of learning rate.
Momentume SGD or more sophiscated ones like Adam use a series of graadients.
Finally, this LSTM-based optimizer is a general case of them.
While they define update formula in advance (by hand), the LSTM-based optimizer learns the formula, i.e, how to merge a history of gradients to reach convergence efficiently.

Dependencies

See requirements.txt

  • chainer>=6.0.0
  • cupy>=6.0.0
  • seaborn==0.9.0

Training

This repository includes an experiment using a fully connected network on MNIST.

  1. python train_mnist.py -g 0

The script trains an LSTM-based optimizer as follows

  1. init optimizer
  2. for-loop
  3. init model
  4. for-loop
  5. update model by optimizer
  6. update optimizer by Adam (at every 20 steps)
  7. test optimizer (through another training of models)
  8. save optimizer if it is best

Citation

  1. @incollection{NIPS2016_6461,
  2. title = {Learning to learn by gradient descent by gradient descent},
  3. author = {Andrychowicz, Marcin and Denil, Misha and G\'{o}mez, Sergio and Hoffman, Matthew W and Pfau, David and Schaul, Tom and Shillingford, Brendan and de Freitas, Nando},
  4. booktitle = {Advances in Neural Information Processing Systems 29},
  5. year = {2016},
  6. url = {http://papers.nips.cc/paper/6461-learning-to-learn-by-gradient-descent-by-gradient-descent.pdf}
  7. }