Accord.MachineLearning.DecisionTrees.Pruning.ReducedErrorPruning.Run C# (CSharp) Method

Run() public method

Computes one pass of the pruning algorithm.
public Run ( ) : double
return double
        public double Run()
        {
            // Compute misclassifications at each node
            foreach (var node in tree)
                info[node].error = computeError(node);

            // Compute the gain at each node
            foreach (var node in tree)
                info[node].gain = computeGain(node);

            // Get maximum violating node
            double maxGain = Double.NegativeInfinity;
            DecisionNode maxNode = null;
            foreach (var node in tree)
            {
                double gain = info[node].gain;

                if (gain > maxGain)
                {
                    maxGain = gain;
                    maxNode = node;
                }
            }

            if (maxGain >= 0 && maxNode != null)
            {
                int[] o = outputs.Get(info[maxNode].subset.ToArray());

                // prune the maximum gain node
                int common = Measures.Mode(o);

                maxNode.Branches = null;
                maxNode.Output = common;
            }

            return computeError();
        }

Usage Example

        public void RunTest()
        {
            double[][] inputs;
            int[] outputs;

            int training = 6000;
            DecisionTree tree = createNurseryExample(out inputs, out outputs, training);

            int nodeCount = 0;
            foreach (var node in tree)
                nodeCount++;

            var pruningInputs = inputs.Submatrix(training, inputs.Length - 1);
            var pruningOutputs = outputs.Submatrix(training, inputs.Length - 1);
            var prune = new ReducedErrorPruning(tree, pruningInputs, pruningOutputs);

            double lastError, error = Double.PositiveInfinity;
            do
            {
                lastError = error;
                error = prune.Run();
            } while (error <= lastError);

            int nodeCount2 = 0;
            foreach (var node in tree)
                nodeCount2++;

            Assert.AreEqual(0.19454022988505748, error);
            Assert.AreEqual(447, nodeCount);
            Assert.AreEqual(4, nodeCount2);
        }