项目作者: iostapyshyn

项目描述 :
Fast neural network with back propagation learning algorithm.
高级语言: C
项目地址: git://github.com/iostapyshyn/nn.git
创建时间: 2019-12-15T00:18:14Z
项目社区:https://github.com/iostapyshyn/nn

开源协议:MIT License

下载


NN

Neural Network with back propagation learning algorithm written in C.

Usage

The project is intended to be used as a library.
However, there is a demo program examples/digits.c included, which uses
the MNIST database
to train itself and classify images of handwritten digits.

  1. git submodule update --init --recursive
  2. cd build
  3. cmake ..
  4. make

To run the example:

  1. ./digits five.png

The program will perform the training on it’s first run.

Quickstart

Please check include/nn/nn.h for better API explanation.

  1. /* Allocates a new (empty) network with n inputs. */
  2. neuralnetwork *nn_create(int n);
  3. /* Deallocates the network and all allocated data. */
  4. void nn_destroy(neuralnetwork *nn);
  5. /* Adds a new layer to the network with specified number of neurons.
  6. * Weights can be passed as a matrix stored in a double array.
  7. * NULL initializes weights randomly and biases with 0.
  8. * Possible values for activations:
  9. * IDENTITY, STEP, TANH, RELU, RELU_LEAKY, GAUSSIAN, SIGMOID, SOFTPLUS
  10. * Xavier initialization is used for all activation functions, except RELUs,
  11. * for which Kaiming is used. */
  12. void nn_addlayer(neuralnetwork *nn, int nodes, double *weights, double *biases, int activation);
  13. /* Forward propagates a given input through the network.
  14. * Returns pointer to the output array. */
  15. double *nn_forwardpropagate(neuralnetwork *nn, double *input);
  16. /* Performs forward propagation followed by the backpropagation to teach the network.
  17. * Passing learning rate of 0 will not perform back propagation.
  18. * Returns mean squared error of the forward pass. */
  19. double nn_backpropagate(neuralnetwork *nn, double *input, double *target, double learningrate);
  20. /* Useful functions for getting the number of inputs/outputs of the network. */
  21. int nn_ninputs(neuralnetwork *nn);
  22. int nn_noutputs(neuralnetwork *nn);
  23. /* File I/O functions. Store/read the network as a binary file to keep the double precision.
  24. * 0 or NULL is returned on failure. */
  25. int nn_writefile(const neuralnetwork *nn, const char *filename);
  26. neuralnetwork *nn_readfile(const char *filename);

License

This project is licensed under the MIT License - see the LICENSE.md file for details