DigitDetection.MnistFullLayerNeuralNetwork.GetLabel C# (CSharp) Method

GetLabel() public method

This function reads the output probabilities from finalLayer to CPU, sorts them and gets the label with heighest probability
public GetLabel ( MPSImage finalLayer ) : uint
finalLayer MPSImage output image of the network this has probabilities of each digit
return uint
		public uint GetLabel (MPSImage finalLayer)
		{
			// even though we have 10 labels outputed the MTLTexture format used is RGBAFloat16 thus 3 slices will have 3*4 = 12 outputs
			var resultHalfArray = Enumerable.Repeat ((ushort)6, 12).ToArray ();
			var resultHalfArrayHandle = GCHandle.Alloc (resultHalfArray, GCHandleType.Pinned);
			var resultHalfArrayPtr = resultHalfArrayHandle.AddrOfPinnedObject ();

			var resultFloatArray = Enumerable.Repeat (0.3f, 10).ToArray ();
			var resultFloatArrayHandle = GCHandle.Alloc (resultFloatArray, GCHandleType.Pinned);
			var resultFloatArrayPtr = resultFloatArrayHandle.AddrOfPinnedObject ();

			for (uint i = 0; i <= 2; i++) {
				finalLayer.Texture.GetBytes (resultHalfArrayPtr + 4 * (int)i * sizeof (ushort),
											sizeof (ushort) * 1 * 4, sizeof (ushort) * 1 * 1 * 4,
											new MTLRegion (new MTLOrigin (0, 0, 0), new MTLSize (1, 1, 1)),
											0, i);
			}

			// we use vImage to convert our data to float16, Metal GPUs use float16 and swift float is 32-bit
			var fullResultVImagebuf = new vImageBuffer {
				Data = resultFloatArrayPtr,
				Height = 1,
				Width = 10,
				BytesPerRow = 10 * 4
			};

			var halfResultVImagebuf = new vImageBuffer {
				Data = resultHalfArrayPtr,
				Height = 1,
				Width = 10,
				BytesPerRow = 10 * 2
			};

			if (Planar16FtoPlanarF (ref halfResultVImagebuf, ref fullResultVImagebuf, 0) != vImageError.NoError)
				Console.WriteLine ("Error in vImage");

			// poll all labels for probability and choose the one with max probability to return
			float max = 0f;
			uint mostProbableDigit = 10;

			for (uint i = 0; i <= 9; i++) {
				if (max < resultFloatArray [i]) {
					max = resultFloatArray [i];
					mostProbableDigit = i;
				}
			}

			resultHalfArrayHandle.Free ();
			resultFloatArrayHandle.Free ();

			return mostProbableDigit;
		}
	}