public void RegressTest2()
{
double[][] inputs;
int[] outputs;
CreateInputOutputsExample1(out inputs, out outputs);
// Create a new Multinomial Logistic Regression for 3 categories
var mlr = new MultinomialLogisticRegression(inputs: 2, categories: 3);
// Create a estimation algorithm to estimate the regression
LowerBoundNewtonRaphson lbnr = new LowerBoundNewtonRaphson(mlr);
// Now, we will iteratively estimate our model. The Run method returns
// the maximum relative change in the model parameters and we will use
// it as the convergence criteria.
double delta;
int iteration = 0;
do
{
// Perform an iteration
delta = lbnr.Run(inputs, outputs);
iteration++;
} while (iteration < 100 && delta > 1e-6);
Assert.AreEqual(52, iteration);
Assert.IsFalse(double.IsNaN(mlr.Coefficients[0][0]));
Assert.IsFalse(double.IsNaN(mlr.Coefficients[0][1]));
Assert.IsFalse(double.IsNaN(mlr.Coefficients[0][2]));
Assert.IsFalse(double.IsNaN(mlr.Coefficients[1][0]));
Assert.IsFalse(double.IsNaN(mlr.Coefficients[1][1]));
Assert.IsFalse(double.IsNaN(mlr.Coefficients[1][2]));
// This is the same example given in R Data Analysis Examples for
// Multinomial Logistic Regression: http://www.ats.ucla.edu/stat/r/dae/mlogit.htm
// brand 2
Assert.AreEqual(-11.774655, mlr.Coefficients[0][0], 1e-4); // intercept
Assert.AreEqual(0.523814, mlr.Coefficients[0][1], 1e-4); // female
Assert.AreEqual(0.368206, mlr.Coefficients[0][2], 1e-4); // age
// brand 3
Assert.AreEqual(-22.721396, mlr.Coefficients[1][0], 1e-4); // intercept
Assert.AreEqual(0.465941, mlr.Coefficients[1][1], 1e-4); // female
Assert.AreEqual(0.685908, mlr.Coefficients[1][2], 1e-4); // age
Assert.IsFalse(double.IsNaN(mlr.StandardErrors[0][0]));
Assert.IsFalse(double.IsNaN(mlr.StandardErrors[0][1]));
Assert.IsFalse(double.IsNaN(mlr.StandardErrors[0][2]));
Assert.IsFalse(double.IsNaN(mlr.StandardErrors[1][0]));
Assert.IsFalse(double.IsNaN(mlr.StandardErrors[1][1]));
Assert.IsFalse(double.IsNaN(mlr.StandardErrors[1][2]));
/*
// Using the standard Hessian estimation
Assert.AreEqual(1.774612, mlr.StandardErrors[0][0], 1e-6);
Assert.AreEqual(0.194247, mlr.StandardErrors[0][1], 1e-6);
Assert.AreEqual(0.055003, mlr.StandardErrors[0][2], 1e-6);
Assert.AreEqual(2.058028, mlr.StandardErrors[1][0], 1e-6);
Assert.AreEqual(0.226090, mlr.StandardErrors[1][1], 1e-6);
Assert.AreEqual(0.062627, mlr.StandardErrors[1][2], 1e-6);
*/
// Using the lower-bound approximation
Assert.AreEqual(1.047378039787443, mlr.StandardErrors[0][0], 1e-6);
Assert.AreEqual(0.153150051082552, mlr.StandardErrors[0][1], 1e-6);
Assert.AreEqual(0.031640507386863, mlr.StandardErrors[0][2], 1e-6);
Assert.AreEqual(1.047378039787443, mlr.StandardErrors[1][0], 1e-6);
Assert.AreEqual(0.153150051082552, mlr.StandardErrors[1][1], 1e-6);
Assert.AreEqual(0.031640507386863, mlr.StandardErrors[1][2], 1e-6);
double ll = mlr.GetLogLikelihood(inputs, outputs);
Assert.AreEqual(-702.97, ll, 1e-2);
Assert.IsFalse(double.IsNaN(ll));
var chi = mlr.ChiSquare(inputs, outputs);
Assert.AreEqual(185.85, chi.Statistic, 1e-2);
Assert.IsFalse(double.IsNaN(chi.Statistic));
var wald00 = mlr.GetWaldTest(0, 0);
var wald01 = mlr.GetWaldTest(0, 1);
var wald02 = mlr.GetWaldTest(0, 2);
var wald10 = mlr.GetWaldTest(1, 0);
var wald11 = mlr.GetWaldTest(1, 1);
var wald12 = mlr.GetWaldTest(1, 2);
Assert.IsFalse(double.IsNaN(wald00.Statistic));
Assert.IsFalse(double.IsNaN(wald01.Statistic));
Assert.IsFalse(double.IsNaN(wald02.Statistic));
Assert.IsFalse(double.IsNaN(wald10.Statistic));
Assert.IsFalse(double.IsNaN(wald11.Statistic));
Assert.IsFalse(double.IsNaN(wald12.Statistic));
/*
// Using standard Hessian estimation
Assert.AreEqual(-6.6351, wald00.Statistic, 1e-4);
Assert.AreEqual( 2.6966, wald01.Statistic, 1e-4);
Assert.AreEqual( 6.6943, wald02.Statistic, 1e-4);
Assert.AreEqual(-11.0404, wald10.Statistic, 1e-4);
Assert.AreEqual( 2.0609, wald11.Statistic, 1e-4);
Assert.AreEqual(10.9524, wald12.Statistic, 1e-4);
*/
// Using Lower-Bound approximation
Assert.AreEqual(-11.241995503283842, wald00.Statistic, 1e-4);
Assert.AreEqual(3.4202662152119889, wald01.Statistic, 1e-4);
Assert.AreEqual(11.637150673342207, wald02.Statistic, 1e-4);
Assert.AreEqual(-21.693553825772664, wald10.Statistic, 1e-4);
Assert.AreEqual(3.0423802097069097, wald11.Statistic, 1e-4);
Assert.AreEqual(21.678124991086548, wald12.Statistic, 1e-4);
}