Note
Go to the end to download the full example code.
DVS128 Gesture Inference on HiAER-Spike Hardware#
This example demonstrates running inference on the DVS128 Gesture dataset using a pre-converted spiking convolutional network (stride-2, 100 channels, 3-layer conv) deployed on HiAER-Spike neuromorphic hardware.
The network was trained with SpikingJelly and converted to the CRI format
using hs_api. Input frames are resized to 63×63 and binarized before
being fed to the hardware timestep-by-timestep.
# sphinx_gallery_thumbnail_path = '_static/dvs_gesture_thumb.png'
The DVS128 Gesture Dataset#
The DVS128 Gesture dataset was collected by IBM Research using the DVS128 dynamic vision sensor — a neuromorphic camera that reports pixel-level brightness changes (events) rather than full frames. Each pixel independently fires an ON event (brightness increase) or an OFF event (brightness decrease) with microsecond-level temporal resolution, producing sparse, asynchronous data that is a natural fit for spiking neural networks.
The dataset contains 11 hand and arm gestures performed by 29 subjects under three different lighting conditions:
0 |
Hand Clapping |
1 |
Right Hand Wave |
2 |
Left Hand Wave |
3 |
Right Arm CW |
4 |
Right Arm CCW |
5 |
Left Arm CW |
6 |
Left Arm CCW |
7 |
Arm Roll |
8 |
Air Drums |
9 |
Air Guitar |
10 |
Other gestures |
Frame representation — SpikingJelly converts the raw event stream into
fixed-size frames by accumulating events over equal-count windows
(split_by="number"). Each frame has shape (2, H, W):
Channel 0 — ON events (brightness increases)
Channel 1 — OFF events (brightness decreases)
A full sample therefore has shape (T, 2, H, W) where T is the number
of frames and H = W = 128 px for the DVS128 sensor.
Importing the necessary libraries#
We need PyTorch and SpikingJelly for data loading and preprocessing,
hs_api for the CRI network interface, and hs_bridge to reset
membrane potentials between samples on the FPGA.
import os
import pickle
import urllib.request
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib.animation import PillowWriter
from spikingjelly.datasets.dvs128_gesture import DVS128Gesture
from torch.utils.data import DataLoader
from spikingjelly.datasets import pad_sequence_collate
Configuration#
Set the path to the DVS128 Gesture dataset and select the compute device. The model configuration is downloaded from Dropbox if not already cached.
data_dir = "/home/ckdeng/myprojects/DVS_Gesture"
model_config_path = "DVS_model_config.pkl"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_CONFIG_URL = (
"https://www.dropbox.com/scl/fi/dz7lfg4hs2mjh3vw0jec8/"
"DVS_model_config.pkl?rlkey=5scc386le9356arxxy3hkrm6e&st=ysecv0s3&dl=1"
)
if not os.path.exists(model_config_path):
print("Downloading model configuration...")
urllib.request.urlretrieve(MODEL_CONFIG_URL, model_config_path)
print(f" Saved to {model_config_path}")
else:
print(f" {model_config_path} already cached.")
Loading the raw dataset for visualisation#
We first load the dataset without any spatial transform so we can visualise the original 128×128 event frames before they are resized for inference.
GESTURE_NAMES = [
"Hand Clapping", "Right Hand Wave", "Left Hand Wave",
"Right Arm CW", "Right Arm CCW", "Left Arm CW",
"Left Arm CCW", "Arm Roll", "Air Drums",
"Air Guitar", "Other",
]
raw_test_set = DVS128Gesture(
root=data_dir,
frames_number=10,
split_by="number",
train=False,
data_type="frame",
duration=1600000,
)
Visualising a gesture#
Each sample is a tensor of shape (T, 2, 128, 128).
We render ON events in green and OFF events in red and animate the
frames to show how the gesture unfolds over time.
sample_frames, sample_label = raw_test_set[0] # [T, 2, H, W]
if isinstance(sample_frames, np.ndarray):
sample_frames = torch.from_numpy(sample_frames)
T_vis, _, H, W = sample_frames.shape
def make_rgb(t):
"""Composite ON (green) and OFF (red) event channels into an RGB frame."""
on = sample_frames[t, 0].float().numpy()
off = sample_frames[t, 1].float().numpy()
rgb = np.zeros((H, W, 3), dtype=np.float32)
rgb[..., 1] = on # green → ON events
rgb[..., 0] = off # red → OFF events
return rgb
fig, ax = plt.subplots(figsize=(4, 4))
ax.axis("off")
fig.suptitle(f"Gesture: {GESTURE_NAMES[sample_label]}", fontsize=12)
im = ax.imshow(make_rgb(0), vmin=0, vmax=1)
time_text = ax.set_title(f"t = 0 / {T_vis - 1}", fontsize=9)
def update(t):
im.set_data(make_rgb(t))
time_text.set_text(f"t = {t} / {T_vis - 1}")
return [im, time_text]
ani = animation.FuncAnimation(fig, update, frames=T_vis, interval=200, blit=True)
ani.save("gesture_sample.gif", writer=PillowWriter(fps=5))
plt.show()
Defining the preprocessing transform#
DVS128 Gesture frames are 128×128. The network expects 63×63 binary inputs, so we resize each frame with bilinear interpolation and then binarize (any non-zero value becomes 1).
import torch.nn as nn
import torch.nn.functional as F
import hs_bridge
from hs_api.api import CRI_network
class DVSResizeAndBinarize:
"""Resize and binarize DVS event frames along the temporal dimension."""
def __init__(self, size):
self.size = size
def __call__(self, data):
frames, label = data if isinstance(data, tuple) else (data, None)
if isinstance(frames, np.ndarray):
frames = torch.from_numpy(frames)
T, C, H, W = frames.shape
resized = torch.zeros(
(T, C, self.size[0], self.size[1]),
dtype=frames.dtype,
device=frames.device,
)
for t in range(T):
frame = frames[t] # [C, H, W]
resized_frame = torch.nn.functional.interpolate(
frame.unsqueeze(0), size=self.size, mode="bilinear", align_corners=False
).squeeze(0)
resized[t] = (resized_frame > 0).float()
return (resized, label) if label is not None else resized
resize_transform = DVSResizeAndBinarize(size=(63, 63))
Loading the dataset for inference#
We reload the dataset with the resize-and-binarize transform applied. A custom collate function handles variable-length sequences.
test_set = DVS128Gesture(
root=data_dir,
frames_number=10,
split_by="number",
train=False,
data_type="frame",
duration=1600000,
transform=resize_transform,
)
test_loader = DataLoader(
test_set,
batch_size=64,
shuffle=False,
drop_last=True,
pin_memory=True,
collate_fn=pad_sequence_collate,
)
print(f"Test samples: {len(test_set)}")
Loading the model configuration#
The CRI network topology (axons, connections, output neuron IDs) was
serialised to disk during the hs_api conversion step and is loaded here.
print("Loading model configuration...")
with open(model_config_path, "rb") as f:
model_config = pickle.load(f)
axons = model_config["axons"]
connections = model_config["connections"]
outputs = model_config["outputs"]
Creating the CRI network#
A CRI_network object is initialised with the loaded topology and
targets the physical CRI hardware (target="CRI").
print("Creating CRI network...")
network = CRI_network(
axons=axons,
connections=connections,
outputs=outputs,
target="CRI",
)
Running inference#
For each test sample we:
Clear the FPGA membrane potentials.
Feed each temporal frame as a list of active axon IDs.
Accumulate output spikes across all frames plus 6 drain timesteps.
Classify by the output neuron with the highest spike rate.
Hardware performance counters (clock cycles, HBM accesses) are recorded for each sample.
print("Running inference on test set...")
data = [] # stores (clock_cycles, hbm_accesses) per sample
correct = 0
total = 0
loss_fn = nn.CrossEntropyLoss()
test_loss = 0
for img, label in test_set:
# Reset membrane potentials before each new sample
hs_bridge.FPGA_Execution.fpga_controller.clear(
len(connections), False, 0
) # num_neurons, simDump, coreOverride
img = img.to(device) # [T, C, H, W]
spike_counts = torch.zeros(len(outputs))
# Feed each frame as a sparse spike list
for t in range(img.shape[0]):
frame = img[t, :, :, :] # [C, H, W]
flat = frame.unsqueeze(0).flatten(start_dim=1).to(torch.int16)
inputs = [f"A{i}" for i, v in enumerate(flat[0]) if v.item() > 0]
hardwareSpikes, _, _ = network.step(inputs)
for spike in hardwareSpikes:
if spike in outputs:
spike_counts[spike] += 1
# Drain remaining spikes with 6 empty timesteps
for _ in range(6):
hardwareSpikes, clock_cycles, hbm_accesses = network.step([])
for spike in hardwareSpikes:
if spike in outputs:
spike_counts[spike] += 1
data.append((clock_cycles, hbm_accesses))
# Classify by average spike rate
spike_counts = spike_counts / img.size(0)
predicted = torch.argmax(spike_counts).item()
total += 1
if predicted == label:
correct += 1
# Cross-entropy loss against one-hot target
label_onehot = F.one_hot(torch.tensor(label), num_classes=11).float()
loss = loss_fn(spike_counts, label_onehot)
test_loss += loss.item()
print(
f"[{total}] Predicted: {predicted}, Label: {label} | "
f"Running accuracy: {100 * correct / total:.2f}%"
)
Results#
Print the final accuracy and average loss over the test set.
accuracy = 100 * correct / total
avg_loss = test_loss / total
print(f"Test accuracy: {accuracy:.2f}%")
print(f"Test loss: {avg_loss:.4f}")