DigitDetection.MnistFullLayerNeuralNetwork.Forward C# (CSharp) Method

Forward() public method

This function encodes all the layers of the network into given commandBuffer, it calls subroutines for each piece of the network Returns: Guess of the network as to what the digit is as UInt
public Forward ( MPSImage inputImage = null, int imageNum = 9999, int correctLabel = 10 ) : uint
inputImage MPSImage Image coming in on which the network will run
imageNum int If the test set is being used we will get a value between 0 and 9999 for which of the 10,000 images is being evaluated
correctLabel int The correct label for the inputImage while testing
return uint
		public virtual uint Forward (MPSImage inputImage = null, int imageNum = 9999, int correctLabel = 10)
		{
			uint label = 99;

			// Get command buffer to use in MetalPerformanceShaders.
			using (var commandBuffer = commandQueue.CommandBuffer ()) {
				// output will be stored in this image
				var finalLayer = new MPSImage (commandBuffer.Device, DID);

				// encode layers to metal commandBuffer
				if (inputImage == null)
					layer.EncodeToCommandBuffer (commandBuffer, SrcImage, dstImage);
				else
					layer.EncodeToCommandBuffer (commandBuffer, inputImage, dstImage);

				softmax.EncodeToCommandBuffer (commandBuffer, dstImage, finalLayer);

				// add a completion handler to get the correct label the moment GPU is done and compare it to the correct output or return it
				commandBuffer.AddCompletedHandler (buffer => {
					label = GetLabel (finalLayer);

					if (correctLabel == label)
						Atomics.Increment ();
				});

				// commit commandbuffer to run on GPU and wait for completion
				commandBuffer.Commit ();
				if (imageNum == 9999 || inputImage == null)
					commandBuffer.WaitUntilCompleted ();
			}

			return label;
		}

Usage Example

Ejemplo n.º 1
0
        partial void TappedDetectDigit(UIButton sender)
        {
            // get the digitView context so we can get the pixel values from it to intput to network
            var context = DigitView.GetViewContext();

            // validate NeuralNetwork was initialized properly
            if (runningNet == null)
            {
                throw new InvalidProgramException();
            }

            // putting input into MTLTexture in the MPSImage
            var region = new MTLRegion(new MTLOrigin(0, 0, 0), new MTLSize((nint)mnistInputWidth, mnistInputHeight, 1));

            runningNet.SrcImage.Texture.ReplaceRegion(region,
                                                      level: 0,
                                                      slice: 0,
                                                      pixelBytes: context.Data,
                                                      bytesPerRow: mnistInputWidth,
                                                      bytesPerImage: 0);
            // run the network forward pass
            var label = runningNet.Forward();

            // show the prediction
            PredictionLabel.Text   = $"{label}";
            PredictionLabel.Hidden = false;
        }