static void Main(string[] args)
{
Assembly assembly = Assembly.GetExecutingAssembly();
Assembly.LoadFile(Path.GetDirectoryName(assembly.Location) + Path.DirectorySeparatorChar + "MyMediaLiteExperimental.dll");
AppDomain.CurrentDomain.UnhandledException += new UnhandledExceptionEventHandler(Handlers.UnhandledExceptionHandler);
Console.CancelKeyPress += new ConsoleCancelEventHandler(AbortHandler);
// recommender arguments
string method = "BiasedMatrixFactorization";
string recommender_options = string.Empty;
// help/version
bool show_help = false;
bool show_version = false;
// arguments for iteration search
int find_iter = 0;
int max_iter = 500;
double epsilon = 0;
double rmse_cutoff = double.MaxValue;
double mae_cutoff = double.MaxValue;
// data arguments
string data_dir = string.Empty;
string user_attributes_file = string.Empty;
string item_attributes_file = string.Empty;
string user_relations_file = string.Empty;
string item_relations_file = string.Empty;
// other arguments
bool online_eval = false;
bool search_hp = false;
string save_model_file = string.Empty;
string load_model_file = string.Empty;
int random_seed = -1;
string prediction_file = string.Empty;
string prediction_line = "{0}\t{1}\t{2}";
int cross_validation = 0;
double split_ratio = 0;
var p = new OptionSet() {
// string-valued options
{ "training-file=", v => training_file = v },
{ "test-file=", v => test_file = v },
{ "recommender=", v => method = v },
{ "recommender-options=", v => recommender_options += " " + v },
{ "data-dir=", v => data_dir = v },
{ "user-attributes=", v => user_attributes_file = v },
{ "item-attributes=", v => item_attributes_file = v },
{ "user-relations=", v => user_relations_file = v },
{ "item-relations=", v => item_relations_file = v },
{ "save-model=", v => save_model_file = v },
{ "load-model=", v => load_model_file = v },
{ "prediction-file=", v => prediction_file = v },
{ "prediction-line=", v => prediction_line = v },
// integer-valued options
{ "find-iter=", (int v) => find_iter = v },
{ "max-iter=", (int v) => max_iter = v },
{ "random-seed=", (int v) => random_seed = v },
{ "cross-validation=", (int v) => cross_validation = v },
// double-valued options
{ "epsilon=", (double v) => epsilon = v },
{ "rmse-cutoff=", (double v) => rmse_cutoff = v },
{ "mae-cutoff=", (double v) => mae_cutoff = v },
{ "split-ratio=", (double v) => split_ratio = v },
// enum options
{ "rating-type=", (RatingType v) => rating_type = v },
{ "file-format=", (RatingFileFormat v) => file_format = v },
// boolean options
{ "compute-fit", v => compute_fit = v != null },
{ "online-evaluation", v => online_eval = v != null },
{ "search-hp", v => search_hp = v != null },
{ "help", v => show_help = v != null },
{ "version", v => show_version = v != null },
};
IList<string> extra_args = p.Parse(args);
// TODO make sure interaction of --find-iter and --cross-validation works properly
bool no_eval = test_file == null;
if (show_version)
ShowVersion();
if (show_help)
Usage(0);
if (extra_args.Count > 0)
Usage("Did not understand " + extra_args[0]);
if (training_file == null)
Usage("Parameter --training-file=FILE is missing.");
if (cross_validation != 0 && split_ratio != 0)
Usage("--cross-validation=K and --split-ratio=NUM are mutually exclusive.");
if (random_seed != -1)
MyMediaLite.Util.Random.InitInstance(random_seed);
recommender = Recommender.CreateRatingPredictor(method);
if (recommender == null)
Usage(string.Format("Unknown method: '{0}'", method));
Recommender.Configure(recommender, recommender_options, Usage);
// ID mapping objects
if (file_format == RatingFileFormat.KDDCUP_2011)
{
user_mapping = new IdentityMapping();
item_mapping = new IdentityMapping();
}
// load all the data
LoadData(data_dir, user_attributes_file, item_attributes_file, user_relations_file, item_relations_file, !online_eval);
Console.Error.WriteLine(string.Format(CultureInfo.InvariantCulture, "ratings range: [{0}, {1}]", recommender.MinRating, recommender.MaxRating));
if (split_ratio > 0)
{
var split = new RatingsSimpleSplit(training_data, split_ratio);
recommender.Ratings = split.Train[0];
training_data = split.Train[0];
test_data = split.Test[0];
}
Utils.DisplayDataStats(training_data, test_data, recommender);
if (find_iter != 0)
{
if ( !(recommender is IIterativeModel) )
Usage("Only iterative recommenders support find_iter.");
var iterative_recommender = (IIterativeModel) recommender;
Console.WriteLine(recommender.ToString() + " ");
if (load_model_file == string.Empty)
recommender.Train();
else
Recommender.LoadModel(iterative_recommender, load_model_file);
if (compute_fit)
Console.Write(string.Format(CultureInfo.InvariantCulture, "fit {0,0:0.#####} ", iterative_recommender.ComputeFit()));
MyMediaLite.Eval.Ratings.DisplayResults(MyMediaLite.Eval.Ratings.Evaluate(recommender, test_data));
Console.WriteLine(" iteration " + iterative_recommender.NumIter);
for (int i = (int) iterative_recommender.NumIter + 1; i <= max_iter; i++)
{
TimeSpan time = Utils.MeasureTime(delegate() {
iterative_recommender.Iterate();
});
training_time_stats.Add(time.TotalSeconds);
if (i % find_iter == 0)
{
if (compute_fit)
{
double fit = 0;
time = Utils.MeasureTime(delegate() {
fit = iterative_recommender.ComputeFit();
});
fit_time_stats.Add(time.TotalSeconds);
Console.Write(string.Format(CultureInfo.InvariantCulture, "fit {0,0:0.#####} ", fit));
}
Dictionary<string, double> results = null;
time = Utils.MeasureTime(delegate() { results = MyMediaLite.Eval.Ratings.Evaluate(recommender, test_data); });
eval_time_stats.Add(time.TotalSeconds);
MyMediaLite.Eval.Ratings.DisplayResults(results);
rmse_eval_stats.Add(results["RMSE"]);
Console.WriteLine(" iteration " + i);
Recommender.SaveModel(recommender, save_model_file, i);
if (prediction_file != string.Empty)
Prediction.WritePredictions(recommender, test_data, user_mapping, item_mapping, prediction_line, prediction_file + "-it-" + i);
if (epsilon > 0.0 && results["RMSE"] - rmse_eval_stats.Min() > epsilon)
{
Console.Error.WriteLine(string.Format(CultureInfo.InvariantCulture, "{0} >> {1}", results["RMSE"], rmse_eval_stats.Min()));
Console.Error.WriteLine("Reached convergence on training/validation data after {0} iterations.", i);
break;
}
if (results["RMSE"] > rmse_cutoff || results["MAE"] > mae_cutoff)
{
Console.Error.WriteLine("Reached cutoff after {0} iterations.", i);
break;
}
}
} // for
DisplayStats();
}
else
{
TimeSpan seconds;
if (load_model_file == string.Empty)
{
if (cross_validation > 0)
{
Console.Write(recommender.ToString());
Console.WriteLine();
var split = new RatingCrossValidationSplit(training_data, cross_validation);
var results = MyMediaLite.Eval.Ratings.EvaluateOnSplit(recommender, split); // TODO if (search_hp)
MyMediaLite.Eval.Ratings.DisplayResults(results);
no_eval = true;
recommender.Ratings = training_data;
}
else
{
if (search_hp)
{
// TODO --search-hp-criterion=RMSE
double result = NelderMead.FindMinimum("RMSE", recommender);
Console.Error.WriteLine("estimated quality (on split) {0}", result.ToString(CultureInfo.InvariantCulture));
// TODO give out hp search time
}
Console.Write(recommender.ToString());
seconds = Utils.MeasureTime( delegate() { recommender.Train(); } );
Console.Write(" training_time " + seconds + " ");
}
}
else
{
Recommender.LoadModel(recommender, load_model_file);
Console.Write(recommender.ToString() + " ");
}
if (!no_eval)
{
if (online_eval) // TODO support also for prediction outputs (to allow external evaluation)
seconds = Utils.MeasureTime(delegate() { MyMediaLite.Eval.Ratings.DisplayResults(MyMediaLite.Eval.Ratings.EvaluateOnline(recommender, test_data)); });
else
seconds = Utils.MeasureTime(delegate() { MyMediaLite.Eval.Ratings.DisplayResults(MyMediaLite.Eval.Ratings.Evaluate(recommender, test_data)); });
Console.Write(" testing_time " + seconds);
}
if (compute_fit)
{
Console.Write("fit ");
seconds = Utils.MeasureTime(delegate() {
MyMediaLite.Eval.Ratings.DisplayResults(MyMediaLite.Eval.Ratings.Evaluate(recommender, training_data));
});
Console.Write(string.Format(CultureInfo.InvariantCulture, " fit_time {0,0:0.#####} ", seconds));
}
if (prediction_file != string.Empty)
{
seconds = Utils.MeasureTime(delegate() {
Console.WriteLine();
Prediction.WritePredictions(recommender, test_data, user_mapping, item_mapping, prediction_line, prediction_file);
});
Console.Error.Write("predicting_time " + seconds);
}
Console.WriteLine();
Console.Error.WriteLine("memory {0}", Memory.Usage);
}
Recommender.SaveModel(recommender, save_model_file);
}