项目作者: aleung12

项目描述 :
Predicting flight cancellations with a generic two-layer artificial neural network
高级语言: Python
项目地址: git://github.com/aleung12/NeuralNetwork.git
创建时间: 2019-08-12T11:43:14Z
项目社区:https://github.com/aleung12/NeuralNetwork

开源协议:MIT License

下载


NeuralNetwork

Predicting flight cancellations with a generic two-layer Artificial Neural Network

neural_network.py and
activation_functions.py
are generic. The artificial neural network is set to initialize to two layers, with
400 neurons in the first hidden layer and 250 neurons in the second hidden layer.
This neural network architecture is configured to work with
flight delays data,
which load_data.py transforms into 368 binary input variables.

The neural network uses a
leaky rectified linear unit#Leaky_ReLUs)
activation function to compute the hidden layers, and a sigmoid activation function to
compute the output layer (which is a binary classification in the case of predicting
flight cancellations). It is straightforward to modify the specified activation function
for the output layer (don’t forget the associated derivative in back propagation), though
care needs to be taken as the learning rate likely requires a lower initial value if
the network is adapted to perform regression.

  • To train the artificial neural network using k-fold cross-validation, execute
    train_network.py in command line
    (k = 2 is a single train/test split; k >= 2):
    1. python train_network.py 2
    The training data file ‘./flightdelays_data.csv’ is required. The user can specify
    a different file at line 112. The program’s default behavior is to start a new training
    from scratch. To resume a training, modify line 115 accordingly and specify _k
    at
    line 116 for the k-fold cross-validation in progress.
  • To use a trained neural network to predict claims for delayed or cancelled flights,
    execute predict_claims.py in command line:
    1. python predict_claims.py
    The user will be prompted to specify the path to a .csv file with the test data.
    Weights for the trained network needs to be in ./state/best/.