NNX.Core.Training.UntilDoneGradientTrainer.Train C# (CSharp) Method

Train() public method

public Train ( IList trainingSet, IList validationSet, IRandomGenerator rand, INeuralNetwork nn ) : void
trainingSet IList
validationSet IList
rand IRandomGenerator
nn INeuralNetwork
return void
        public override void Train(IList<InputOutput> trainingSet,
            IList<InputOutput> validationSet,
            IRandomGenerator rand,
            INeuralNetwork nn)
        {
            var bestWeights = nn.Weights.DeepClone();
            var bestError = GetError(nn, validationSet);
            var epochsSinceLastImprovement = 0;
            var epochsToNextTest = EpochsBetweenValidations;

            var prevWeightGradients = nn.Weights.DeepClone();

            foreach (var gradSet in prevWeightGradients)
            {
                for (var j = 0; j < gradSet.Length; j++)
                    gradSet[j] = 0;
            }

            for (var epoch = 1; epoch <= NumEpochs; epoch++)
            {
                var batch = GetBatch(trainingSet, BatchSize, rand);

                var gradients = nn.Weights.DeepCloneToZeros();

                for (var j = 0; j < BatchSize; j++)
                {
                    gradients.AddInPlace(
                        nn.CalculateGradients(batch[j].Input.AddRelativeNoise(MaxRelativeNoise, rand), batch[j].Output));
                }

                gradients.MultiplyInPlace(1 / ((double)BatchSize));

                //var gradients = nn.CalculateGradients(inputOutput.Input, inputOutput.Output);
                AdjustWeights(nn, gradients, prevWeightGradients);
                gradients.DeepCopyTo(prevWeightGradients);

                // Check against validation set.
                epochsToNextTest--;

                if (epochsToNextTest == 0)
                {
                    epochsToNextTest = EpochsBetweenValidations;
                    var error = GetError(nn, validationSet);

                    if (error < bestError)
                    {
                        nn.Weights.DeepCopyTo(bestWeights);
                        epochsSinceLastImprovement = 0;
                    }
                    else
                    {
                        epochsSinceLastImprovement += EpochsBetweenValidations;

                        if (epochsSinceLastImprovement > MaxEpochsWithoutImprovement)
                            break;
                    }
                }
            }

            bestWeights.DeepCopyTo(nn.Weights);
        }