public void multilabel_gaussian_new_usage()
{
#region doc_learn_gaussian
// Let's say we have the following data to be classified
// into three possible classes. Those are the samples:
//
double[][] inputs =
{
// input output
new double[] { 0, 1, 1, 0 }, // 0
new double[] { 0, 1, 0, 0 }, // 0
new double[] { 0, 0, 1, 0 }, // 0
new double[] { 0, 1, 1, 0 }, // 0
new double[] { 0, 1, 0, 0 }, // 0
new double[] { 1, 0, 0, 0 }, // 1
new double[] { 1, 0, 0, 0 }, // 1
new double[] { 1, 0, 0, 1 }, // 1
new double[] { 0, 0, 0, 1 }, // 1
new double[] { 0, 0, 0, 1 }, // 1
new double[] { 1, 1, 1, 1 }, // 2
new double[] { 1, 0, 1, 1 }, // 2
new double[] { 1, 1, 0, 1 }, // 2
new double[] { 0, 1, 1, 1 }, // 2
new double[] { 1, 1, 1, 1 }, // 2
};
int[] outputs = // those are the class labels
{
0, 0, 0, 0, 0,
1, 1, 1, 1, 1,
2, 2, 2, 2, 2,
};
// Create the multi-class learning algorithm for the machine
var teacher = new MulticlassSupportVectorLearning<Gaussian>()
{
// Configure the learning algorithm to use SMO to train the
// underlying SVMs in each of the binary class subproblems.
Learner = (param) => new SequentialMinimalOptimization<Gaussian>()
{
// Estimate a suitable guess for the Gaussian kernel's parameters.
// This estimate can serve as a starting point for a grid search.
UseKernelEstimation = true
}
};
// Configure parallel execution options
teacher.ParallelOptions.MaxDegreeOfParallelism = 1;
// Learn a machine
var machine = teacher.Learn(inputs, outputs);
// Obtain class predictions for each sample
int[] predicted = machine.Decide(inputs);
// Get class scores for each sample
double[] scores = machine.Score(inputs);
// Compute classification error
double error = new ZeroOneLoss(outputs).Loss(predicted);
#endregion
// Get log-likelihoods (should be same as scores)
double[][] logl = machine.LogLikelihoods(inputs);
// Get probability for each sample
double[][] prob = machine.Probabilities(inputs);
// Compute classification error
double loss = new CategoryCrossEntropyLoss(outputs).Loss(prob);
string str = scores.ToCSharp();
double[] expectedScores =
{
1.00888999727541, 1.00303259868784, 1.00068403386636, 1.00888999727541,
1.00303259868784, 1.00831890183328, 1.00831890183328, 0.843757409449037,
0.996768862332386, 0.996768862332386, 1.02627325826713, 1.00303259868784,
0.996967401312164, 0.961947708617365, 1.02627325826713
};
double[][] expectedLogL =
{
new double[] { 1.00888999727541, -1.00888999727541, -1.00135670089335 },
new double[] { 1.00303259868784, -0.991681098166717, -1.00303259868784 },
new double[] { 1.00068403386636, -0.54983354268499, -1.00068403386636 },
new double[] { 1.00888999727541, -1.00888999727541, -1.00135670089335 },
new double[] { 1.00303259868784, -0.991681098166717, -1.00303259868784 },
new double[] { -1.00831890183328, 1.00831890183328, -0.0542719287771535 },
new double[] { -1.00831890183328, 1.00831890183328, -0.0542719287771535 },
new double[] { -0.843757409449037, 0.843757409449037, -0.787899083913034 },
new double[] { -0.178272229157676, 0.996768862332386, -0.996768862332386 },
new double[] { -0.178272229157676, 0.996768862332386, -0.996768862332386 },
new double[] { -1.02627325826713, -1.00323113766761, 1.02627325826713 },
new double[] { -1.00303259868784, -0.38657999872922, 1.00303259868784 },
new double[] { -0.996967401312164, -0.38657999872922, 0.996967401312164 },
new double[] { -0.479189991343958, -0.961947708617365, 0.961947708617365 },
new double[] { -1.02627325826713, -1.00323113766761, 1.02627325826713 }
};
double[][] expectedProbs =
{
new double[] { 0.789324598208647, 0.104940932711551, 0.105734469079803 },
new double[] { 0.78704862182644, 0.107080012017624, 0.105871366155937 },
new double[] { 0.74223157627093, 0.157455631737191, 0.100312791991879 },
new double[] { 0.789324598208647, 0.104940932711551, 0.105734469079803 },
new double[] { 0.78704862182644, 0.107080012017624, 0.105871366155937 },
new double[] { 0.0900153422818135, 0.676287261796794, 0.233697395921392 },
new double[] { 0.0900153422818135, 0.676287261796794, 0.233697395921392 },
new double[] { 0.133985810363445, 0.72433118122885, 0.141683008407705 },
new double[] { 0.213703968297751, 0.692032433073136, 0.0942635986291124 },
new double[] { 0.213703968297751, 0.692032433073136, 0.0942635986291124 },
new double[] { 0.10192623206507, 0.104302095948601, 0.79377167198633 },
new double[] { 0.0972161784678357, 0.180077937396817, 0.722705884135347 },
new double[] { 0.0981785890979593, 0.180760971768703, 0.721060439133338 },
new double[] { 0.171157270099157, 0.105617610634377, 0.723225119266465 },
new double[] { 0.10192623206507, 0.104302095948601, 0.79377167198633 }
};
Assert.AreEqual(0, error);
Assert.AreEqual(4.5289447815997672, loss, 1e-10);
Assert.IsTrue(predicted.IsEqual(outputs));
Assert.IsTrue(expectedScores.IsEqual(scores, 1e-10));
Assert.IsTrue(expectedLogL.IsEqual(logl, 1e-10));
Assert.IsTrue(expectedProbs.IsEqual(prob, 1e-10));
}