These are notes on the S4 paper titled: Efficiently Modeling Long Sequences with Structured State Spaces by Albert Gu, Karan Goel, and Christopher Re. This is the direct predecessor of the famous Mamba state space model. This is a great starting point for learning more about state space models for deep learning in general, and is a great complementary lesson to my post: "Paper Notes - State Space Models for Deep Learning".
I also conducted an experimental study comparing the performance of Transformer and State Space Model (specifically, Mamba) based LLM architectures for increasing sequence length via number of chunks retrieved using Retrieval Augmented Generation (RAG).
Efficiently Modeling Long Sequences with Structured State Spaces: S4
- The central problem in sequence modeling is efficiently handling data that contains long-range dependencies (LRDs).
Background: State Spaces
2.1 SSMs: A Continuous-time Latent State Model
- The SSM is defined by the simple equation (1). It maps a 1-D input signal to an N-D latent state before projecting to a 1-D output signal .
- It is related to latent state models such as Hidden Markov Models.
- The goal is to use the SSM as a black-box representation in a deep sequence model, where , , , are parameters learned by gradient descent.
Addressing Long-Range Dependencies with HiPPO
- HiPPO theory of continuous-time memorization.
- HiPPO specifies a class of certain matrices that, when incorporated into (1), allow the state to memorize the history of the input .
- The most important matrix in this class is defined by equation (2), which we call the HiPPO matrix.
Discrete-time SSM: The Recurrent Representation
- To apply the model on a discrete input sequence instead of a continuous , (1) must be discretized by a step size that represents the resolution of the input. Conceptually, the inputs can be viewed as sampling an implicit underlying continuous signal , where .
- To discretize the continuous-time SSM, we follow prior work using the bilinear method, which converts the state matrix into an approximation . The discrete SSM is then:
Equation (3) is now a sequence-to-sequence map instead of a function-to-function map. Moreover, the state equation is now a recurrence in , allowing the discrete SSM to be computed like an RNN. Concretely, can be viewed as a hidden state with transition matrix .
2.4 Training SSMs: Convolutional Representation
-
The recurrent SSM (3) is not practical for training on modern hardware due to its sequential nature. Instead, there is a well-known connection between linear time-invariant (LTI) SSMs such as (1) and continuous convolutions. Correspondingly, (3) can be written as a discrete convolution.
-
Let the initial state be . Then unrolling (3) yields:
This can be vectorized into a convolution (4) with an explicit formula for the convolution kernel:
In other words, equation (4) is a single non-circular convolution and can be computed very efficiently with FFTs, provided that is known (pre-computed beforehand).
- However, computing is non-trivial and is the focus of the technical contributions of S4. is referred to as the SSM convolution kernel or filter.
Method: Structured State Spaces (S4)
- The technical focus of S4 is on developing the S4 parameterization and showing how to efficiently compute all views of the SSM: the continuous representation, the recurrent representation, and the convolutional representation.
- Section 3.1 is based on conjugation and diagonalization, and discusses why the naïve application does not work. Section 3.2 gives an overview of key technical components of their approach and formally defines the S4 parameterization. Section 3.3 sketches the main results.
3.1 Motivation: Diagonalization
- The bottleneck in computing the discrete-time SSM (3) is that it involves repeated matrix multiplication by .
- To overcome this, they use a structural result that allows simplification of the SSMs.
Lemma 3.1 Conjugation is an equivalence relation on SSMs .
Proof: Write out the two SSMs with states denoted by and respectively. After multiplying the side SSM by , the two SSMs become identical with . Therefore, these compute the exact same operator , but with a change of basis by in the state .
Lemma 3.1 motivates putting into a canonical form by conjugation, which is ideally more structured and allows faster computation.
Unfortunately, the naïve application of diagonalization does not work due to numerical issues. They derive the explicit diagonalization for the HiPPO matrix (2) and show it has entries exponentially large in the state size , rendering diagonalization numerically infeasible.
Lemma 3.2 The HiPPO matrix in equation (2) is diagonalized by the matrix . In particular, . Therefore, has entries of magnitude up to .
3.2 The S4 Parameterization: Normal Plus Low-Rank
-
The core innovation in S4 is the S4 parameterization, which represents the state transition matrix as the sum of two components: a normal matrix and a low-rank matrix. This is inspired by the observation that the HiPPO matrix is highly structured and its diagonalization produces matrices with exponentially large entries, making it numerically challenging to work with.
-
The S4 parameterization aims to provide a practical solution by splitting into a normal part and a low-rank part :
-
The normal part has a relatively simple structure, while the low-rank part is constructed to capture the long-range dependencies.
-
The key to this approach is that the low-rank part can be efficiently handled using techniques from low-rank matrix factorization, and the normal part allows for easier computations. This decomposition drastically reduces the computational burden associated with the full matrix multiplication, allowing for the efficient computation of the convolution kernel .
3.3 Efficient Computation of
-
The convolution kernel in equation (4) is computed using an efficient approximation that exploits the structure of . The decomposition of into normal and low-rank parts provides a method to compute the kernel more efficiently than the naive approach, which involves repeated matrix exponentiation.
-
The key steps in the computation are:
- Approximate the low-rank part using low-rank updates to the transition matrix.
- Use this approximation to compute the convolution kernel through efficient matrix operations.
-
By decoupling the normal and low-rank parts, the authors show how to compute much more efficiently, reducing the complexity from to in the best cases, where is the size of the latent state.
Experiments
-
The authors evaluate the performance of S4 on several benchmark tasks, demonstrating that it outperforms previous state-of-the-art methods for sequence modeling, particularly in terms of handling long-range dependencies.
-
The experiments show that S4 achieves superior performance on tasks such as time-series prediction and natural language processing tasks compared to traditional methods such as RNNs, LSTMs, and transformers.
-
S4 exhibits a notable advantage in terms of computational efficiency, allowing for faster training and inference times despite handling sequences with very long-range dependencies.
Conclusion
- The S4 method presents a novel approach to sequence modeling by leveraging structured state spaces and a combination of normal and low-rank matrix factorizations.
- The key contribution of S4 is the efficient computation of the convolution kernel , which allows for modeling long-range dependencies in sequences much more efficiently than prior methods.
- S4 offers a promising solution to a fundamental challenge in machine learning and time-series analysis, achieving state-of-the-art performance while maintaining computational efficiency.
Summary of the Paper
The paper introduces S4 (Structured State Spaces), a novel approach to efficiently modeling long-range dependencies in sequence data. The central idea of S4 is based on state-space models (SSMs), which map input sequences to latent states and then to output sequences. The paper highlights the challenge of efficiently computing with these models, particularly when handling long-range dependencies in large sequences.
S4 addresses this challenge by using HiPPO (history-preserving) theory to create structured state spaces, which can store and retrieve long-term dependencies. It then introduces a practical parameterization of the state transition matrix as the sum of a normal matrix and a low-rank matrix. This decomposition significantly reduces the computational complexity involved in matrix operations and convolution kernel computations.
The authors demonstrate that this approach enables efficient computation of long-range dependencies with reduced memory and time requirements, making S4 suitable for tasks like time-series prediction and natural language processing. In experimental evaluations, S4 outperforms previous methods like RNNs, LSTMs, and transformers on sequence modeling tasks, particularly in terms of computational efficiency and long-range dependency handling.
In conclusion, S4 provides a promising new direction for sequence modeling, efficiently addressing the problem of long-range dependencies without compromising performance. It offers a solid foundation for further research and practical applications in machine learning and data analysis.