public static double[] PrecomputeLearningRateSchedule(int nBatch, int nEpoch, double LearnRateStart, double LearnRateEnd, double Accuracy)
{
// Initialization
double[] LearningRatePool = new double[nEpoch];
LearningRatePool[nEpoch - 1] = LearnRateEnd;
double b_min = 0;
double b_max = 0;
int iter = 0;
double b;
bool upper_flag;
bool lower_flag;
if (LearnRateEnd > LearnRateStart)
{
throw new System.ArgumentException("LearnRateEnd should be smaller than LearnRateStart");
}
// Precompute the optimal b by bi-section
while (Math.Abs(LearningRatePool[0] - LearnRateStart) > Accuracy * LearnRateStart)
{
// Upper value of b
b = b_max;
for (int k = (nEpoch - 1); k >= 1; k--)
{
LearningRatePool[k - 1] = 0.5 * (1 + 1 / (Math.Pow((1 - LearningRatePool[k] * b), 2 * nBatch))) * LearningRatePool[k];
}
upper_flag = ((LearningRatePool[0] > LearnRateStart) || (b * LearningRatePool.Max() >= 2)) ? true : false;
// Lower value of b
b = b_min;
for (int k = (nEpoch - 1); k >= 1; k--)
{
LearningRatePool[k - 1] = 0.5 * (1 + 1 / (Math.Pow((1 - LearningRatePool[k] * b), 2 * nBatch))) * LearningRatePool[k];
}
lower_flag = ((LearningRatePool[0] <= LearnRateStart) || (b * LearningRatePool.Max() < 2)) ? true : false;
if (!lower_flag)
{
throw new System.InvalidOperationException("lower_flag cannot be zero");
}
// Update
if (!upper_flag)
{
b_max = b_max + 1;
}
else
{
b = (b_max + b_min) / 2;
for (int k = (nEpoch - 1); k >= 1; k--)
{
LearningRatePool[k - 1] = 0.5 * (1 + 1 / (Math.Pow((1 - LearningRatePool[k] * b), 2 * nBatch))) * LearningRatePool[k];
}
if ((LearningRatePool[0] > LearnRateStart) || (b * LearningRatePool.Max() > 2))
{
b_max = b;
}
else
{
b_min = b;
}
}
iter++;
if (iter > 1e10)
{
throw new System.InvalidOperationException("Maximum number of iterations has reached");
}
}
return LearningRatePool;
}