public Forward ( MPSImage inputImage = null, int imageNum = 9999, int correctLabel = 10 ) : uint | ||
inputImage | MPSImage | Image coming in on which the network will run |
imageNum | int | If the test set is being used we will get a value between 0 and 9999 for which of the 10,000 images is being evaluated |
correctLabel | int | The correct label for the inputImage while testing |
return | uint |
public virtual uint Forward (MPSImage inputImage = null, int imageNum = 9999, int correctLabel = 10)
{
uint label = 99;
// Get command buffer to use in MetalPerformanceShaders.
using (var commandBuffer = commandQueue.CommandBuffer ()) {
// output will be stored in this image
var finalLayer = new MPSImage (commandBuffer.Device, DID);
// encode layers to metal commandBuffer
if (inputImage == null)
layer.EncodeToCommandBuffer (commandBuffer, SrcImage, dstImage);
else
layer.EncodeToCommandBuffer (commandBuffer, inputImage, dstImage);
softmax.EncodeToCommandBuffer (commandBuffer, dstImage, finalLayer);
// add a completion handler to get the correct label the moment GPU is done and compare it to the correct output or return it
commandBuffer.AddCompletedHandler (buffer => {
label = GetLabel (finalLayer);
if (correctLabel == label)
Atomics.Increment ();
});
// commit commandbuffer to run on GPU and wait for completion
commandBuffer.Commit ();
if (imageNum == 9999 || inputImage == null)
commandBuffer.WaitUntilCompleted ();
}
return label;
}
partial void TappedDetectDigit(UIButton sender) { // get the digitView context so we can get the pixel values from it to intput to network var context = DigitView.GetViewContext(); // validate NeuralNetwork was initialized properly if (runningNet == null) { throw new InvalidProgramException(); } // putting input into MTLTexture in the MPSImage var region = new MTLRegion(new MTLOrigin(0, 0, 0), new MTLSize((nint)mnistInputWidth, mnistInputHeight, 1)); runningNet.SrcImage.Texture.ReplaceRegion(region, level: 0, slice: 0, pixelBytes: context.Data, bytesPerRow: mnistInputWidth, bytesPerImage: 0); // run the network forward pass var label = runningNet.Forward(); // show the prediction PredictionLabel.Text = $"{label}"; PredictionLabel.Hidden = false; }