Paper Notes - State Space Models for Deep Learning

Apr 2, 2024

This page is primarily based on the paper Efficiently Modeling Long Sequences with Structured State Spaces by Albert Gu, Karan Goel, and Christopher Re and How to Train Your HiPPO: State Space Models with Generalized Orthogonal Basis Projections by Albert Gu et al.

Albert Gu has largely been paving the way for utilizing state space models for deep learning. This paper is really the first step to understanding several papers of his and this entire area of work.

I also conducted an experimental study comparing the performance of Transformer and State Space Model (specifically, Mamba) based LLM architectures for increasing sequence length via number of chunks retrieved using Retrieval Augmented Generation (RAG).

This stuff uses a lot of math. In the future, we plan on including posts explaining concepts like:

  • Orthogonal Polynomials
  • Legendre Polynomials
  • Differential Equations (Ordinary, mostly)
  • Convolution

Resources

Good video explanation

  • Timestamp 11:44
    • "It really excels when you need really really long sequences like DNA modeling or audio waveforms and so on. Having a long context and being efficient with that is probably more important than this super ability of transformers to focus in on individual states."

Another video explainer (more detail)

  • Linear Algebra
  • Ordinary Differential Equations
  • Neural Networks including
    • CNNs, RNNs, LSTMs (or GRUs)
  • Numerical Methods/Analysis
    • Orthogonal Polynomials - Legendre Polynomials, in particular

Structured State Space Models for Deep Sequence Modeling

Watch the video for more detail

Deep Sequence Model Overview

  • Incorporating a State Space Model (SSM) layer into a deep model.
  • SSMs are classical statistical models:
    • Typically a "1-layer," linear model.
    • Probabilistic model for the data generating process.

However, in the context of deep learning, SSMs become a deep, non-linear model that is useful for feature extraction, transforming inputs deterministically.

Outline

  • SSM Mechanics
  • Structured State Spaces (S4) for long-range dependencies
  • Deep SSMs: perspectives and directions

SSM Mechanics

What is an SSM?

  • SSMs can be defined using a differential equation:

    x(t)=Ax(t)+Bu(t)x'(t) = Ax(t) + Bu(t) y(t)=Cx(t)+Du(t)y(t) = Cx(t) + Du(t)
  • Origin: First developed in control theory.

    • The Kalman filter was the first well-known SSM.

Goal: Map a 1-D sequence to another 1-D sequence.

Mapping Sequences to Functions

SSMs map a function to a function, not just sequences to sequences. This is key because a function can be much more general than a sequence. For example, audio data is often treated as an underlying signal, and treating data this way provides inductive bias for the model.

  • Inputs u(t)u(t) pass through the differential equation, parameterized by matrices AA and BB, to produce a higher-dimensional latent state x(t)x(t).
  • Output y(t)y(t) is derived by projecting x(t)x(t) into the desired output space using matrix CC.

Notation:

  • u(t)u(t): Input function.
  • x(t)x(t): Latent state function.
  • y(t)y(t): Output function.

Continuous Representation

SSMs can work with continuous data, treating sequences as signals. This makes the model more flexible for a wide variety of data types like audio and time-series, where continuous models provide a natural representation.

Discretization: From Continuous to Discrete Models

To apply this model to discrete sequences (as we typically have in deep learning), we discretize the differential equation using Euler's method:

x(t+Δ)x(t)+Δx(t)=x(t)+Δ(Ax(t)+Bu(t))=x(t)+ΔAx(t)+ΔBu(t)=(I+ΔA)x(t)+ΔBu(t)=Ax(t)+Bu(t) \begin{gather} x(t + \Delta) \approx x(t) + \Delta x'(t) \\ = x(t) + \Delta(Ax(t) + Bu(t)) \\ = x(t) + \Delta Ax(t) + \Delta B u(t) \\ = (I + \Delta A)x(t) + \Delta B u(t) \\ = \overline{A} x(t) + \overline{B}u(t) \end{gather}

Note the first line is just a first-order approximation of x(t+Δ)x(t + \Delta).

This results in a discrete recurrence:

Ax(t)+Bu(t)\overline{A} x(t) + \overline{B} u(t)

Thus, the discrete form of the model is given by:

  1. Discretization: A=I+ΔA\overline{A} = I + \Delta A
  2. Recurrent Update: xk=Axk1+Bukx_k = \overline{A} x_{k-1} + \overline{B} u_k
  3. Output Projection: yk=Cxk1+Duky_k = \overline{C} x_{k-1} + \overline{D} u_k

Here, the recurrence follows the same idea as an RNN. We can turn this continuous time model into a discrete time recurrent update. So, this can be unrolled as a linear RNN.


Recurrent View of SSMs

  • SSMs can be viewed through the lens of recurrent models, where the current output depends on both the previous state and the current input.
  • Key feature: SSMs are autoregressive in that each output depends on the entire input history, yet can be computed in constant time (contrary to transformers which can become slower with longer sequences).
  • RNN analogy: SSMs resemble RNNs in their update mechanisms, where x(t)x(t) acts like the hidden state.

Advantages:

  • Efficient for online computation: Once the recurrence is defined, new outputs can be computed with just a constant time per step.

Disadvantage:

  • RNNs are sequential, so if you know the entire input sequence upfront, you can't leverage parallelization. This is unlike transformers, which can handle parallel computations.

Convolutional View of SSMs

How to be efficient when we have all the input data?

Note: We exclude the D\overline D because it's easy to compute so we can ignore it.

xk=Axk1+Bukyk=Cxk\begin{gather} x_k = \overline A x_{k-1} + \overline B u_k \\ y_k = \overline C x_k \end{gather}

Instead of using recurrence, we can unroll the linear recurrence explicitly in closed form, similar to a convolution. Here's how:

  • First states: x0=Bu0x1=ABu0+Bu1x2=A2Bu0+ABu1+Bu2\begin{gather*} x_0 = \overline{B} u_0 \\ x_1 = \overline{A} \overline{B} u_0 + \overline{B} u_1 \\ x_2 = \overline{A}^2 \overline{B} u_0 + \overline{A} \overline{B} u_1 + \overline{B} u_2 \\ \end{gather*}

The output is essentially a linear projection of this.

  • Output Sequence: yk=CAkBu0+CAk1Bu1++CBuk\begin{gather*} y_k = \overline{C} \overline{A}^k \overline{B} u_0 + \overline{C} \overline{A}^{k-1} \overline{B} u_1 + \cdots + \overline{C} \overline{B} u_k \end{gather*}

This is equivalent to a convolution operation:

y=Kuy = \overline{K} * u

where:

K=(CB,CAB,CA2B,)\overline{K} = (\overline{C} \overline{B}, \overline{C} \overline{A} \overline{B}, \overline{C} \overline{A}^2 \overline{B}, \ldots)

K\overline{K} is what we call SSM convolution kernel beause the output yy is a single convolution with the input uu from the kernel. Note that the kernel is implicitly infinitely long. In practice we truncate to the length of the sequence.

  • This is an implicit convolution, meaning you don't have to compute the entire state vector x(t)x(t) to get the output.
  • It is highly parallelizable and computationally efficient, similar to convolutions used in CNNs.
  • Can view a local CNN as an SSM where the AA matrix is like a shift matrix and the state is just a buffer of the history.

Structured State Space Model (S4)

S4 is an extension of the basic SSM with special formulations for the AA and BB matrices.

  • S4 models long-range dependencies efficiently.
  • HiPPO matrices are used to structure the AA and BB matrices.
  • The HiPPO operator provides a way to maintain long-term memory across sequences, facilitating better performance for tasks requiring long-range context.

The HiPPO Operator

The HiPPO operator works as follows:

x(t)=Ax(t)+Bu(t)x'(t) = Ax(t) + Bu(t)
  • u(t)u(t) \to HiPPO x(t)\to x(t) It specifies fixed formulas for AA and BB matrices, encoding long-range dependencies in the state.

Key goal: Design the state to encode the entire history of the input sequence. This is critical for reconstructing long-range context. At every time step, the state evolves to preserve the memory of previous inputs.

  • x(t)x(t) can be used to reconstruct the entire input history up to tt.
  • This is known as online function reconstruction.

The HiPPO matrix AA is structured as:

Ank={0n<kn+1n=k2n+1n>k\begin{gather} A_{nk} = \begin{cases} 0 & n < k \\ n+1 & n=k \\ 2n+1 & n > k \end{cases} \end{gather}

This formulation enables efficient modeling of long-range dependencies in sequence data, and forms the basis for the S4 model.


Summary and Future Directions

SSMs, and particularly the S4 model, offer a powerful way to handle long-range dependencies in sequence modeling. By leveraging the flexibility of continuous-time models and efficient computational methods like convolutions, S4 can outperform traditional models like RNNs and transformers for certain tasks (e.g., DNA modeling, speech recognition, etc.).

  • S4’s key advantage: It allows for efficient, parallelizable computation of long-range dependencies, making it particularly suited for deep learning applications where data involves long sequences.

There are ongoing directions to explore, including improvements in parameter learning algorithms and variants of S4 that could further enhance efficiency and scalability.