CIFAR-10 Inference with Bit-Sliced Inputs on HiAER-Spike Hardware#

This example demonstrates running inference on the CIFAR-10 test set using a fully-connected spiking neural network (16-100-100) deployed on HiAER-Spike neuromorphic hardware.

Each 3-channel RGB image is converted to a 15-channel binary representation by extracting the 5 most significant bit planes from each colour channel (bit-slicing). The resulting binary tensor is fed to the hardware timestep-by-timestep, and the output neuron with the highest spike rate determines the predicted class.

# sphinx_gallery_thumbnail_path = '_static/cifar10_bitslice_thumb.png'

Importing the necessary libraries#

We need PyTorch and torchvision for data loading and preprocessing, hs_api for the CRI network interface, and hs_bridge to reset membrane potentials between samples on the FPGA.

import io
import pickle
import urllib.request
import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
import hs_bridge
from hs_api.api import CRI_network

Configuration#

Select the compute device and set inference hyperparameters. T is the number of active input timesteps; extra_timesteps are additional empty steps used to drain spikes still propagating through the network after the last input frame.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
T = 30
extra_timesteps = 5
MODEL_CONFIG_URL = 'https://www.dropbox.com/scl/fi/dqurfqfye1omyyozj0ps6/CIFAR10_model_config.pkl?rlkey=7lhl1ksk4n4j5evpyocwrh2y9&st=zcphe718&dl=1'

Defining the bit-slicing utilities#

CIFAR-10 images are 8-bit per channel. Rather than thresholding to a single binary value, we extract the 5 most significant bit planes from each of the R, G, and B channels, producing a 15-channel binary tensor per image. This preserves more colour information while keeping inputs binary.

class Cutout(object):
    """Randomly cut out square regions from an image tensor (data augmentation).

    Args:
        n_holes (int): Number of square holes to cut out.
        length (int): Side length of each square hole.
    """

    def __init__(self, n_holes: int, length: int):
        self.n_holes = n_holes
        self.length = length

    def __call__(self, img) -> torch.Tensor:
        if isinstance(img, Image.Image):
            img = transforms.ToTensor()(img)
        elif not isinstance(img, torch.Tensor):
            raise TypeError(
                f"Input must be a PIL Image or torch.Tensor, got {type(img)}"
            )

        C, H, W = img.shape
        mask = torch.ones((H, W), dtype=torch.bool, device=img.device)

        for _ in range(self.n_holes):
            y = torch.randint(H, (1,), device=img.device).item()
            x = torch.randint(W, (1,), device=img.device).item()
            y1 = max(0, y - self.length // 2)
            y2 = min(H, y + self.length // 2)
            x1 = max(0, x - self.length // 2)
            x2 = min(W, x + self.length // 2)
            mask[y1:y2, x1:x2] = False

        return img * mask.unsqueeze(0).expand_as(img)

    def __repr__(self):
        return f"{self.__class__.__name__}(n_holes={self.n_holes}, length={self.length})"


def extract_bit_planes(
    tensor: torch.Tensor,
    num_bits: int = 8,
    msb_first: bool = True,
    target_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
    """Extract individual bit planes from a uint8 tensor.

    Args:
        tensor: uint8 input tensor of any shape.
        num_bits: Number of bit planes to extract (1–8).
        msb_first: If ``True``, extract from the most significant bit downward.
        target_dtype: dtype of the returned tensor.

    Returns:
        Tensor of shape ``(num_bits, *tensor.shape)`` containing 0/1 values.
    """
    if not isinstance(tensor, torch.Tensor):
        raise TypeError(f"Expected torch.Tensor, got {type(tensor)}")
    if tensor.dtype != torch.uint8:
        raise ValueError(f"Input dtype must be torch.uint8, got {tensor.dtype}")
    if not (1 <= num_bits <= 8):
        raise ValueError(f"num_bits must be between 1 and 8, got {num_bits}")

    tensor = tensor.contiguous()
    dev = tensor.device

    if msb_first:
        shifts = torch.arange(7, 7 - num_bits, -1, device=dev, dtype=torch.uint8)
    else:
        shifts = torch.arange(0, num_bits, device=dev, dtype=torch.uint8)

    shift_shape = (num_bits,) + (1,) * tensor.ndim
    bit_planes = (tensor >> shifts.view(shift_shape)) & 1
    return bit_planes.to(target_dtype)


def cifar10_to_15channel_binary(image_tensor: torch.Tensor) -> torch.Tensor:
    """Convert a uint8 RGB image to a 15-channel binary tensor via bit-slicing.

    Takes the 5 most significant bit planes from each of the R, G, and B
    channels, yielding a ``(15, H, W)`` float32 tensor.

    Args:
        image_tensor: uint8 tensor of shape ``(3, H, W)``.

    Returns:
        float32 tensor of shape ``(15, H, W)`` with values in {0, 1}.
    """
    if image_tensor.dtype != torch.uint8:
        raise TypeError(f"Input must be uint8, got {image_tensor.dtype}")

    r, g, b = torch.unbind(image_tensor, dim=0)
    r_bits = extract_bit_planes(r, num_bits=5, msb_first=True)
    g_bits = extract_bit_planes(g, num_bits=5, msb_first=True)
    b_bits = extract_bit_planes(b, num_bits=5, msb_first=True)
    return torch.cat([r_bits, g_bits, b_bits], dim=0)


class Binarize(object):
    """Threshold a float tensor to binary values.

    Args:
        threshold (float): Values above this become 1, others become 0.
    """

    def __init__(self, threshold=0.5):
        self.th = threshold

    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        return (x > self.th).float()

Defining the preprocessing transform#

Images are resized to 32×32 (already the CIFAR-10 native size), converted to uint8 tensors with PILToTensor, bit-sliced into 15 channels, and finally binarized with a 0.5 threshold.

transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.PILToTensor(),
    cifar10_to_15channel_binary,
    Binarize(threshold=0.5),
])

Loading the dataset#

We use the torchvision CIFAR-10 test split. After the transform each sample is a (15, 32, 32) binary tensor.

test_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=False,
    transform=transform,
    download=True,
)

C, H, W = test_dataset[0][0].shape
print(f"Input shape: {(C, H, W)}")
print(f"Test samples: {len(test_dataset)}")

Loading the model configuration#

The CRI network topology (axons, connections, output neuron IDs) was serialised to disk during the hs_api conversion step and is loaded here.

print("Loading model configuration...")
with urllib.request.urlopen(MODEL_CONFIG_URL) as response:
    model_config = pickle.load(io.BytesIO(response.read()))

axons = model_config['axons']
connections = model_config['connections']
outputs = model_config['outputs']

Creating the CRI network#

A CRI_network object is initialised with the loaded topology and targets the physical CRI hardware (target="CRI").

print("Creating CRI network...")
network = CRI_network(
    axons=axons,
    connections=connections,
    outputs=outputs,
    target="CRI",
)

Running inference#

For each test image we:

  1. Clear the FPGA membrane potentials.

  2. Flatten the 15-channel binary image and feed active pixel indices as axon IDs for each of the T timesteps.

  3. Accumulate output spikes across all timesteps plus extra_timesteps drain steps.

  4. Classify by the output neuron with the highest average spike rate.

print("Running inference on test set...")

correct = 0
total = 0

for img, label in test_dataset:
    # Reset membrane potentials before each new sample
    hs_bridge.FPGA_Execution.fpga_controller.clear(
        len(connections), False, 0
    )  # num_neurons, simDump, coreOverride

    img = img.to(device)                     # [C, H, W]
    flat = img.unsqueeze(0).flatten(start_dim=1)  # [1, C*H*W]
    spike_counts = torch.zeros(len(outputs))

    # Build spike list once — same binary image is replayed each timestep
    inputs = [f"A{i}" for i, v in enumerate(flat[0]) if v.item() > 0]

    for _ in range(T):
        hardwareSpikes, _, _ = network.step(inputs)
        for spike in hardwareSpikes:
            if spike in outputs:
                spike_counts[spike] += 1
            else:
                print(f"Warning: unexpected output spike {spike}")

    # Drain remaining spikes with empty timesteps
    for _ in range(extra_timesteps):
        hardwareSpikes, _, _ = network.step([])
        for spike in hardwareSpikes:
            if spike in outputs:
                spike_counts[spike] += 1

    spike_counts = spike_counts / T  # convert to average spike rate
    predicted = torch.argmax(spike_counts).item()

    total += 1
    if predicted == label:
        correct += 1

    print(
        f"[{total}] Predicted: {predicted}, Label: {label} | "
        f"Running accuracy: {100 * correct / total:.2f}%"
    )

Results#

Print the final classification accuracy over the full 10 000-image test set.

accuracy = 100 * correct / total
print(f"Test accuracy: {accuracy:.2f}%")

Gallery generated by Sphinx-Gallery