This page is primarily based on the paper Efficiently Modeling Long Sequences with Structured State Spaces by Albert Gu, Karan Goel, and Christopher Re and How to Train Your HiPPO: State Space Models with Generalized Orthogonal Basis Projections by Albert Gu et al.
Albert Gu has largely been paving the way for utilizing state space models for deep learning. This paper is really the first step to understanding several papers of his and this entire area of work.
I also conducted an experimental study comparing the performance of Transformer and State Space Model (specifically, Mamba) based LLM architectures for increasing sequence length via number of chunks retrieved using Retrieval Augmented Generation (RAG).
This stuff uses a lot of math. In the future, we plan on including posts explaining concepts like:
- Orthogonal Polynomials
- Legendre Polynomials
- Differential Equations (Ordinary, mostly)
- Convolution
Resources
- Timestamp 11:44
- "It really excels when you need really really long sequences like DNA modeling or audio waveforms and so on. Having a long context and being efficient with that is probably more important than this super ability of transformers to focus in on individual states."
Another video explainer (more detail)
Recommended Prerequisites
- Linear Algebra
- Ordinary Differential Equations
- Neural Networks including
- CNNs, RNNs, LSTMs (or GRUs)
- Numerical Methods/Analysis
- Orthogonal Polynomials - Legendre Polynomials, in particular
Structured State Space Models for Deep Sequence Modeling
Watch the video for more detail
Deep Sequence Model Overview
- Incorporating a State Space Model (SSM) layer into a deep model.
- SSMs are classical statistical models:
- Typically a "1-layer," linear model.
- Probabilistic model for the data generating process.
However, in the context of deep learning, SSMs become a deep, non-linear model that is useful for feature extraction, transforming inputs deterministically.
Outline
- SSM Mechanics
- Structured State Spaces (S4) for long-range dependencies
- Deep SSMs: perspectives and directions
SSM Mechanics
What is an SSM?
-
SSMs can be defined using a differential equation:
-
Origin: First developed in control theory.
- The Kalman filter was the first well-known SSM.
Goal: Map a 1-D sequence to another 1-D sequence.
Mapping Sequences to Functions
SSMs map a function to a function, not just sequences to sequences. This is key because a function can be much more general than a sequence. For example, audio data is often treated as an underlying signal, and treating data this way provides inductive bias for the model.
- Inputs pass through the differential equation, parameterized by matrices and , to produce a higher-dimensional latent state .
- Output is derived by projecting into the desired output space using matrix .
Notation:
- : Input function.
- : Latent state function.
- : Output function.
Continuous Representation
SSMs can work with continuous data, treating sequences as signals. This makes the model more flexible for a wide variety of data types like audio and time-series, where continuous models provide a natural representation.
Discretization: From Continuous to Discrete Models
To apply this model to discrete sequences (as we typically have in deep learning), we discretize the differential equation using Euler's method:
Note the first line is just a first-order approximation of .
This results in a discrete recurrence:
Thus, the discrete form of the model is given by:
- Discretization:
- Recurrent Update:
- Output Projection:
Here, the recurrence follows the same idea as an RNN. We can turn this continuous time model into a discrete time recurrent update. So, this can be unrolled as a linear RNN.
Recurrent View of SSMs
- SSMs can be viewed through the lens of recurrent models, where the current output depends on both the previous state and the current input.
- Key feature: SSMs are autoregressive in that each output depends on the entire input history, yet can be computed in constant time (contrary to transformers which can become slower with longer sequences).
- RNN analogy: SSMs resemble RNNs in their update mechanisms, where acts like the hidden state.
Advantages:
- Efficient for online computation: Once the recurrence is defined, new outputs can be computed with just a constant time per step.
Disadvantage:
- RNNs are sequential, so if you know the entire input sequence upfront, you can't leverage parallelization. This is unlike transformers, which can handle parallel computations.
Convolutional View of SSMs
How to be efficient when we have all the input data?
Note: We exclude the because it's easy to compute so we can ignore it.
Instead of using recurrence, we can unroll the linear recurrence explicitly in closed form, similar to a convolution. Here's how:
- First states:
The output is essentially a linear projection of this.
- Output Sequence:
This is equivalent to a convolution operation:
where:
is what we call SSM convolution kernel beause the output is a single convolution with the input from the kernel. Note that the kernel is implicitly infinitely long. In practice we truncate to the length of the sequence.
- This is an implicit convolution, meaning you don't have to compute the entire state vector to get the output.
- It is highly parallelizable and computationally efficient, similar to convolutions used in CNNs.
- Can view a local CNN as an SSM where the matrix is like a shift matrix and the state is just a buffer of the history.
Structured State Space Model (S4)
S4 is an extension of the basic SSM with special formulations for the and matrices.
- S4 models long-range dependencies efficiently.
- HiPPO matrices are used to structure the and matrices.
- The HiPPO operator provides a way to maintain long-term memory across sequences, facilitating better performance for tasks requiring long-range context.
The HiPPO Operator
The HiPPO operator works as follows:
- HiPPO It specifies fixed formulas for and matrices, encoding long-range dependencies in the state.
Key goal: Design the state to encode the entire history of the input sequence. This is critical for reconstructing long-range context. At every time step, the state evolves to preserve the memory of previous inputs.
- can be used to reconstruct the entire input history up to .
- This is known as online function reconstruction.
The HiPPO matrix is structured as:
This formulation enables efficient modeling of long-range dependencies in sequence data, and forms the basis for the S4 model.
Summary and Future Directions
SSMs, and particularly the S4 model, offer a powerful way to handle long-range dependencies in sequence modeling. By leveraging the flexibility of continuous-time models and efficient computational methods like convolutions, S4 can outperform traditional models like RNNs and transformers for certain tasks (e.g., DNA modeling, speech recognition, etc.).
- S4’s key advantage: It allows for efficient, parallelizable computation of long-range dependencies, making it particularly suited for deep learning applications where data involves long sequences.
There are ongoing directions to explore, including improvements in parameter learning algorithms and variants of S4 that could further enhance efficiency and scalability.