Abstract

Spiking neural networks (SNNs) are a highly efficient signal processing mechanism in biological systems that have inspired a plethora of research efforts aimed at translating their energy efficiency to computational platforms. Efficient training approaches are critical for the successful deployment of SNNs. Compared to mainstream deep neural networks (ANNs), training SNNs is far more challenging due to complex neural dynamics that evolve with time and their discrete, binary computing paradigm. Back-propagation-through-time (BPTT) with surrogate gradients has recently emerged as an effective technique to train deep SNNs directly. SNN-BPTT, however, has a major drawback in that it has a high memory requirement that increases with the number of timesteps. SNNs generally result from the discretization of Ordinary Differential Equations, due to which the sequence length must be typically longer than RNNs, compounding the time dependence problem. It, therefore, becomes hard to train deep SNNs on a single or multi-GPU setup with sufficiently large batch sizes or timesteps, and extended periods of training are required to achieve reasonable network performance. In this work, we reduce the memory requirements of BPTT in SNNs to enable the training of deeper SNNs with more timesteps (T). For this, we leverage the notion of activation re-computation in the context of SNN training that enables the GPU memory to scale sub-linearly with increasing time-steps. We observe that naively deploying the re-computation based approach leads to a considerable computational overhead. To solve this, we propose a time-skipped BPTT approximation technique, called Skipper, for SNNs, that not only alleviates this computation overhead, but also lowers memory consumption further with little to no loss of accuracy. We show the efficacy of our proposed technique by comparing it against a popular method for memory footprint reduction during training. Our evaluations on 5 state-of-the-art networks and 4 datasets show that for a constant batch size and time-steps, skipper reduces memory usage by 3.3× to 8.4× (6.7× on average) over baseline SNN-BPTT. It also achieves a speedup of 29% to 70% over the checkpointed approach and of 4% to 40% over the baseline approach. For a constant memory budget, skipper can scale to an order of magnitude higher timesteps compared to baseline SNN-BPTT.

Full Text
Published version (Free)

Talk to us

Join us for a 30 min session where you can share your feedback and ask us any queries you have

Schedule a call