function train_nn_classification_model(learning_rate,
steps,
batch_size,
hidden_units,
is_embedding,
keep_probability,
training_examples,
training_targets,
validation_examples,
validation_targets)
"""Trains a neural network classification model.
Args:
learning_rate: A `float`, the learning rate.
steps: A non-zero `int`, the total number of training steps. A training step
consists of a forward and backward pass using a single batch.
batch_size: A non-zero `int`, the batch size.
hidden_units: A vector describing the layout of the neural network.
is_embedding: 'true' or 'false' depending on if the first layer of the NN is an embedding layer.
keep_probability: A `float`, the probability of keeping a node active during one training step.
Returns:
p1: Plot of the loss function for the different periods.
y: The final layer of the TensorFlow network.
final_probabilities: Final predicted probabilities on the validation examples.
weight_export: The weights of the first layer of the NN
feature_columns: TensorFlow feature columns.
target_columns: TensorFlow target columns.
"""
periods = 10
steps_per_period = steps / periods
# Create feature columns.
feature_columns = placeholder(Float32, shape=[-1, size(training_examples,2)])
target_columns = placeholder(Float32, shape=[-1, size(training_targets,2)])
# Network parameters
push!(hidden_units,size(training_targets,2)) #create an output node that fits to the size of the targets
activation_functions = Vector{Function}(size(hidden_units,1))
activation_functions[1:end-1]=z->nn.dropout(nn.relu(z), keep_probability)
activation_functions[end] = nn.sigmoid #Last function should be idenity as we need the logits
# create network
flag=0
weight_export=Variable([1])
Zs = [feature_columns]
for (ii,(hlsize, actfun)) in enumerate(zip(hidden_units, activation_functions))
Wii = get_variable("W_$ii"*randstring(4), [get_shape(Zs[end], 2), hlsize], Float32)
bii = get_variable("b_$ii"*randstring(4), [hlsize], Float32)
if((is_embedding==true) & (flag==0))
Zii=Zs[end]*Wii
else
Zii = actfun(Zs[end]*Wii + bii)
end
push!(Zs, Zii)
if(flag==0)
weight_export=Wii
flag=1
end
end
y=Zs[end]
eps=1e-8
cross_entropy = -reduce_mean(log(y+eps).*target_columns + log(1-y+eps).*(1-target_columns))
features_batches, targets_batches = create_batches(training_examples, training_targets, steps, batch_size)
# Standard Adam Optimizer
my_optimizer=train.minimize(train.AdamOptimizer(learning_rate), cross_entropy)
run(sess, global_variables_initializer())
# Train the model, but do so inside a loop so that we can periodically assess
# loss metrics.
println("Training model...")
println("LogLoss error (on validation data):")
training_log_losses = []
validation_log_losses = []
for period in 1:periods
# Train the model, starting from the prior state.
for i=1:steps_per_period
features, labels = my_input_fn(features_batches, targets_batches, convert(Int,(period-1)*steps_per_period+i), batch_size)
run(sess, my_optimizer, Dict(feature_columns=>construct_feature_columns(features), target_columns=>construct_feature_columns(labels)))
end
# Take a break and compute log loss.
training_log_loss = run(sess, cross_entropy, Dict(feature_columns=> construct_feature_columns(training_examples), target_columns=>construct_feature_columns(training_targets)));
validation_log_loss = run(sess, cross_entropy, Dict(feature_columns=> construct_feature_columns(validation_examples), target_columns=>construct_feature_columns(validation_targets)));
# Occasionally print the current loss.
println(" period ", period, ": ", training_log_loss)
# Add the loss metrics from this period to our list.
push!(training_log_losses, training_log_loss)
push!(validation_log_losses, validation_log_loss)
end
println("Model training finished.")
# Calculate final predictions (not probabilities, as above).
final_probabilities = run(sess, y, Dict(feature_columns=> validation_examples, target_columns=>validation_targets))
final_predictions=0.0.*copy(final_probabilities)
final_predictions=castto01(final_probabilities)
accuracy = sklm.accuracy_score(validation_targets, final_predictions)
println("Final accuracy (on validation data): ", accuracy)
# Output a graph of loss metrics over periods.
p1=plot(training_log_losses, label="training", title="LogLoss vs. Periods", ylabel="LogLoss", xlabel="Periods")
p1=plot!(validation_log_losses, label="validation")
return p1, y, final_probabilities, weight_export, feature_columns, target_columns
end