public virtual IEnumerator<TBatch> GetEnumerator()
{
TBatch batch = Init();
if (Shuffle == ShuffleMethod.None)
PrepareBatch(Vector.Range(Inputs.Length));
else // if (Shuffle == ShuffleMethod.OnlyOnce || ShuffleMethod.EveryEpoch)
PrepareBatch(Vector.Sample(Inputs.Length));
this.CurrentIteration = 0;
this.CurrentEpoch = 0;
this.CurrentSample = 0;
this.CurrentMiniBatch = 0;
while (true)
{
for (int i = 0; i < batch.Inputs.Length; i++)
{
PutCurrentSampleInMiniBatch(batch, i);
CurrentSample++;
if (CurrentSample >= NumberOfSamples)
{
CurrentEpoch++;
CurrentSample = 0;
if (Shuffle == ShuffleMethod.EveryEpoch)
PrepareBatch(Vector.Sample(Inputs.Length));
}
}
batch.Index = CurrentMiniBatch;
yield return batch;
CurrentMiniBatch++;
if (CurrentMiniBatch >= NumberOfMiniBatches)
CurrentMiniBatch = 0;
CurrentIteration++;
if (MaxEpochs > 0 && CurrentEpoch > MaxEpochs)
yield break;
if (MaxIterations > 0 && CurrentIteration > MaxIterations)
yield break;
}
}