com.penguinwerks.jodene
Class Trainer

java.lang.Object
  extended bycom.penguinwerks.jodene.Trainer

public class Trainer
extends java.lang.Object

Trains the network. The trainer configures the neurons in the network for training, with the given training parameters, and runs the training loop.

The trainer is given a set of training examples and uses those examples in the training loop to train the network. The trainer continues to present the examples until it is told to terminate, or the hardMax limit is reached. To configure the manner in which training is stopped, EpochEventListeners can be added. At the end of each training epoch, the listeners are send the endEpoch event. The MaxIterationsListener, for example, terminates training when a pre-set number of iterations occured.

The trainer is configured by setting the epochSize, the learningRate, and the momentum. If on-line training is used, the epochSize is set to 1, otherwise, it is the size of the trainingExamples, by default. If the appropriate EpochEventListener is implemented, techniques, such as adjusting the learning rate or momentum are possible.

Change on 04/24/04 The error handling has been switched to an error manager and a default error name. Also error history has been refactored to the error manager.

Author:
Paul Hoehne

Constructor Summary
Trainer()
          Default constructor.
 
Method Summary
 void addCalculator(ErrorCalculator calculator)
          The error calculator computes the error from the input values and the output values.
 void addCalculator(java.lang.String name, ErrorCalculator calculator)
          The error calculator computes the error from the input values and the output values.
 void addEpochEventListener(EpochEventListener listener)
          Adds an epoch event listener to the list of listeners.
 void endEpoch(Network network)
          Called at the end of an epoch.
 void fireEpochEvent(Network network)
          Fires an epoch training event to all the listeners.
 int getEpochCount()
          Returns the epoch count so far.
 int getEpochSize()
          The epoch size is the number of examples to present during training before the error is saved and the weights are adjusted.
 ErrorManager getErrorManager()
          Returns the error manager used by the trainer when calculating training error.
 int getHardMax()
          If no EpochEventListeners are added to the network the hardMax limit is used to safely stop training.
 double getLearningRate()
          Returns the default learning rate set for training.
 double getMomentum()
          Returns the momentum used during training.
 NeuronTraining getTraining()
          Returns the default neuron training for this trainer.
 ExampleSet getTrainingExamples()
          The training examples are set of known inputs and their outputs used to train the network.
 void requestTermination()
          Used primarily by the epoch event listeners to request the neural network to stop training.
 void setDefaultErrorName(java.lang.String name)
          Sets the default error name to examine.
 void setEpochSize(int epochSize)
          The epoch size is the number of examples to present during training before the error is saved and the weights are adjusted.
 void setHardMax(int hardMax)
          If no EpochEventListeners are added to the network the hardMax limit is used to safely stop training.
 void setLearningRate(double val)
          Set the default learning rate used for training.
 void setMomentum(double val)
          Sets the momentum used during training.
 void setTrainingExamples(ExampleSet trainingExamples)
          The training examples are set of known inputs and their outputs used to train the network.
 void train(Network network)
           Called to train the network.
 void updateWeightAdjustments()
          Causes the weight adjustments to be updated at all the neuron trainers.
 
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 

Constructor Detail

Trainer

public Trainer()
Default constructor.

Method Detail

getHardMax

public int getHardMax()
If no EpochEventListeners are added to the network the hardMax limit is used to safely stop training.

Returns:
Returns the hardMax.

setHardMax

public void setHardMax(int hardMax)
If no EpochEventListeners are added to the network the hardMax limit is used to safely stop training.

Parameters:
hardMax - The hardMax to set.

getLearningRate

public double getLearningRate()
Returns the default learning rate set for training. The default is 0.20. This is the learning rate passed to all the neuron trainers when they are created.

Returns:
The learning rate used by this trainer.

setLearningRate

public void setLearningRate(double val)
Set the default learning rate used for training. The default is 0.20. This is the learning rate passed to all the neuron trainers when they are created.

Parameters:
val - The new learning rate.

getMomentum

public double getMomentum()
Returns the momentum used during training. This is the default momentum passed to the neuron trainers when they are created. The default is 0.00.

Returns:
The momentum used by this trainer.

setMomentum

public void setMomentum(double val)
Sets the momentum used during training. This is the default momentum passed to the neuron trainers when they are created. The default is 0.00.

Parameters:
val - The new momentum.

addCalculator

public void addCalculator(ErrorCalculator calculator)
The error calculator computes the error from the input values and the output values. By default this is sum of squares. This version of the addCalculator uses the default name for the calculator.

Parameters:
calculator - The calculator to set.

addCalculator

public void addCalculator(java.lang.String name,
                          ErrorCalculator calculator)
The error calculator computes the error from the input values and the output values. By default this is the sum of squares.

Parameters:
calculator - The calculator to set.
name - The name of the error calculator.

getEpochSize

public int getEpochSize()
The epoch size is the number of examples to present during training before the error is saved and the weights are adjusted. Normally, this is equal to the size of the training set.

Returns:
Returns the epochSize.

setEpochSize

public void setEpochSize(int epochSize)
The epoch size is the number of examples to present during training before the error is saved and the weights are adjusted. Normally, this is equal to the size of the training set.

Parameters:
epochSize - The epochSize to set.

getTrainingExamples

public ExampleSet getTrainingExamples()
The training examples are set of known inputs and their outputs used to train the network. Normally the data is divided into a training set and a validation set. The training set is used to train the network while the validation set is used to test the quality of the training.

Returns:
Returns the trainingExamples.

setTrainingExamples

public void setTrainingExamples(ExampleSet trainingExamples)
The training examples are set of known inputs and their outputs used to train the network. Normally the data is divided into a training set and a validation set. The training set is used to train the network while the validation set is used to test the quality of the training.

Parameters:
trainingExamples - The trainingExamples to set.

train

public void train(Network network)
           throws TrainingException

Called to train the network. Iterates until the terminating flag is set. If there are no training event listeners set then there is a hard termination limit of 10000 (by default) iterations.

The training algorithm is simple. It presents the training examples in order, up to epochSize before ending the epoch. At that point the error for the epoch is calculated and the EndEpochEvent is sent to all the listeners. The weights of the network are then adjusted. Training stops when the hard limit is reached (given no listeners) or by a listener requesting a termination.

Parameters:
network - The network to train.
Throws:
TrainingException - Thrown if a problem with training occurs.

updateWeightAdjustments

public void updateWeightAdjustments()
Causes the weight adjustments to be updated at all the neuron trainers.


endEpoch

public void endEpoch(Network network)
Called at the end of an epoch. In this case it sends the end epoch message to all the neuron trainers.

Parameters:
network - The network being trained.

getTraining

public NeuronTraining getTraining()
Returns the default neuron training for this trainer. It also sets the learning rate as appropriate.

Returns:
The neuron training used by this trainer.

requestTermination

public void requestTermination()
Used primarily by the epoch event listeners to request the neural network to stop training.


fireEpochEvent

public void fireEpochEvent(Network network)
Fires an epoch training event to all the listeners. It sends an EpochTrainingEvent, with the trainer as the source.

Parameters:
network - The network for which the event is fired.

addEpochEventListener

public void addEpochEventListener(EpochEventListener listener)
Adds an epoch event listener to the list of listeners. The epoch end event is fired at the end of every epoch, sending the trainer as the source.

Parameters:
listener - A new listener for epoch events.

getEpochCount

public int getEpochCount()
Returns the epoch count so far.

Returns:
The total training epochs that have transpired.

getErrorManager

public ErrorManager getErrorManager()
Returns the error manager used by the trainer when calculating training error.

Returns:
The erorr manager.

setDefaultErrorName

public void setDefaultErrorName(java.lang.String name)
Sets the default error name to examine. If no error is specified, the default error name is used.

Parameters:
name - The default error name to use.