public static void TrainingBP_sLDA(
SparseMatrix TrainData,
SparseMatrix TrainLabel,
SparseMatrix TestData,
SparseMatrix TestLabel,
SparseMatrix ValidData,
SparseMatrix ValidLabel,
paramModel_t paramModel,
paramTrain_t paramTrain,
string ModelFile,
string ResultFile
)
{
Console.WriteLine("*****************************************************************");
Console.WriteLine("jvking version of BP-sLDA: Mirror-Descent Back Propagation");
Console.WriteLine("*****************************************************************");
// ---- Extract the parameters ----
// Model parameters
int nInput = paramModel.nInput;
int nHid = paramModel.nHid;
int nHidLayer = paramModel.nHidLayer;
int nOutput = paramModel.nOutput;
float eta = paramModel.eta;
float T_value = paramModel.T_value;
string OutputType = paramModel.OutputType;
float beta = paramModel.beta;
// Training parameters
int nEpoch = paramTrain.nEpoch;
float mu_Phi = paramTrain.mu_Phi;
float mu_U = paramTrain.mu_U;
int nTrain = paramTrain.nTrain;
float mu_ReduceFactor = paramTrain.mu_Phi_ReduceFactor;
string LearnRateSchedule = paramTrain.LearnRateSchedule;
int nSamplesPerDisplay = paramTrain.nSamplesPerDisplay;
int nEpochPerSave = paramTrain.nEpochPerSave;
int nEpochPerTest = paramTrain.nEpochPerTest;
int nEpochPerDump = paramTrain.nEpochPerDump;
// ---- Initialize the model ----
ModelInit_LDA_Feedforward(paramModel);
// ---- Initialize the training algorithm ----
float TotLoss = 0.0f;
float TotTrErr = 0.0f;
double TotTime = 0.0f;
double TotTimeThisEpoch = 0.0f;
int TotSamples = 0;
int TotSamplesThisEpoch = 0;
float CntRunningAvg = 0.0f;
float CntModelUpdate = 0.0f;
double AvgnHidLayerEffective = 0.0f;
DenseRowVector mu_phi_search = new DenseRowVector(nHid, mu_Phi);
DenseRowVector mu_U_search = new DenseRowVector(nHid, mu_U);
DenseRowVector AdaGradSum = new DenseRowVector(nHid, 0.0f);
DenseRowVector TmpDenseRowVec = new DenseRowVector(nHid, 0.0f);
DenseRowVector TestError_pool = new DenseRowVector(nEpoch / nEpochPerTest, 0.0f);
DenseRowVector ValidError_pool = new DenseRowVector(nEpoch / nEpochPerTest, 0.0f);
DenseRowVector TrainError_pool = new DenseRowVector(nEpoch / nEpochPerTest, 0.0f);
DenseRowVector TrainLoss_pool = new DenseRowVector(nEpoch / nEpochPerTest, 0.0f);
DenseRowVector TestError_epoch = new DenseRowVector(nEpoch / nEpochPerTest, 0.0f);
DenseRowVector TestError_time = new DenseRowVector(nEpoch / nEpochPerTest, 0.0f);
int CountTest = 0;
float nLearnLineSearch = 0.0f;
int[] IdxPerm = null;
int BatchSize_NormalBatch = paramTrain.BatchSize;
int BatchSize_tmp = paramTrain.BatchSize;
int nBatch = (int)Math.Ceiling(((float)nTrain) / ((float)BatchSize_NormalBatch));
DNNRun_t DNNRun_NormalBatch = new DNNRun_t(nHid, BatchSize_NormalBatch, paramModel.nHidLayer, nOutput);
DNNRun_t DNNRun_EndBatch = new DNNRun_t(nHid, nTrain - (nBatch - 1) * BatchSize_NormalBatch, paramModel.nHidLayer, nOutput);
DNNRun_t DNNRun = null;
Grad_t Grad = new Grad_t(nHid, nOutput, nInput, paramModel.nHidLayer, OutputType);
SparseMatrix TmpGrad = new SparseMatrix(nInput, nHid, true);
DenseMatrix TmpMatDensePhi = new DenseMatrix(nInput, nHid);
DenseMatrix TmpMatDenseU = new DenseMatrix(nOutput, nHid);
paramModel_t paramModel_avg = new paramModel_t(paramModel);
Stopwatch stopWatch = new Stopwatch();
// ---- Compute the schedule of the learning rate
double[] stepsize_pool_Phi = null;
double[] stepsize_pool_U = null;
switch (LearnRateSchedule)
{
case "PreCompute":
stepsize_pool_Phi = PrecomputeLearningRateSchedule(nBatch, nEpoch, mu_Phi, mu_Phi / mu_ReduceFactor, 1e-8f);
stepsize_pool_U = PrecomputeLearningRateSchedule(nBatch, nEpoch, mu_U, mu_U / mu_ReduceFactor, 1e-8f);
break;
case "Constant":
stepsize_pool_Phi = new double[nEpoch];
stepsize_pool_U = new double[nEpoch];
for (int Idx = 0; Idx < nEpoch; Idx++)
{
stepsize_pool_Phi[Idx] = mu_Phi;
stepsize_pool_U[Idx] = mu_U;
}
break;
default:
throw new Exception("Unknown type of LearnRateSchedule");
}
// Now start training.........................
for (int epoch = 0; epoch < nEpoch; epoch++)
{
TotSamplesThisEpoch = 0;
TotTimeThisEpoch = 0.0;
AvgnHidLayerEffective = 0.0f;
// -- Set the batch size if there is schedule --
if (paramTrain.flag_BachSizeSchedule)
{
if (paramTrain.BachSizeSchedule.TryGetValue(epoch + 1, out BatchSize_tmp))
{
BatchSize_NormalBatch = BatchSize_tmp;
nBatch = (int)Math.Ceiling(((float)nTrain) / ((float)BatchSize_NormalBatch));
DNNRun_NormalBatch = new DNNRun_t(nHid, BatchSize_NormalBatch, paramModel.nHidLayer, nOutput);
DNNRun_EndBatch = new DNNRun_t(nHid, nTrain - (nBatch - 1) * BatchSize_NormalBatch, paramModel.nHidLayer, nOutput);
}
}
// -- Shuffle the data (generating shuffled index) --
IdxPerm = Statistics.RandPerm(nTrain);
// -- Reset the (MDA) inference step-sizes --
if (epoch > 0)
{
for (int Idx = 0; Idx < paramModel.nHidLayer; Idx++)
{
paramModel.T[Idx] = T_value;
}
}
// -- Take the learning rate for the current epoch --
mu_Phi = (float)stepsize_pool_Phi[epoch];
mu_U = (float)stepsize_pool_U[epoch];
// -- Start this epoch --
Console.WriteLine("############## Epoch #{0}. BatchSize: {1} Learning Rate: Phi:{2}, U:{3} ##################",
epoch + 1, BatchSize_NormalBatch, mu_Phi, mu_U);
for (int IdxBatch = 0; IdxBatch < nBatch; IdxBatch++)
{
stopWatch.Start();
// Extract the batch
int BatchSize = 0;
if (IdxBatch < nBatch - 1)
{
BatchSize = BatchSize_NormalBatch;
DNNRun = DNNRun_NormalBatch;
}
else
{
BatchSize = nTrain - IdxBatch * BatchSize_NormalBatch;
DNNRun = DNNRun_EndBatch;
}
SparseMatrix Xt = new SparseMatrix(nInput, BatchSize);
SparseMatrix Dt = new SparseMatrix(nOutput, BatchSize);
int[] IdxSample = new int[BatchSize];
Array.Copy(IdxPerm, IdxBatch * BatchSize_NormalBatch, IdxSample, 0, BatchSize);
TrainData.GetColumns(Xt, IdxSample);
TrainLabel.GetColumns(Dt, IdxSample);
// Forward activation
LDA_Learn.ForwardActivation_LDA(Xt, DNNRun, paramModel, true);
// Back propagation
LDA_Learn.BackPropagation_LDA(Xt, Dt, DNNRun, paramModel, Grad);
// Compute the gradient and update the model (All gradients of Phi are accumulated into Grad.grad_Q_Phi)
// (i) Update Phi
MatrixOperation.ScalarDivideMatrix(Grad.grad_Q_Phi, (-1.0f) * ((beta - 1) / ((float)nTrain)), paramModel.Phi, true);
mu_phi_search.FillValue(mu_Phi);
// Different learning rate for different columns of Phi: Similar to AdaGrad but does not decay with time
++CntModelUpdate;
MatrixOperation.ElementwiseMatrixMultiplyMatrix(TmpMatDensePhi, Grad.grad_Q_Phi, Grad.grad_Q_Phi);
MatrixOperation.VerticalSumMatrix(TmpDenseRowVec, TmpMatDensePhi);
MatrixOperation.ScalarMultiplyVector(TmpDenseRowVec, 1.0f / ((float)nInput));
MatrixOperation.VectorSubtractVector(TmpDenseRowVec, AdaGradSum);
MatrixOperation.ScalarMultiplyVector(TmpDenseRowVec, 1.0f / CntModelUpdate);
MatrixOperation.VectorAddVector(AdaGradSum, TmpDenseRowVec);
MatrixOperation.ElementwiseSquareRoot(TmpDenseRowVec, AdaGradSum);
MatrixOperation.ScalarAddVector(TmpDenseRowVec, mu_Phi);
MatrixOperation.ElementwiseVectorDivideVector(mu_phi_search, mu_phi_search, TmpDenseRowVec);
nLearnLineSearch = SMD_Update(paramModel.Phi, Grad.grad_Q_Phi, mu_phi_search, eta);
// (ii) Update U
MatrixOperation.ScalarMultiplyMatrix(Grad.grad_Q_U, (-1.0f) * mu_U);
MatrixOperation.MatrixAddMatrix(paramModel.U, Grad.grad_Q_U);
// (iii) Running average of the model
if (paramTrain.flag_RunningAvg && epoch >= (int)Math.Ceiling(((float)nEpoch)/2.0f))
{
++CntRunningAvg;
MatrixOperation.MatrixSubtractMatrix(TmpMatDensePhi, paramModel.Phi, paramModel_avg.Phi);
MatrixOperation.MatrixSubtractMatrix(TmpMatDenseU, paramModel.U, paramModel_avg.U);
MatrixOperation.ScalarMultiplyMatrix(TmpMatDensePhi, 1.0f / CntRunningAvg);
MatrixOperation.ScalarMultiplyMatrix(TmpMatDenseU, 1.0f / CntRunningAvg);
MatrixOperation.MatrixAddMatrix(paramModel_avg.Phi, TmpMatDensePhi);
MatrixOperation.MatrixAddMatrix(paramModel_avg.U, TmpMatDenseU);
}
// Display the result
TotTrErr += 100 * ComputeNumberOfErrors(Dt, DNNRun.y);
TotLoss += ComputeSupervisedLoss(Dt, DNNRun.y, paramModel.OutputType);
TotSamples += BatchSize;
TotSamplesThisEpoch += BatchSize;
AvgnHidLayerEffective =
(((double)(TotSamplesThisEpoch - BatchSize)) / ((double)TotSamplesThisEpoch)) * AvgnHidLayerEffective
+
1.0 / ((double)TotSamplesThisEpoch) * DNNRun.nHidLayerEffective.Sum();
stopWatch.Stop();
TimeSpan ts = stopWatch.Elapsed;
TotTime += ts.TotalSeconds;
TotTimeThisEpoch += ts.TotalSeconds;
stopWatch.Reset();
if (TotSamplesThisEpoch % nSamplesPerDisplay == 0)
{
// Display results
Console.WriteLine(
"* Ep#{0}/{1} Bat#{2}/{3}. Loss={4:F3}. TrErr={5:F3}%. Speed={6} Samples/Sec.",
epoch + 1, nEpoch,
IdxBatch + 1, nBatch,
TotLoss / TotSamples, TotTrErr / TotSamples,
(int)((double)TotSamplesThisEpoch / TotTimeThisEpoch)
);
if (paramTrain.DebugLevel == DebugLevel_t.medium)
{
Console.WriteLine(
" muPhiMax={0} \n muPhiMin={1}",
mu_phi_search.VectorValue.Max(), mu_phi_search.VectorValue.Min()
);
Console.WriteLine();
}
if (paramTrain.DebugLevel == DebugLevel_t.high)
{
Console.WriteLine(
" muPhiMax={0} \n muPhiMin={1}",
mu_phi_search.VectorValue.Max(), mu_phi_search.VectorValue.Min()
);
float MaxAbsVal_Grad_Q_Phi = Grad.grad_Q_Phi.MaxAbsValue();
float MaxAbsVal_Grad_Q_U = Grad.grad_Q_U.MaxAbsValue();
Console.WriteLine(
" AvgnHidLayerEff={0:F1}. G_Phi={1:F3}. G_U={2:F3}",
AvgnHidLayerEffective,
MaxAbsVal_Grad_Q_Phi,
MaxAbsVal_Grad_Q_U
);
// Save the screen into a log file
(new FileInfo(ResultFile + ".log")).Directory.Create();
using (StreamWriter LogFile = File.AppendText(ResultFile + ".log"))
{
LogFile.WriteLine(
"- Ep#{0}/{1} Bat#{2}/{3}. Loss={4:F3}. TrErr={5:F3}%. Speed={6} Samples/Sec.",
epoch + 1, nEpoch,
IdxBatch + 1, nBatch,
TotLoss / TotSamples, TotTrErr / TotSamples,
(int)((double)TotSamplesThisEpoch / TotTimeThisEpoch)
);
LogFile.WriteLine(
" muPhiMax={0} \n muPhiMin={1}",
mu_phi_search.VectorValue.Max(), mu_phi_search.VectorValue.Min()
);
LogFile.WriteLine(
" AvgnHidLayerEff={0:F1}. G_Phi={1:F3}. G_U={2:F3}",
AvgnHidLayerEffective,
MaxAbsVal_Grad_Q_Phi,
MaxAbsVal_Grad_Q_U
);
Console.WriteLine();
}
Console.WriteLine();
}
}
}
// -- Test --
if ((epoch + 1) % nEpochPerTest == 0)
{
// Standard performance metric
TestError_epoch.VectorValue[(epoch + 1) / nEpochPerTest - 1] = epoch + 1;
TestError_time.VectorValue[(epoch + 1) / nEpochPerTest - 1] = (float)TotTime;
if (paramTrain.flag_RunningAvg && epoch >= (int)Math.Ceiling(((float)nEpoch) / 2.0f))
{
if (paramTrain.flag_HasValidSet)
{
ValidError_pool.VectorValue[(epoch + 1) / nEpochPerTest - 1]
= Testing_BP_sLDA(
ValidData,
ValidLabel,
paramModel_avg,
paramTrain.BatchSize_Test,
ResultFile + ".validscore",
"Validation Set"
);
}
TestError_pool.VectorValue[(epoch + 1) / nEpochPerTest - 1]
= Testing_BP_sLDA(
TestData,
TestLabel,
paramModel_avg,
paramTrain.BatchSize_Test,
ResultFile + ".testscore",
"Test Set"
);
}
else
{
if (paramTrain.flag_HasValidSet)
{
ValidError_pool.VectorValue[(epoch + 1) / nEpochPerTest - 1]
= Testing_BP_sLDA(
ValidData,
ValidLabel,
paramModel,
paramTrain.BatchSize_Test,
ResultFile + ".validscore",
"Validation Set"
);
}
TestError_pool.VectorValue[(epoch + 1) / nEpochPerTest - 1]
= Testing_BP_sLDA(
TestData,
TestLabel,
paramModel,
paramTrain.BatchSize_Test,
ResultFile + ".testscore",
"Test Set"
);
}
TrainError_pool.VectorValue[(epoch + 1) / nEpochPerTest - 1]
= TotTrErr / TotSamples;
TrainLoss_pool.VectorValue[(epoch + 1) / nEpochPerTest - 1]
= TotLoss / TotSamples;
// Performance metric evaluated using external evaluation tools, e.g., AUC, Top@K accuracy, etc.
if (paramTrain.flag_ExternalEval)
{
ExternalEvaluation(
paramTrain.ExternalEval,
ResultFile,
paramTrain.TestLabelFile,
epoch,
"Test Set"
);
if (paramTrain.flag_HasValidSet)
{
ExternalEvaluation(
paramTrain.ExternalEval,
ResultFile,
paramTrain.ValidLabelFile,
epoch,
"Validation Set"
);
}
}
CountTest++;
}
// -- Save --
if ((epoch + 1) % nEpochPerSave == 0)
{
// Save model
string PhiCol = null;
string UCol = null;
(new FileInfo(ResultFile + ".model.Phi")).Directory.Create();
string ModelName_Phi;
string ModelName_U;
if (paramTrain.flag_SaveAllModels)
{
ModelName_Phi = ResultFile + ".model.Phi" + ".iter" + (epoch + 1).ToString();
ModelName_U = ResultFile + ".model.U" + ".iter" + (epoch + 1).ToString();
}
else
{
ModelName_Phi = ResultFile + ".model.Phi";
ModelName_U = ResultFile + ".model.U";
}
if (paramTrain.flag_RunningAvg && epoch >= (int)Math.Ceiling(((float)nEpoch) / 2.0f))
{
using (StreamWriter FileSaveModel_Phi = new StreamWriter(ModelName_Phi, false))
{
for (int IdxCol = 0; IdxCol < paramModel_avg.Phi.nCols; IdxCol++)
{
PhiCol = String.Join("\t", paramModel_avg.Phi.DenseMatrixValue[IdxCol].VectorValue);
FileSaveModel_Phi.WriteLine(PhiCol);
}
}
using (StreamWriter FileSaveModel_U = new StreamWriter(ModelName_U, false))
{
for (int IdxCol = 0; IdxCol < paramModel_avg.U.nCols; IdxCol++)
{
UCol = String.Join("\t", paramModel_avg.U.DenseMatrixValue[IdxCol].VectorValue);
FileSaveModel_U.WriteLine(UCol);
}
}
}
else
{
using (StreamWriter FileSaveModel_Phi = new StreamWriter(ModelName_Phi, false))
{
for (int IdxCol = 0; IdxCol < paramModel.Phi.nCols; IdxCol++)
{
PhiCol = String.Join("\t", paramModel.Phi.DenseMatrixValue[IdxCol].VectorValue);
FileSaveModel_Phi.WriteLine(PhiCol);
}
}
using (StreamWriter FileSaveModel_U = new StreamWriter(ModelName_U, false))
{
for (int IdxCol = 0; IdxCol < paramModel.U.nCols; IdxCol++)
{
UCol = String.Join("\t", paramModel.U.DenseMatrixValue[IdxCol].VectorValue);
FileSaveModel_U.WriteLine(UCol);
}
}
}
// Save the final learning curves
using (StreamWriter FileSavePerf = new StreamWriter(ResultFile + ".perf", false))
{
FileSavePerf.Write("Epoch:\t");
FileSavePerf.WriteLine(String.Join("\t", TestError_epoch.VectorValue));
FileSavePerf.Write("TrainTime:\t");
FileSavePerf.WriteLine(String.Join("\t", TestError_time.VectorValue));
if (paramTrain.flag_HasValidSet)
{
FileSavePerf.Write("Validation:\t");
FileSavePerf.WriteLine(String.Join("\t", ValidError_pool.VectorValue));
}
FileSavePerf.Write("Test:\t");
FileSavePerf.WriteLine(String.Join("\t", TestError_pool.VectorValue));
FileSavePerf.Write("TrainError:\t");
FileSavePerf.WriteLine(String.Join("\t", TrainError_pool.VectorValue));
FileSavePerf.Write("TrainLoss:\t");
FileSavePerf.WriteLine(String.Join("\t", TrainLoss_pool.VectorValue));
}
}
// -- Dump feature --
if (paramTrain.flag_DumpFeature && (epoch + 1) % nEpochPerDump == 0)
{
if (paramTrain.flag_RunningAvg && epoch >= (int)Math.Ceiling(((float)nEpoch) / 2.0f))
{
DumpingFeature_BP_LDA(TrainData, paramModel_avg, paramTrain.BatchSize_Test, ResultFile + ".train.fea", "Train");
DumpingFeature_BP_LDA(TestData, paramModel_avg, paramTrain.BatchSize_Test, ResultFile + ".test.fea", "Test");
if (paramTrain.flag_HasValidSet)
{
DumpingFeature_BP_LDA(ValidData, paramModel_avg, paramTrain.BatchSize_Test, ResultFile + ".valid.fea", "Validation");
}
}
{
DumpingFeature_BP_LDA(TrainData, paramModel, paramTrain.BatchSize_Test, ResultFile + ".train.fea", "Train");
DumpingFeature_BP_LDA(TestData, paramModel, paramTrain.BatchSize_Test, ResultFile + ".test.fea", "Test");
if (paramTrain.flag_HasValidSet)
{
DumpingFeature_BP_LDA(ValidData, paramModel, paramTrain.BatchSize_Test, ResultFile + ".valid.fea", "Validation");
}
}
}
}
}