BP_LDA.LDA_Learn.PrecomputeLearningRateSchedule C# (CSharp) Method

PrecomputeLearningRateSchedule() public static method

public static PrecomputeLearningRateSchedule ( int nBatch, int nEpoch, double LearnRateStart, double LearnRateEnd, double Accuracy ) : double[]
nBatch int
nEpoch int
LearnRateStart double
LearnRateEnd double
Accuracy double
return double[]
		public static double[] PrecomputeLearningRateSchedule(int nBatch, int nEpoch, double LearnRateStart, double LearnRateEnd, double Accuracy)
		{
			// Initialization
			double[] LearningRatePool = new double[nEpoch];
			LearningRatePool[nEpoch - 1] = LearnRateEnd;
			double b_min = 0;
			double b_max = 0;
			int iter = 0;
			double b;
			bool upper_flag;
			bool lower_flag;

			if (LearnRateEnd > LearnRateStart)
			{
				throw new System.ArgumentException("LearnRateEnd should be smaller than LearnRateStart");
			}

			// Precompute the optimal b by bi-section
			while (Math.Abs(LearningRatePool[0] - LearnRateStart) > Accuracy * LearnRateStart)
			{
				// Upper value of b
				b = b_max;
				for (int k = (nEpoch - 1); k >= 1; k--)
				{
					LearningRatePool[k - 1] = 0.5 * (1 + 1 / (Math.Pow((1 - LearningRatePool[k] * b), 2 * nBatch))) * LearningRatePool[k];
				}
				upper_flag = ((LearningRatePool[0] > LearnRateStart) || (b * LearningRatePool.Max() >= 2)) ? true : false;

				// Lower value of b
				b = b_min;
				for (int k = (nEpoch - 1); k >= 1; k--)
				{
					LearningRatePool[k - 1] = 0.5 * (1 + 1 / (Math.Pow((1 - LearningRatePool[k] * b), 2 * nBatch))) * LearningRatePool[k];
				}
				lower_flag = ((LearningRatePool[0] <= LearnRateStart) || (b * LearningRatePool.Max() < 2)) ? true : false;
				if (!lower_flag)
				{
					throw new System.InvalidOperationException("lower_flag cannot be zero");
				}

				// Update
				if (!upper_flag)
				{
					b_max = b_max + 1;
				}
				else
				{
					b = (b_max + b_min) / 2;
					for (int k = (nEpoch - 1); k >= 1; k--)
					{
						LearningRatePool[k - 1] = 0.5 * (1 + 1 / (Math.Pow((1 - LearningRatePool[k] * b), 2 * nBatch))) * LearningRatePool[k];
					}
					if ((LearningRatePool[0] > LearnRateStart) || (b * LearningRatePool.Max() > 2))
					{
						b_max = b;
					}
					else
					{
						b_min = b;
					}
				}
				iter++;
				if (iter > 1e10)
				{
					throw new System.InvalidOperationException("Maximum number of iterations has reached");
				}
			}

			return LearningRatePool;
		}