public override void Train(IList<InputOutput> trainingSet,
IList<InputOutput> validationSet,
IRandomGenerator rand,
INeuralNetwork nn)
{
var bestWeights = nn.Weights.DeepClone();
var bestError = GetError(nn, validationSet);
var epochsSinceLastImprovement = 0;
var epochsToNextTest = EpochsBetweenValidations;
var prevWeightGradients = nn.Weights.DeepClone();
foreach (var gradSet in prevWeightGradients)
{
for (var j = 0; j < gradSet.Length; j++)
gradSet[j] = 0;
}
for (var epoch = 1; epoch <= NumEpochs; epoch++)
{
var batch = GetBatch(trainingSet, BatchSize, rand);
var gradients = nn.Weights.DeepCloneToZeros();
for (var j = 0; j < BatchSize; j++)
{
gradients.AddInPlace(
nn.CalculateGradients(batch[j].Input.AddRelativeNoise(MaxRelativeNoise, rand), batch[j].Output));
}
gradients.MultiplyInPlace(1 / ((double)BatchSize));
//var gradients = nn.CalculateGradients(inputOutput.Input, inputOutput.Output);
AdjustWeights(nn, gradients, prevWeightGradients);
gradients.DeepCopyTo(prevWeightGradients);
// Check against validation set.
epochsToNextTest--;
if (epochsToNextTest == 0)
{
epochsToNextTest = EpochsBetweenValidations;
var error = GetError(nn, validationSet);
if (error < bestError)
{
nn.Weights.DeepCopyTo(bestWeights);
epochsSinceLastImprovement = 0;
}
else
{
epochsSinceLastImprovement += EpochsBetweenValidations;
if (epochsSinceLastImprovement > MaxEpochsWithoutImprovement)
break;
}
}
}
bestWeights.DeepCopyTo(nn.Weights);
}