{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# CIFAR-10 Inference with Bit-Sliced Inputs on HiAER-Spike Hardware\n\nThis example demonstrates running inference on the CIFAR-10 test set using a\nfully-connected spiking neural network (16-100-100) deployed on HiAER-Spike\nneuromorphic hardware.\n\nEach 3-channel RGB image is converted to a 15-channel binary representation\nby extracting the 5 most significant bit planes from each colour channel\n(bit-slicing). The resulting binary tensor is fed to the hardware\ntimestep-by-timestep, and the output neuron with the highest spike rate\ndetermines the predicted class.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# sphinx_gallery_thumbnail_path = '_static/cifar10_bitslice_thumb.png'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Importing the necessary libraries\nWe need PyTorch and torchvision for data loading and preprocessing,\n``hs_api`` for the CRI network interface, and ``hs_bridge`` to reset\nmembrane potentials between samples on the FPGA.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import io\nimport pickle\nimport urllib.request\nimport torch\nimport torchvision\nimport torchvision.transforms as transforms\nfrom PIL import Image\nimport hs_bridge\nfrom hs_api.api import CRI_network" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Configuration\nSelect the compute device and set inference hyperparameters.\n``T`` is the number of active input timesteps; ``extra_timesteps`` are\nadditional empty steps used to drain spikes still propagating through\nthe network after the last input frame.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\nT = 30\nextra_timesteps = 5\nMODEL_CONFIG_URL = 'https://www.dropbox.com/scl/fi/dqurfqfye1omyyozj0ps6/CIFAR10_model_config.pkl?rlkey=7lhl1ksk4n4j5evpyocwrh2y9&st=zcphe718&dl=1'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Defining the bit-slicing utilities\nCIFAR-10 images are 8-bit per channel. Rather than thresholding to a single\nbinary value, we extract the 5 most significant bit planes from each of the\nR, G, and B channels, producing a 15-channel binary tensor per image.\nThis preserves more colour information while keeping inputs binary.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class Cutout(object):\n \"\"\"Randomly cut out square regions from an image tensor (data augmentation).\n\n Args:\n n_holes (int): Number of square holes to cut out.\n length (int): Side length of each square hole.\n \"\"\"\n\n def __init__(self, n_holes: int, length: int):\n self.n_holes = n_holes\n self.length = length\n\n def __call__(self, img) -> torch.Tensor:\n if isinstance(img, Image.Image):\n img = transforms.ToTensor()(img)\n elif not isinstance(img, torch.Tensor):\n raise TypeError(\n f\"Input must be a PIL Image or torch.Tensor, got {type(img)}\"\n )\n\n C, H, W = img.shape\n mask = torch.ones((H, W), dtype=torch.bool, device=img.device)\n\n for _ in range(self.n_holes):\n y = torch.randint(H, (1,), device=img.device).item()\n x = torch.randint(W, (1,), device=img.device).item()\n y1 = max(0, y - self.length // 2)\n y2 = min(H, y + self.length // 2)\n x1 = max(0, x - self.length // 2)\n x2 = min(W, x + self.length // 2)\n mask[y1:y2, x1:x2] = False\n\n return img * mask.unsqueeze(0).expand_as(img)\n\n def __repr__(self):\n return f\"{self.__class__.__name__}(n_holes={self.n_holes}, length={self.length})\"\n\n\ndef extract_bit_planes(\n tensor: torch.Tensor,\n num_bits: int = 8,\n msb_first: bool = True,\n target_dtype: torch.dtype = torch.float32,\n) -> torch.Tensor:\n \"\"\"Extract individual bit planes from a uint8 tensor.\n\n Args:\n tensor: uint8 input tensor of any shape.\n num_bits: Number of bit planes to extract (1\u20138).\n msb_first: If ``True``, extract from the most significant bit downward.\n target_dtype: dtype of the returned tensor.\n\n Returns:\n Tensor of shape ``(num_bits, *tensor.shape)`` containing 0/1 values.\n \"\"\"\n if not isinstance(tensor, torch.Tensor):\n raise TypeError(f\"Expected torch.Tensor, got {type(tensor)}\")\n if tensor.dtype != torch.uint8:\n raise ValueError(f\"Input dtype must be torch.uint8, got {tensor.dtype}\")\n if not (1 <= num_bits <= 8):\n raise ValueError(f\"num_bits must be between 1 and 8, got {num_bits}\")\n\n tensor = tensor.contiguous()\n dev = tensor.device\n\n if msb_first:\n shifts = torch.arange(7, 7 - num_bits, -1, device=dev, dtype=torch.uint8)\n else:\n shifts = torch.arange(0, num_bits, device=dev, dtype=torch.uint8)\n\n shift_shape = (num_bits,) + (1,) * tensor.ndim\n bit_planes = (tensor >> shifts.view(shift_shape)) & 1\n return bit_planes.to(target_dtype)\n\n\ndef cifar10_to_15channel_binary(image_tensor: torch.Tensor) -> torch.Tensor:\n \"\"\"Convert a uint8 RGB image to a 15-channel binary tensor via bit-slicing.\n\n Takes the 5 most significant bit planes from each of the R, G, and B\n channels, yielding a ``(15, H, W)`` float32 tensor.\n\n Args:\n image_tensor: uint8 tensor of shape ``(3, H, W)``.\n\n Returns:\n float32 tensor of shape ``(15, H, W)`` with values in {0, 1}.\n \"\"\"\n if image_tensor.dtype != torch.uint8:\n raise TypeError(f\"Input must be uint8, got {image_tensor.dtype}\")\n\n r, g, b = torch.unbind(image_tensor, dim=0)\n r_bits = extract_bit_planes(r, num_bits=5, msb_first=True)\n g_bits = extract_bit_planes(g, num_bits=5, msb_first=True)\n b_bits = extract_bit_planes(b, num_bits=5, msb_first=True)\n return torch.cat([r_bits, g_bits, b_bits], dim=0)\n\n\nclass Binarize(object):\n \"\"\"Threshold a float tensor to binary values.\n\n Args:\n threshold (float): Values above this become 1, others become 0.\n \"\"\"\n\n def __init__(self, threshold=0.5):\n self.th = threshold\n\n def __call__(self, x: torch.Tensor) -> torch.Tensor:\n return (x > self.th).float()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Defining the preprocessing transform\nImages are resized to 32\u00d732 (already the CIFAR-10 native size), converted\nto uint8 tensors with ``PILToTensor``, bit-sliced into 15 channels, and\nfinally binarized with a 0.5 threshold.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "transform = transforms.Compose([\n transforms.Resize((32, 32)),\n transforms.PILToTensor(),\n cifar10_to_15channel_binary,\n Binarize(threshold=0.5),\n])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Loading the dataset\nWe use the torchvision CIFAR-10 test split. After the transform each sample\nis a ``(15, 32, 32)`` binary tensor.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "test_dataset = torchvision.datasets.CIFAR10(\n root='./data',\n train=False,\n transform=transform,\n download=True,\n)\n\nC, H, W = test_dataset[0][0].shape\nprint(f\"Input shape: {(C, H, W)}\")\nprint(f\"Test samples: {len(test_dataset)}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Loading the model configuration\nThe CRI network topology (axons, connections, output neuron IDs) was\nserialised to disk during the ``hs_api`` conversion step and is loaded here.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "print(\"Loading model configuration...\")\nwith urllib.request.urlopen(MODEL_CONFIG_URL) as response:\n model_config = pickle.load(io.BytesIO(response.read()))\n\naxons = model_config['axons']\nconnections = model_config['connections']\noutputs = model_config['outputs']" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Creating the CRI network\nA ``CRI_network`` object is initialised with the loaded topology and\ntargets the physical CRI hardware (``target=\"CRI\"``).\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "print(\"Creating CRI network...\")\nnetwork = CRI_network(\n axons=axons,\n connections=connections,\n outputs=outputs,\n target=\"CRI\",\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Running inference\nFor each test image we:\n\n1. Clear the FPGA membrane potentials.\n2. Flatten the 15-channel binary image and feed active pixel indices as axon\n IDs for each of the ``T`` timesteps.\n3. Accumulate output spikes across all timesteps plus ``extra_timesteps``\n drain steps.\n4. Classify by the output neuron with the highest average spike rate.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "print(\"Running inference on test set...\")\n\ncorrect = 0\ntotal = 0\n\nfor img, label in test_dataset:\n # Reset membrane potentials before each new sample\n hs_bridge.FPGA_Execution.fpga_controller.clear(\n len(connections), False, 0\n ) # num_neurons, simDump, coreOverride\n\n img = img.to(device) # [C, H, W]\n flat = img.unsqueeze(0).flatten(start_dim=1) # [1, C*H*W]\n spike_counts = torch.zeros(len(outputs))\n\n # Build spike list once \u2014 same binary image is replayed each timestep\n inputs = [f\"A{i}\" for i, v in enumerate(flat[0]) if v.item() > 0]\n\n for _ in range(T):\n hardwareSpikes, _, _ = network.step(inputs)\n for spike in hardwareSpikes:\n if spike in outputs:\n spike_counts[spike] += 1\n else:\n print(f\"Warning: unexpected output spike {spike}\")\n\n # Drain remaining spikes with empty timesteps\n for _ in range(extra_timesteps):\n hardwareSpikes, _, _ = network.step([])\n for spike in hardwareSpikes:\n if spike in outputs:\n spike_counts[spike] += 1\n\n spike_counts = spike_counts / T # convert to average spike rate\n predicted = torch.argmax(spike_counts).item()\n\n total += 1\n if predicted == label:\n correct += 1\n\n print(\n f\"[{total}] Predicted: {predicted}, Label: {label} | \"\n f\"Running accuracy: {100 * correct / total:.2f}%\"\n )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Results\nPrint the final classification accuracy over the full 10 000-image test set.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "accuracy = 100 * correct / total\nprint(f\"Test accuracy: {accuracy:.2f}%\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.6" } }, "nbformat": 4, "nbformat_minor": 0 }