Accord.Tests.Neuro.ContrastiveDivergenceLearningTest.RunTest2 C# (CSharp) Method

RunTest2() private method

private RunTest2 ( ) : void
return void
        public void RunTest2()
        {
            // Example from Edwin Chen, Introduction to Restricted Boltzmann Machines
            // http://blog.echen.me/2011/07/18/introduction-to-restricted-Boltzmann-machines/

            double[][] inputs =
            {
                new double[] { 1,1,1,0,0,0 },
                new double[] { 1,0,1,0,0,0 },
                new double[] { 1,1,1,0,0,0 },
                new double[] { 0,0,1,1,1,0 },
                new double[] { 0,0,1,1,0,0 },
                new double[] { 0,0,1,1,1,0 }
            };

            Accord.Math.Tools.SetupGenerator(0);
            // BernoulliFunction.Random = new ThreadSafeRandom(0);
            // GaussianFunction.Random.SetSeed(0);

            RestrictedBoltzmannMachine network = 
                RestrictedBoltzmannMachine.CreateGaussianBernoulli(6, 2);

            Accord.Math.Tools.SetupGenerator(0);
            new GaussianWeights(network).Randomize();
            network.UpdateVisibleWeights();

            Accord.Math.Tools.SetupGenerator(0);
            var target = new ContrastiveDivergenceLearning(network);
            target.ParallelOptions.MaxDegreeOfParallelism = 1;

            target.Momentum = 0;
            target.LearningRate = 0.1;
            target.Decay = 0;

            int iterations = 5000;
            double[] errors = new double[iterations];
            for (int i = 0; i < iterations; i++)
                errors[i] = target.RunEpoch(inputs);

            double startError = errors[0];
            double lastError = errors[iterations - 1];
            Assert.IsTrue(startError > lastError);

            {
                double[] output = network.GenerateOutput(new double[] { 0, 0, 0, 1, 1, 0 });
                Assert.AreEqual(2, output.Length);
                Assert.AreEqual(0, output[0]);
                Assert.AreEqual(0, output[1]);
            }

            {
                double[] output = network.GenerateOutput(new double[] { 1, 1, 1, 0, 0, 0 });
                Assert.AreEqual(2, output.Length);
                Assert.AreEqual(1, output[0]);
                Assert.AreEqual(1, output[1]);
            }


        }
ContrastiveDivergenceLearningTest