Accord.Tests.MachineLearning.KNearestNeighborTest.KNearestNeighbor_CrossValidation C# (CSharp) Method

KNearestNeighbor_CrossValidation() private method

private KNearestNeighbor_CrossValidation ( ) : void
return void
        public void KNearestNeighbor_CrossValidation()
        {
            // Create some sample learning data. In this data,
            // the first two instances belong to a class, the
            // four next belong to another class and the last
            // three to yet another.

            double[][] inputs = 
            {
                // The first two are from class 0
                new double[] { -5, -2, -1 },
                new double[] { -5, -5, -6 },

                // The next four are from class 1
                new double[] {  2,  1,  1 },
                new double[] {  1,  1,  2 },
                new double[] {  1,  2,  2 },
                new double[] {  3,  1,  2 },

                // The last three are from class 2
                new double[] { 11,  5,  4 },
                new double[] { 15,  5,  6 },
                new double[] { 10,  5,  6 },
            };

            int[] outputs =
            {
                0, 0,        // First two from class 0
                1, 1, 1, 1,  // Next four from class 1
                2, 2, 2      // Last three from class 2
            };



            // Create a new Cross-validation algorithm passing the data set size and the number of folds
            var crossvalidation = new CrossValidation(size: inputs.Length, folds: 3);

            // Define a fitting function using Support Vector Machines. The objective of this
            // function is to learn a SVM in the subset of the data indicated by cross-validation.

            crossvalidation.Fitting = delegate(int k, int[] indicesTrain, int[] indicesValidation)
            {
                // The fitting function is passing the indices of the original set which
                // should be considered training data and the indices of the original set
                // which should be considered validation data.

                // Lets now grab the training data:
                var trainingInputs = inputs.Submatrix(indicesTrain);
                var trainingOutputs = outputs.Submatrix(indicesTrain);

                // And now the validation data:
                var validationInputs = inputs.Submatrix(indicesValidation);
                var validationOutputs = outputs.Submatrix(indicesValidation);

                // Now we will create the K-Nearest Neighbors algorithm. For this
                // example, we will be choosing k = 4. This means that, for a given
                // instance, its nearest 4 neighbors will be used to cast a decision.
                KNearestNeighbors knn = new KNearestNeighbors(k: 4, classes: 3,
                    inputs: inputs, outputs: outputs);


                // After the algorithm has been created, we can classify instances:
                int[] train_predicted = trainingInputs.Apply(knn.Compute);
                int[] test_predicted = validationInputs.Apply(knn.Compute);

                // Compute classification error
                var cmTrain = new ConfusionMatrix(train_predicted, trainingOutputs);
                double trainingAcc = cmTrain.Accuracy;

                // Now we can compute the validation error on the validation data:
                var cmTest = new ConfusionMatrix(test_predicted, validationOutputs);
                double validationAcc = cmTest.Accuracy;

                // Return a new information structure containing the model and the errors achieved.
                return new CrossValidationValues(knn, trainingAcc, validationAcc);
            };


            // Compute the cross-validation
            var result = crossvalidation.Compute();

            // Finally, access the measured performance.
            double trainingAccs = result.Training.Mean;
            double validationAccs = result.Validation.Mean;


            Assert.AreEqual(1, trainingAccs);
            Assert.AreEqual(1, validationAccs);
        }
    }