Abstract:In the quest for next-generation sequence modeling architectures, State Space Models (SSMs) have emerged as a potent alternative to transformers, particularly for their computational efficiency and suitability for dynamical systems. This paper investigates the effect of quantization on the S5 model to understand its impact on model performance and to facilitate its deployment to edge and resource-constrained platforms. Using quantization-aware training (QAT) and post-training quantization (PTQ), we systematically evaluate the quantization sensitivity of SSMs across different tasks like dynamical systems modeling, Sequential MNIST (sMNIST) and most of the Long Range Arena (LRA). We present fully quantized S5 models whose test accuracy drops less than 1% on sMNIST and most of the LRA. We find that performance on most tasks degrades significantly for recurrent weights below 8-bit precision, but that other components can be compressed further without significant loss of performance. Our results further show that PTQ only performs well on language-based LRA tasks whereas all others require QAT. Our investigation provides necessary insights for the continued development of efficient and hardware-optimized SSMs.
Abstract:As the role of artificial intelligence becomes increasingly pivotal in modern society, the efficient training and deployment of deep neural networks have emerged as critical areas of focus. Recent advancements in attention-based large neural architectures have spurred the development of AI accelerators, facilitating the training of extensive, multi-billion parameter models. Despite their effectiveness, these powerful networks often incur high execution costs in production environments. Neuromorphic computing, inspired by biological neural processes, offers a promising alternative. By utilizing temporally-sparse computations, Spiking Neural Networks (SNNs) offer to enhance energy efficiency through a reduced and low-power hardware footprint. However, the training of SNNs can be challenging due to their recurrent nature which cannot as easily leverage the massive parallelism of modern AI accelerators. To facilitate the investigation of SNN architectures and dynamics researchers have sought to bridge Python-based deep learning frameworks such as PyTorch or TensorFlow with custom-implemented compute kernels. This paper introduces Spyx, a new and lightweight SNN simulation and optimization library designed in JAX. By pre-staging data in the expansive vRAM of contemporary accelerators and employing extensive JIT compilation, Spyx allows for SNN optimization to be executed as a unified, low-level program on NVIDIA GPUs or Google TPUs. This approach achieves optimal hardware utilization, surpassing the performance of many existing SNN training frameworks while maintaining considerable flexibility.