My final project is to take the two-layer neural net code below which I showed on Thursday, April 21 and add training to it.
The neural net attempts to classify the irises in the famous iris dataset — Citation: Fisher, R. A. (1936) “The use of multiple measurements in taxonomic problems.” Annals of Eugenics, 7, Part II, 179–188. It does this with 4 neurons in the input layer (one for each feature of the dataset) and 3 neurons in the output layer.
The theory for training is pretty complex. I was told by someone who knows AI that I will need to compute the “cross-entropy between a one-hot encoding of the true labels and the output of the softmax.” Whoosh! Fortunately, this type of theory has been written up in many places. Here is one of the top Google hits for cross-entropy and artificial intelligence on medium.com
In plain English, we need a function that is used when training the neural net that takes the true value (which is one of the three iris species is in the data record), and compares it with the neural net’s output (which is a vector of three probabilities adding up to 1).
The cross-entropy is super-simple, actually. Suppose that for a record, the second of the three species is the correct answer. If we use zero-based indexing, the second of the three species has index 1. Suppose the softmax output is q0, q1, q2 (with q0 + q1 + q2 = 1). Then the entropy is simply -log q1. If the third species was the correct answer, then the entropy is simply -log q2.
Committed:
At present, this neural net has no training! The 23 weights — 8 in the input layer (4 of those are biases) and 15 in the output layer (3 of those are biases) — are all assigned randomly. So of course it does a completely random job of classifying.
Stretch:
As one additional feature, I want to add some kind of visual representation of the quality of classification. At present the code just logs what it is doing and has no visual representation at all.
The project is implemented, including the visual representation of the learning process.
This plot shows the regions color coded for correct classifications:
Red is setosa. Blue is versicolor. Green is virginica.
Here is how the neural net performs initially:
It is terrible and random, as one would expect with no training.
Here is how it does after about a thousand training steps:
This is very encouraging. Not only is setosa well classified, but versicolor is starting to be separated from virginica.
However, here is how it is doing after tens of thousands of training steps:
Sadly, it is still having a hard time distinguishing versicolor from virginica.
On some runs, I have seen it be partially successful at distinguishing versicolor from virginica:
So it is possible for it to do better. I have tried some experimentation with the training method (cutting the learning rate, adding normalization to each datum), but it does not make it more likely to do better. That’s where I’ll leave the project at!
Before getting into the implementation, it would be good to think about what this all means. It is at some level just a big minimization problem. The neural net isn’t doing something that I understand as learning. I don’t call a linear regression learning, even though it can make striking predictions about a dataset. A buzzword that expresses the skepticism that this is really learning is “generalizability.” An example is that recently a Tesla crashed into a private plane. After training a Tesla to avoid other cars, people, posts, walls, curbs, etc., apparently it didn’t have the common sense not to crash into a plane. A human, even one who had never seen a plane would not make that mistake. However, the intermediate layers in deep neural nets perhaps are creating representations of the inputs that have generalizability?
Some things that the friend that helped me choose what problem to initially solve suggested I look at next are: deep neural nets, residual nets (resnets), ReLU activation functions (which to some degree achieve sparsity in the neural net, and also don’t suffer from regimes with very small gradients), better initialization (choice of the initial random weights), the GPT-3 neural network, convolutional neural nets for images, and transformers for natural language processing. I don’t see myself going this direction, but I have a gold mine of search terms if I do.
Add this JSON file to the sketch
static final int DISPLAY_CORRECT_REGIONS = 300;
static final int TRAINING_STEPS_PER_FRAME = 200;
static final int FRAMES_PER_REFRESH = 30;
static final int FRAME_RATE = 30;
static float NEURON_LEARNING_RATE = 0.01;
static int FRAME_FOR_LEARNING_RATE_CUT = 3600;
static float LEARNING_RATE_CUT = 10.0;
// Filename for the iris dataset in JSON
static final String IRIS_DATASET_FILENAME = "iris.json";
// Key used at the top level of the JSON dictionary
static final String DATA_KEY = "data";
// Keys used in each iris data record
static final String[] DATUM_KEYS = {"sepalLength", "sepalWidth", "petalLength", "petalWidth"};
static final float[] DATUM_NORMALIZERS = { 5.0, 4.0, 1.5, 0.2 };
static final String SPECIES_KEY = "species";
// All sorts of things are data sources. The layer in a neural net can
// be a data source for another layer. The neural net itself needs a
// data source, and in this code, that will be the iris dataset.
interface DataSource {
// To be a data source for something
// an object must be able to compute outputs
float getOutput(int index);
// Mostly for debugging purposes, it must also be able to report its size
int getSize();
}
// All sorts of things are trainable. The whole two-layer neural net is
// trainable, a layer of the neural net is trainable, and each neuron is
// trainable. To be trainable, something must be able to generate a
// perturbation, and then be told whether that perturbation is an improvement
// or not.
interface Trainable {
void perturb();
void train(float carrotOrStick);
}
// This code is just complicated enough, and our Processing IDE is just primitive
// enough, that I added this class to help me with debugging. When an input source
// is incorrectly sized, it helps narrow down the problem.
class DataSourceSizeMismatchException extends Exception {
String sizeMismatchDescription;
DataSourceSizeMismatchException(String sizeMismatchDescription_) {
sizeMismatchDescription = sizeMismatchDescription_;
}
}
class Neuron implements Trainable {
float[] weights;
float[] perturbations;
DataSource dataSource = null;
int position;
// Constructor. Takes three variables, one of which is the DataSource
// for all neurons in the layer. The second of which is its position
// in the layer. Finally, the number of inputs. NB: There is actually
// one more weight than the number of inputs because there is a weigth
// for the bias. The constructor initializes the weights randomly.
// Later the weights must be trained.
Neuron(DataSource dataSource_, int position_, int n) {
dataSource = dataSource_;
position = position_;
weights = new float[n + 1];
perturbations = new float[n + 1];
for (int i = 0; i <= n; ++i) {
weights[i] = random(-1, 1);
perturbations[i] = 0.0;
}
}
float feedForward() {
float sum = 0.0;
for (int i = 0; i < weights.length - 1; ++i) {
int index = position * (weights.length - 1) + i;
sum += dataSource.getOutput(index) * (weights[i] + perturbations[i]);
}
sum += weights[weights.length - 1] + perturbations[weights.length - 1];
float output = activate(sum);
return output;
}
float activate(float sum) {
// We use the classic sigmoid function for activation. See:
// https://machinelearningmastery.com/a-gentle-introduction-to-sigmoid-function/
return 1.0 / (1.0 + exp(-sum));
}
void perturb() {
for (int i = 0; i < perturbations.length; ++i) {
perturbations[i] = random(-1.0 * NEURON_LEARNING_RATE, NEURON_LEARNING_RATE);
}
}
void train(float carrotOrStick) {
float accelerant = carrotOrStick / NEURON_LEARNING_RATE;
if (frameCount > FRAME_FOR_LEARNING_RATE_CUT) {
accelerant /= LEARNING_RATE_CUT;
}
for (int i = 0; i < perturbations.length; ++i) {
weights[i] += perturbations[i] * accelerant;
perturbations[i] = 0.0;
}
}
}
class Layer implements DataSource, Trainable {
Neuron[] neurons;
DataSource dataSource = null;
Layer(DataSource dataSource_, int inputsPerNeuron, int neuronsCount)
throws DataSourceSizeMismatchException {
neurons = new Neuron[neuronsCount];
if (dataSource_.getSize() != neurons.length * inputsPerNeuron) {
throw new DataSourceSizeMismatchException("Layer has neurons.length " + neurons.length +
"and inputsPerNeuron is " + inputsPerNeuron +
", but input source size is " + dataSource_.getSize() + ".");
}
dataSource = dataSource_;
for (int i = 0; i < neuronsCount; ++i) {
neurons[i] = new Neuron(dataSource, i, inputsPerNeuron);
}
}
float getOutput(int index) {
return neurons[index].feedForward();
}
int getSize() {
return neurons.length;
}
void perturb() {
for (int i = 0; i < neurons.length; ++i) {
neurons[i].perturb();
}
}
void train(float carrotOrStick) {
for (int i = 0; i < neurons.length; ++i) {
neurons[i].train(carrotOrStick);
}
}
}
class Wiring implements DataSource {
int outputsCount;
Layer layer;
Wiring(int outputsCount_, Layer layer_) {
outputsCount = outputsCount_;
layer = layer_;
}
float getOutput(int index) {
index = index % layer.neurons.length;
return layer.neurons[index].feedForward();
}
int getSize() {
return outputsCount * layer.getSize();
}
}
abstract class NeuralNet implements DataSource {
float[] softMax() {
float[] values = new float[getSize()];
float sum = 0.0;
for (int i = 0; i < values.length; ++i) {
values[i] = exp(getOutput(i));
sum += values[i];
}
for (int i = 0; i < values.length; ++i) {
values[i] /= sum;
}
return values;
}
int classify() {
float[] softMax = softMax();
float biggest = softMax[0];
int classification = 0;
for (int i = 1; i < softMax.length; ++i) {
if (softMax[i] > biggest) {
biggest = softMax[i];
classification = i;
}
}
return classification;
}
}
class TwoLayerNeuralNet extends NeuralNet implements Trainable {
DataSource dataSource;
Layer inputLayer;
Wiring wiring;
Layer outputLayer;
TwoLayerNeuralNet(DataSource dataSource_, int inputsCount, int outputsCount)
throws DataSourceSizeMismatchException {
dataSource = dataSource_;
inputLayer = new Layer(dataSource, 1, inputsCount);
wiring = new Wiring(outputsCount, inputLayer);
outputLayer = new Layer(wiring, inputsCount, outputsCount);
}
float getOutput(int index) {
return outputLayer.getOutput(index);
}
int getSize() {
return outputLayer.getSize();
}
void perturb() {
inputLayer.perturb();
outputLayer.perturb();
}
void train(float carrotOrStick) {
inputLayer.train(carrotOrStick);
outputLayer.train(carrotOrStick);
}
}
abstract class FloatingPointDataset implements DataSource {
float[][] records;
int[] classifications;
int recordPointer = 0;
int recordCount = 0;
float getOutput(int index) {
return records[recordPointer][index];
}
int getClassification() {
return classifications[recordPointer];
}
int getSize() {
return records[0].length;
}
int recordPointer() {
return recordPointer;
}
void setRecordPointer(int recordPointer_) {
recordPointer = recordPointer_;
}
}
class IrisDataset extends FloatingPointDataset {
ArrayList<String> speciesList = new ArrayList<String>();
void configure(JSONObject json) {
JSONArray data = json.getJSONArray(DATA_KEY);
recordCount = data.size();
classifications = new int[recordCount];
records = new float[recordCount][DATUM_KEYS.length];
for (int i = 0; i < recordCount; ++i) {
JSONObject irisRecord = data.getJSONObject(i);
String species = irisRecord.getString(SPECIES_KEY);
if (!speciesList.contains(species)) {
speciesList.add(species);
}
classifications[i] = speciesList.indexOf(species);
for (int j = 0; j < DATUM_KEYS.length; ++j) {
records[i][j] = irisRecord.getFloat(DATUM_KEYS[j]) / DATUM_NORMALIZERS[j];
}
}
}
}
class Trainer {
FloatingPointDataset trainingSet;
TwoLayerNeuralNet neuralNet;
Trainer(FloatingPointDataset trainingSet_, TwoLayerNeuralNet neuralNet_) {
trainingSet = trainingSet_;
neuralNet = neuralNet_;
}
float trainAndAssess() {
int recordCount = trainingSet.recordCount;
int randomRecordPointer = (int)random(recordCount);
// println("random record pointer is " + randomRecordPointer);
trainingSet.setRecordPointer(randomRecordPointer);
float[] softMax = neuralNet.softMax();
int correctClassification = trainingSet.getClassification();
float q_i = softMax[correctClassification];
float entropy = -1.0 * log(q_i);
// println("softMax[i] is " + q_i + ", entropy is " + entropy);
neuralNet.perturb();
float[] perturbedSoftMax = neuralNet.softMax();
float perturbed_q_i = perturbedSoftMax[correctClassification];
float perturbedEntropy = -1.0 * log(perturbed_q_i);
// println("perturbedSoftMax[i] is " + perturbed_q_i + ", perturbedEntropy is " + perturbedEntropy);
float carrotOrStick = entropy - perturbedEntropy;
neuralNet.train(carrotOrStick);
return perturbedEntropy;
}
}
TwoLayerNeuralNet neuralNet;
IrisDataset irisDataset;
Trainer trainer;
void setup() {
size(600, 400);
frameRate(30);
rectMode(CENTER);
irisDataset = new IrisDataset();
JSONObject json = loadJSONObject(IRIS_DATASET_FILENAME);
irisDataset.configure(json);
try {
neuralNet = new TwoLayerNeuralNet(irisDataset, DATUM_KEYS.length, irisDataset.speciesList.size());
}
catch (DataSourceSizeMismatchException e) {
e.printStackTrace(System.out);
}
trainer = new Trainer(irisDataset, neuralNet);
}
int framesToNextRefresh = 0;
int displayingCorrectRegions = DISPLAY_CORRECT_REGIONS;
void draw() {
if (displayingCorrectRegions > 0) {
background(200);
for (float xx = 0.0; xx < 1.0; xx += 0.01) {
for (float yy = 0.0; yy < 1.0; yy += 0.01) {
float zz = 1.0 - xx - yy;
if (zz < 0.0) continue;
if (xx > yy && xx > zz) {
// xx is the largest
fill(255, 0, 0);
} else if (yy > zz) {
// yy is the largest
fill(0, 255, 0);
} else {
// zz is the largest
fill(0, 0, 255);
}
rect(600 * xx + 2, 400 * (1 - yy) + 2, 4, 4);
}
}
displayingCorrectRegions -= 1;
return;
}
if (framesToNextRefresh == 0) {
refresh();
framesToNextRefresh = FRAMES_PER_REFRESH;
} else {
for (int i = 0; i < TRAINING_STEPS_PER_FRAME; ++i) {
trainer.trainAndAssess();
}
--framesToNextRefresh;
}
}
void refresh() {
background(200);
stroke(0);
line(600 / 3, 2 * 400 / 3, 600 / 2, 400 / 2);
line(600 / 3, 2 * 400 / 3, 0, 400 / 2);
line(600 / 3, 2 * 400 / 3, 600 / 2, 400);
line(0, 0, 600, 400);
int recordCount = irisDataset.recordCount;
for (int recordPointer = 0; recordPointer < recordCount; ++recordPointer) {
irisDataset.setRecordPointer(recordPointer);
float[] softMax = neuralNet.softMax();
float x = softMax[0] * 600;
float y = (1.0 - softMax[1]) * 400;
int correctClassification = irisDataset.getClassification();
int r = correctClassification == 0 ? 255 : 0;
int g = correctClassification == 1 ? 255 : 0;
int b = correctClassification == 2 ? 255 : 0;
fill(r, g, b);
rect(x, y, 4, 4);
}
}