public static bool ParseArgument(
string[] args,
paramModel_t paramModel,
paramTrain_t paramTrain,
ref string ModelFile,
ref string ResultFile
)
{
string ArgKey;
string ArgValue;
for (int IdxArg = 0; IdxArg < args.Length - 1; IdxArg += 2)
{
ArgKey = args[IdxArg];
ArgValue = args[IdxArg + 1];
switch (ArgKey)
{
case "--nHid":
paramModel.nHid = int.Parse(ArgValue);
break;
case "--nHidLayer":
paramModel.nHidLayer = int.Parse(ArgValue);
break;
case "--To":
paramModel.To = float.Parse(ArgValue);
break;
case "--alpha":
paramModel.alpha = float.Parse(ArgValue);
break;
case "--beta":
paramModel.beta = float.Parse(ArgValue);
break;
case "--nEpoch":
paramTrain.nEpoch = int.Parse(ArgValue);
break;
case "--BatchSize":
paramTrain.BatchSize = int.Parse(ArgValue);
break;
case "--BatchSize_Test":
paramTrain.BatchSize_Test = int.Parse(ArgValue);
break;
case "--mu_Phi":
paramTrain.mu_Phi = float.Parse(ArgValue);
break;
case "--mu_U":
paramTrain.mu_U = float.Parse(ArgValue);
break;
case "--nSamplesPerDisplay":
paramTrain.nSamplesPerDisplay = int.Parse(ArgValue);
break;
case "--nEpochPerSave":
paramTrain.nEpochPerSave = int.Parse(ArgValue);
break;
case "--nEpochPerTest":
paramTrain.nEpochPerTest = int.Parse(ArgValue);
break;
case "--TrainInputFile":
paramTrain.TrainInputFile = ArgValue;
break;
case "--TestInputFile":
paramTrain.TestInputFile = ArgValue;
break;
case "--TrainLabelFile":
paramTrain.TrainLabelFile = ArgValue;
break;
case "--TestLabelFile":
paramTrain.TestLabelFile = ArgValue;
break;
case "--ResultFile":
ResultFile = ArgValue;
break;
case "--nInput":
paramModel.nInput = int.Parse(ArgValue);
break;
case "--nOutput":
paramModel.nOutput = int.Parse(ArgValue);
break;
case "--OutputType":
paramModel.OutputType = ArgValue;
if (paramModel.OutputType != "softmaxCE" && paramModel.OutputType != "linearQuad" && paramModel.OutputType != "linearCE")
{
throw new Exception("Unknown OutputType for supervised learning. Only softmaxCE/linearQuad/linearCE is supported.");
}
break;
case "--LearnRateSchedule":
paramTrain.LearnRateSchedule = ArgValue;
break;
case "--flag_DumpFeature":
paramTrain.flag_DumpFeature = bool.Parse(ArgValue);
break;
case "--nEpochPerDump":
paramTrain.nEpochPerDump = int.Parse(ArgValue);
break;
case "--BatchSizeSchedule":
paramTrain.flag_BachSizeSchedule = true;
paramTrain.BachSizeSchedule = new Dictionary<int, int>();
string[] StrBatSched = ArgValue.Split(',');
for (int Idx = 0; Idx < StrBatSched.Length; Idx++)
{
string[] KeyValPair = StrBatSched[Idx].Split(':');
paramTrain.BachSizeSchedule.Add(int.Parse(KeyValPair[0]), int.Parse(KeyValPair[1]));
}
break;
case "--ThreadNum":
paramTrain.ThreadNum = int.Parse(ArgValue);
break;
case "--MaxThreadDeg":
paramTrain.MaxMultiThreadDegree = int.Parse(ArgValue);
break;
case "--ExternalEval":
paramTrain.flag_ExternalEval = true;
paramTrain.ExternalEval = ArgValue;
break;
case "--flag_SaveAllModels":
paramTrain.flag_SaveAllModels = bool.Parse(ArgValue);
break;
case "--ValidLabelFile":
paramTrain.ValidLabelFile = ArgValue;
paramTrain.flag_HasValidSet = true;
break;
case "--ValidInputFile":
paramTrain.ValidInputFile = ArgValue;
paramTrain.flag_HasValidSet = true;
break;
case "--T_value":
paramModel.T_value = float.Parse(ArgValue);
break;
case "--eta":
paramModel.eta = float.Parse(ArgValue);
break;
case "--DebugLevel":
paramTrain.DebugLevel = (DebugLevel_t)Enum.Parse(typeof(DebugLevel_t), ArgValue, true);
break;
case "--flag_AdaptivenHidLayer":
paramModel.flag_AdaptivenHidLayer = bool.Parse(ArgValue);
break;
case "--flag_RunningAvg":
paramTrain.flag_RunningAvg = bool.Parse(ArgValue);
break;
default:
Console.WriteLine("Unknown ArgKey: {0}", ArgKey);
Program.DispHelp();
return false;
}
}
if (paramModel.alpha >= 1.0f) {
paramModel.T_value = 1.0f;
paramModel.flag_AdaptivenHidLayer = false;
} else if (paramModel.alpha < 1.0f && paramModel.alpha > 0.0f) {
paramModel.T_value = 0.01f;
paramModel.flag_AdaptivenHidLayer = true;
} else {
throw new Exception ("Invalid alpha.");
}
if (String.IsNullOrEmpty(paramTrain.TrainInputFile) || String.IsNullOrEmpty(paramTrain.TestInputFile)
|| String.IsNullOrEmpty(paramTrain.TrainLabelFile) || String.IsNullOrEmpty(paramTrain.TestLabelFile))
{
Console.WriteLine("Empty TrainInputFile, TestInputFile, TrainLabelFile, or TestLabelFile!");
return false;
}
return true;
}