.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/CIFAR10_model.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_CIFAR10_model.py: CIFAR-10 Inference with Bit-Sliced Inputs on HiAER-Spike Hardware ================================================================== This example demonstrates running inference on the CIFAR-10 test set using a fully-connected spiking neural network (16-100-100) deployed on HiAER-Spike neuromorphic hardware. Each 3-channel RGB image is converted to a 15-channel binary representation by extracting the 5 most significant bit planes from each colour channel (bit-slicing). The resulting binary tensor is fed to the hardware timestep-by-timestep, and the output neuron with the highest spike rate determines the predicted class. .. GENERATED FROM PYTHON SOURCE LINES 15-17 .. code-block:: Python # sphinx_gallery_thumbnail_path = '_static/cifar10_bitslice_thumb.png' .. GENERATED FROM PYTHON SOURCE LINES 18-23 Importing the necessary libraries ---------------------------------- We need PyTorch and torchvision for data loading and preprocessing, ``hs_api`` for the CRI network interface, and ``hs_bridge`` to reset membrane potentials between samples on the FPGA. .. GENERATED FROM PYTHON SOURCE LINES 23-34 .. code-block:: Python import io import pickle import urllib.request import torch import torchvision import torchvision.transforms as transforms from PIL import Image import hs_bridge from hs_api.api import CRI_network .. GENERATED FROM PYTHON SOURCE LINES 35-41 Configuration ------------- Select the compute device and set inference hyperparameters. ``T`` is the number of active input timesteps; ``extra_timesteps`` are additional empty steps used to drain spikes still propagating through the network after the last input frame. .. GENERATED FROM PYTHON SOURCE LINES 41-47 .. code-block:: Python device = torch.device("cuda" if torch.cuda.is_available() else "cpu") T = 30 extra_timesteps = 5 MODEL_CONFIG_URL = 'https://www.dropbox.com/scl/fi/dqurfqfye1omyyozj0ps6/CIFAR10_model_config.pkl?rlkey=7lhl1ksk4n4j5evpyocwrh2y9&st=zcphe718&dl=1' .. GENERATED FROM PYTHON SOURCE LINES 48-54 Defining the bit-slicing utilities ------------------------------------ CIFAR-10 images are 8-bit per channel. Rather than thresholding to a single binary value, we extract the 5 most significant bit planes from each of the R, G, and B channels, producing a 15-channel binary tensor per image. This preserves more colour information while keeping inputs binary. .. GENERATED FROM PYTHON SOURCE LINES 54-167 .. code-block:: Python class Cutout(object): """Randomly cut out square regions from an image tensor (data augmentation). Args: n_holes (int): Number of square holes to cut out. length (int): Side length of each square hole. """ def __init__(self, n_holes: int, length: int): self.n_holes = n_holes self.length = length def __call__(self, img) -> torch.Tensor: if isinstance(img, Image.Image): img = transforms.ToTensor()(img) elif not isinstance(img, torch.Tensor): raise TypeError( f"Input must be a PIL Image or torch.Tensor, got {type(img)}" ) C, H, W = img.shape mask = torch.ones((H, W), dtype=torch.bool, device=img.device) for _ in range(self.n_holes): y = torch.randint(H, (1,), device=img.device).item() x = torch.randint(W, (1,), device=img.device).item() y1 = max(0, y - self.length // 2) y2 = min(H, y + self.length // 2) x1 = max(0, x - self.length // 2) x2 = min(W, x + self.length // 2) mask[y1:y2, x1:x2] = False return img * mask.unsqueeze(0).expand_as(img) def __repr__(self): return f"{self.__class__.__name__}(n_holes={self.n_holes}, length={self.length})" def extract_bit_planes( tensor: torch.Tensor, num_bits: int = 8, msb_first: bool = True, target_dtype: torch.dtype = torch.float32, ) -> torch.Tensor: """Extract individual bit planes from a uint8 tensor. Args: tensor: uint8 input tensor of any shape. num_bits: Number of bit planes to extract (1–8). msb_first: If ``True``, extract from the most significant bit downward. target_dtype: dtype of the returned tensor. Returns: Tensor of shape ``(num_bits, *tensor.shape)`` containing 0/1 values. """ if not isinstance(tensor, torch.Tensor): raise TypeError(f"Expected torch.Tensor, got {type(tensor)}") if tensor.dtype != torch.uint8: raise ValueError(f"Input dtype must be torch.uint8, got {tensor.dtype}") if not (1 <= num_bits <= 8): raise ValueError(f"num_bits must be between 1 and 8, got {num_bits}") tensor = tensor.contiguous() dev = tensor.device if msb_first: shifts = torch.arange(7, 7 - num_bits, -1, device=dev, dtype=torch.uint8) else: shifts = torch.arange(0, num_bits, device=dev, dtype=torch.uint8) shift_shape = (num_bits,) + (1,) * tensor.ndim bit_planes = (tensor >> shifts.view(shift_shape)) & 1 return bit_planes.to(target_dtype) def cifar10_to_15channel_binary(image_tensor: torch.Tensor) -> torch.Tensor: """Convert a uint8 RGB image to a 15-channel binary tensor via bit-slicing. Takes the 5 most significant bit planes from each of the R, G, and B channels, yielding a ``(15, H, W)`` float32 tensor. Args: image_tensor: uint8 tensor of shape ``(3, H, W)``. Returns: float32 tensor of shape ``(15, H, W)`` with values in {0, 1}. """ if image_tensor.dtype != torch.uint8: raise TypeError(f"Input must be uint8, got {image_tensor.dtype}") r, g, b = torch.unbind(image_tensor, dim=0) r_bits = extract_bit_planes(r, num_bits=5, msb_first=True) g_bits = extract_bit_planes(g, num_bits=5, msb_first=True) b_bits = extract_bit_planes(b, num_bits=5, msb_first=True) return torch.cat([r_bits, g_bits, b_bits], dim=0) class Binarize(object): """Threshold a float tensor to binary values. Args: threshold (float): Values above this become 1, others become 0. """ def __init__(self, threshold=0.5): self.th = threshold def __call__(self, x: torch.Tensor) -> torch.Tensor: return (x > self.th).float() .. GENERATED FROM PYTHON SOURCE LINES 168-173 Defining the preprocessing transform -------------------------------------- Images are resized to 32×32 (already the CIFAR-10 native size), converted to uint8 tensors with ``PILToTensor``, bit-sliced into 15 channels, and finally binarized with a 0.5 threshold. .. GENERATED FROM PYTHON SOURCE LINES 173-181 .. code-block:: Python transform = transforms.Compose([ transforms.Resize((32, 32)), transforms.PILToTensor(), cifar10_to_15channel_binary, Binarize(threshold=0.5), ]) .. GENERATED FROM PYTHON SOURCE LINES 182-186 Loading the dataset -------------------- We use the torchvision CIFAR-10 test split. After the transform each sample is a ``(15, 32, 32)`` binary tensor. .. GENERATED FROM PYTHON SOURCE LINES 186-198 .. code-block:: Python test_dataset = torchvision.datasets.CIFAR10( root='./data', train=False, transform=transform, download=True, ) C, H, W = test_dataset[0][0].shape print(f"Input shape: {(C, H, W)}") print(f"Test samples: {len(test_dataset)}") .. GENERATED FROM PYTHON SOURCE LINES 199-203 Loading the model configuration -------------------------------- The CRI network topology (axons, connections, output neuron IDs) was serialised to disk during the ``hs_api`` conversion step and is loaded here. .. GENERATED FROM PYTHON SOURCE LINES 203-212 .. code-block:: Python print("Loading model configuration...") with urllib.request.urlopen(MODEL_CONFIG_URL) as response: model_config = pickle.load(io.BytesIO(response.read())) axons = model_config['axons'] connections = model_config['connections'] outputs = model_config['outputs'] .. GENERATED FROM PYTHON SOURCE LINES 213-217 Creating the CRI network ------------------------ A ``CRI_network`` object is initialised with the loaded topology and targets the physical CRI hardware (``target="CRI"``). .. GENERATED FROM PYTHON SOURCE LINES 217-226 .. code-block:: Python print("Creating CRI network...") network = CRI_network( axons=axons, connections=connections, outputs=outputs, target="CRI", ) .. GENERATED FROM PYTHON SOURCE LINES 227-237 Running inference ----------------- For each test image we: 1. Clear the FPGA membrane potentials. 2. Flatten the 15-channel binary image and feed active pixel indices as axon IDs for each of the ``T`` timesteps. 3. Accumulate output spikes across all timesteps plus ``extra_timesteps`` drain steps. 4. Classify by the output neuron with the highest average spike rate. .. GENERATED FROM PYTHON SOURCE LINES 237-283 .. code-block:: Python print("Running inference on test set...") correct = 0 total = 0 for img, label in test_dataset: # Reset membrane potentials before each new sample hs_bridge.FPGA_Execution.fpga_controller.clear( len(connections), False, 0 ) # num_neurons, simDump, coreOverride img = img.to(device) # [C, H, W] flat = img.unsqueeze(0).flatten(start_dim=1) # [1, C*H*W] spike_counts = torch.zeros(len(outputs)) # Build spike list once — same binary image is replayed each timestep inputs = [f"A{i}" for i, v in enumerate(flat[0]) if v.item() > 0] for _ in range(T): hardwareSpikes, _, _ = network.step(inputs) for spike in hardwareSpikes: if spike in outputs: spike_counts[spike] += 1 else: print(f"Warning: unexpected output spike {spike}") # Drain remaining spikes with empty timesteps for _ in range(extra_timesteps): hardwareSpikes, _, _ = network.step([]) for spike in hardwareSpikes: if spike in outputs: spike_counts[spike] += 1 spike_counts = spike_counts / T # convert to average spike rate predicted = torch.argmax(spike_counts).item() total += 1 if predicted == label: correct += 1 print( f"[{total}] Predicted: {predicted}, Label: {label} | " f"Running accuracy: {100 * correct / total:.2f}%" ) .. GENERATED FROM PYTHON SOURCE LINES 284-287 Results ------- Print the final classification accuracy over the full 10 000-image test set. .. GENERATED FROM PYTHON SOURCE LINES 287-290 .. code-block:: Python accuracy = 100 * correct / total print(f"Test accuracy: {accuracy:.2f}%") .. _sphx_glr_download_auto_examples_CIFAR10_model.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: CIFAR10_model.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: CIFAR10_model.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: CIFAR10_model.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_