com.penguinwerks.jodene
Class NeuronTraining

java.lang.Object
  extended bycom.penguinwerks.jodene.NeuronTraining

public class NeuronTraining
extends java.lang.Object

This class is used by the neuron to facilitate training. The NeuronTraining class maintains the weight adjustments, the prior weight adjustments, the partial derivatives, and the learning rate and momentum for the training.

Author:
Paul Hoehne

Constructor Summary
NeuronTraining()
          Default constructor.
 
Method Summary
 void adjustWeights(double[] weights)
          Called to adjust the weights.
 java.lang.Double calculateDelta(double[] feedback, double derivative)
          Calcuates the delta given the deltas of the subsequent layers and the weight connecting this neuron to the next neuron.
 java.lang.Double calculateDelta(double expected, double actual, double derivative)
          Calculates the delta given the expected values, the actual values and the derivative of the activation value.
 void calculatePartials(double[] inputs, double delta)
          Initializes the partials, if necessary.
 void clearPartials()
          Clears the partial derivatives, called after weights are adjusted.
 void endEpoch()
          Called to end the epoch.
 double getLearningRate()
          The learning rate used by this Neuron trainer.
 double getMomentum()
          The momentum rate used by this Neuron trainer.
 double[] getPartials()
          Return the partial derivatives.
 Trainer getTrainer()
          The trainer that owns this neuron trainer.
 double[] getWeightUpdates()
          Returns the weight updates - anther specious function.
 void setLearningRate(double learningRate)
          The learning rate used by this Neuron trainer.
 void setMomentum(double momentum)
          The momentum rate used by this Neuron trainer.
 void setSize(int sz)
          Sizes the arrays for the weight adjustmets, etc.
 void updateWeightAdjustments()
          Updates the weight adjustments.
 
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 

Constructor Detail

NeuronTraining

public NeuronTraining()
Default constructor.

Method Detail

getTrainer

public Trainer getTrainer()
The trainer that owns this neuron trainer. Neuron trainers are children of trainers, much the same way that neurons are the children of networks.

Returns:
Returns the trainer.

setSize

public void setSize(int sz)
Sizes the arrays for the weight adjustmets, etc. Normally, this is the size of the inputs to the attached neuron.

Parameters:
sz - The size (in inputs) of the neuron to train.

getLearningRate

public double getLearningRate()
The learning rate used by this Neuron trainer. Currently all neuron trainers attached to a given trainer use the same rate, but that may change in the future.

Returns:
Returns the learningRate.

setLearningRate

public void setLearningRate(double learningRate)
The learning rate used by this Neuron trainer. Currently all neuron trainers attached to a given trainer use the same rate, but that may change in the future.

Parameters:
learningRate - The learningRate to set.

getMomentum

public double getMomentum()
The momentum rate used by this Neuron trainer. Currently all neuron trainers attached to a given trainer use the same rate, but that may change in the future.

Returns:
Returns the momentum.

setMomentum

public void setMomentum(double momentum)
The momentum rate used by this Neuron trainer. Currently all neuron trainers attached to a given trainer use the same rate, but that may change in the future.

Parameters:
momentum - The momentum to set.

calculateDelta

public java.lang.Double calculateDelta(double expected,
                                       double actual,
                                       double derivative)
Calculates the delta given the expected values, the actual values and the derivative of the activation value.

Parameters:
expected - The expected value.
actual - The actual value.
derivative - The derivative of the activation.
Returns:
The delta value.

calculateDelta

public java.lang.Double calculateDelta(double[] feedback,
                                       double derivative)
Calcuates the delta given the deltas of the subsequent layers and the weight connecting this neuron to the next neuron. This is used on internal or "hidden" neurons to back propagate the error. Normally this is defined as the Sum of the delta[i] * weight[i] for the i-th connection from this neuron to the next layer. We pre-multiply the delta and weight when we back-propagate the feedback and therefore we need only sum it.

Parameters:
feedback - The feedback (delta and connecting weight)
derivative - The derivative of the activation.
Returns:
The delta value.

clearPartials

public void clearPartials()
Clears the partial derivatives, called after weights are adjusted.


calculatePartials

public void calculatePartials(double[] inputs,
                              double delta)
Initializes the partials, if necessary. Calculates the partials by summing them. During an epoch the partial derivatives are cumulative. For on-line training, they are recalculated with ever input.

Parameters:
inputs - The inputs for this training example.
delta - The calculated delta for this neuron.

getPartials

public double[] getPartials()
Return the partial derivatives. This function seems pretty useless.

Returns:
The partial derivatives.

updateWeightAdjustments

public void updateWeightAdjustments()
Updates the weight adjustments. The weights are passed in only for the size. Could be changed to an int parameter. Could also pre- size the weight adjustment array size earlier. Also stores off the prior weight adjustents


getWeightUpdates

public double[] getWeightUpdates()
Returns the weight updates - anther specious function.

Returns:
The weight adjustments.

endEpoch

public void endEpoch()
Called to end the epoch. Clears the partial derivatives.


adjustWeights

public void adjustWeights(double[] weights)
Called to adjust the weights. Used the weight adjustments to update the weights passed in by reference.

Parameters:
weights - The weights to adjust.