Paper Notes - RAG for Knowledge Intensive NLP Tasks

Jan 1, 2024

Paper Notes for RAG for Knowledge Intensive NLP Tasks

These are my personal notes summarizing the paper: Retrieval-Augmented Generation for Knowledge_Intensive NLP Tasks by Lewis et al.


Contents



Abstract Points

  • Large pre-trained language models have been shown to store factual knowledge in their parameters.
  • Downfall is that their ability to access and manipulate knowledge is limited.
  • Struggle or fail to provide provenance for their decisions and updating their world knowledge.
  • Present a general-purpose fine-tuning recipe for retrieval-augmented generation (RAG).
    • Models which combine pre-trained parametric and non-parametric memory for language generation.
  • Parametric memory is a pre-trained seq2seq model.
  • Non-parametric memory is a dense vector index of Wikipedia, accessed with a pre-trained neural retriever.


1. Introduction

Pre-trained neural language models like BERT and GPT have demonstrated an impressive ability to internalize large amounts of factual knowledge during training. However, their memory is parametric—baked into their weights—which means it cannot be easily updated or expanded after training. These models also struggle to provide transparency or provenance for their outputs, often generating confident but hallucinated responses with no clear link to an external source.

To address these limitations, researchers have explored hybrid models that combine parametric memory with non-parametric memory—external sources of information that can be retrieved and inspected at inference time. Notably, models like REALM and ORQA have paired masked language models with differentiable retrievers for open-domain extractive question answering. However, these methods are limited in scope and do not apply to generative, sequence-to-sequence tasks.

This paper introduces Retrieval-Augmented Generation (RAG), a framework that brings the hybrid memory approach to seq2seq generation tasks. RAG equips a pre-trained generative model (parametric memory) with access to a non-parametric memory in the form of a dense vector index over Wikipedia. This index is queried using a neural retriever—Dense Passage Retriever (DPR)—to retrieve relevant documents based on the input.

The generation model—based on BART, a powerful seq2seq transformer—is then conditioned on both the input and the retrieved documents to produce outputs. RAG integrates these components into a probabilistic model trained end-to-end, where both the retriever and generator can be fine-tuned jointly. Retrieved documents are treated as latent variables and marginalized using a top-K approximation—either once per output sequence (RAG-Sequence) or at each token step (RAG-Token).

By combining the strengths of parametric and non-parametric memory, RAG offers a flexible, interpretable, and updatable solution for knowledge-intensive NLP tasks.



2. Methods

  • Use the input sequence xx to retrieve text documents zz and use them as additional context when generating the target sequence yy.
  • Models leverage two components:
  1. A retriever pη(zx)p_{\eta}(z\|x) with parameters η\eta that returns (top-K truncated) distributions over text passages given a query xx
  2. A generator pθ(yix,z,y1:i1)p_{\theta}(y_i\|x, z, y_{1:i-1}) parametrized by θ\theta that generates a current token based on a context of the previous i1i-1 tokens y1:i1y_{1:i-1}, the original input xx and a retrieved passage zz.
  • To train the retriever and generator end-to-end, we treat the retrieved document as a latent variable. They propose two models that marginalize over the latent documents in different ways to produce a distribution over generated text.
    • In one approach, RAG-Sequence, the model uses the same document to predict each target token.
    • In the second approach, RAG-Token, can predict each target token based on a different document.


2.1 Models

  • RAG-Sequence: Uses the same retrieved document to generate the complete sequence. Treats the retrieved document as a single latent variable that is marginalized to get the seq2seq probability p(yx)p(y\|x) via a top-K approximation.
    • The top KK documents are retrieved using the retriever, and the generator produces the output sequence probability for each document, which are then marginalized,
pRAG-Sequence(yx)ztop-k(p(x))pη(zx)pθ(yx,z)=ztop-k(p(x))pη(zx)iNpθ(yix,z,y1:i1)p_{\text{RAG-Sequence}}(y\|x) \approx \sum_{z\in \text{top-k}(p(\cdot \| x))} p_{\eta}(z\|x)p_{\theta}(y\|x, z) = \sum_{z\in \text{top-k}(p(\cdot \| x))} p_{\eta}(z\|x)\prod_i^N p_\theta (y_i\|x, z, y_{1:i-1})
  • RAG-Token: Can draw a different latent document for each target token and marginalize accordingly.
    • This allows the generator to choose content from several documents when producing an answer.
    • The top KK documents are retrieved using the retriever, and then the generator produces a distribution for the next output token for each document, before marginalizing, and repeating the process with the following output token.
pRAG-Token(yx)iNztop-kp(x)pη(zx)pθ(yix,z,y1:i1)p_{\text{RAG-Token}}(y\|x) \approx \prod_{i}^N \sum_{z\in \text{top-k}p(\cdot\|x)} p_\eta(z\|x)p_\theta(y_i\|x, z, y_{1:i-1})
  • Note that RAG can be used for sequence classification tasks by considering the target class as a target sequence of length one, in which case RAG-Sequence and RAG-Token are equivalent.


2.2 Retriever: DPR

The retrieval component pη(zx)p_{\eta}(z\|x) is based on DPR. DPR follows a bi-encoder architecture:

pη(zx)exp(d(z)Tq(x))p_\eta(z\|x)\propto \exp(d(z)^T q(x))

where d(z)=BERTd(z)d(z) = BERT_d(z), q(x)=BERTq(x)q(x) = BERT_q(x). d(z)d(z) is a dense representation of a document produced by a BERTBASEBERT_{BASE} document encoder and q(x)q(x) is a query representation produced by a query encoder, also based on BERTBASEBERT_{BASE}.

Calculating the top-k(pη(zx))\text{top-k}(p_\eta(z\|x)), the list of kk documents zz with highest prior probability pη(zx)p_\eta(z\|x), is a Maximum Inner Product Search (MIPS) problem, which can be approximately solved in sub-linear time. They used a pre-trained bi-encoder from DPR to initialize the retriever and to build the document index.

  • Non-parametric memory refers to the document index.


2.3 Generator: BART

The generator component pθ(yix,z,y1:i1)p_\theta(y_i\|x, z, y_{1:i-1}) could be modeled using any encoder-decoder but they use BART-large, a pre-trained seq2seq transformer with 400M parameters.

  • To combine the input xx with the retrieved content zz when generating BART, they simply concatenate them.
  • Parametric memory refers to the BART generator.


2.4 Training

Jointly train the retriever and generator components without any direct supervision on what document should be retrieved.

Given a fine-tuning training corpus of input/output pairs (xj,yj)(x_j, y_j), they minimize the negative marginal log-likelihood of each target, jlogp(yjxj)\sum_j - \log p(y_j\|x_j) using SGD with Adam optimizer.

  • Updating the document encoder BERTdBERT_d during training is costly as it requires the document index to be periodically updated as REALM does during pre-training.
  • They don't find this to be necessary and keep the document encoder and index fixed, only fine-tuning the query encoder BERTqBERT_q and the BART generator.


2.5 Decoding

At test time, RAG-Sequence and RAG-Token require different ways to approximate argmaxyp(yx)\arg \max_y p(y\|x).

  • RAG-Token: The RAG-Token model can be seen as a standard, autoregressive seq2seq generator with transition probability: pθ(yix,y1:i1)p_\theta'(y_i\|x, y_{1:i-1}). To decode, we can plug this probability into a standard beam search decoder.
  • RAG-Sequence: For RAG-Sequence, the likelihood p(yx)p(y\|x) does not break into a conventional per-token likelihood, hence we can't solve it with a single beam search.
    • Instead we run beam search for each document zz, scoring each hypothesis using pθ(yix,y1:i1)p_\theta(y_i\|x, y_{1:i-1}). This yields a set of hypotheses YY, some of which may not have appeared in the beams of all documents.
    • To estimate the probability of a hypothesis yy, they run an additional forward pass for each document zz for which yy does not appear in the beam, multiply the generator probability with pη(zx)p_\eta(z\|x) and then sum the probabilities across beams for the marginals.
    • This is called "Thorough Decoding". Here's how you can continue and finish your blog post:


2.6 Performance and Results

The paper evaluates RAG on a variety of knowledge-intensive NLP tasks, such as:

  • Open-Domain Question Answering (ODQA): RAG significantly outperforms prior methods such as BERT-based extractive models (e.g., REALM and DPR) on benchmarks like Natural Questions and TriviaQA.
  • Abstractive QA: RAG achieves higher performance compared to non-retrieval baselines.
  • Dialogue Generation: RAG demonstrates its ability to produce grounded, factual responses in dialogue systems.

Key Insights from the results:

  • End-to-End Fine-Tuning: Joint optimization of the retriever and generator improves performance compared to training them independently.
  • RAG-Token vs RAG-Sequence:
    • RAG-Token excels in tasks requiring reasoning across multiple documents, thanks to its ability to marginalize at the token level.
    • RAG-Sequence, while simpler, performs better when a single relevant document suffices for generating a coherent output.
  • Interpretability: RAG's non-parametric memory (document index) allows for transparent inspection of retrieved documents, aiding in better understanding and debugging of model outputs.


Discussion and Applications

The authors highlight several advantages of RAG's hybrid memory approach:

  1. Expandable Knowledge: The document index (non-parametric memory) can be easily updated without retraining the model.
  2. Provenance: The retrieved documents provide insight into the knowledge source for generated outputs, making RAG more interpretable.
  3. Versatility: The RAG framework can adapt to a wide range of seq2seq tasks, including QA, summarization, and dialogue.

Potential Applications:

  • Customer Support: Generating accurate, grounded responses using company-specific knowledge bases.
  • Content Creation: Assisting with factually accurate and coherent content generation for articles or blogs.
  • Research Assistance: Summarizing or answering queries based on a given corpus of academic papers.


Limitations and Future Directions

Despite its strong performance, the paper also acknowledges several limitations of RAG:

  • Scalability: While dense retrieval is efficient, the memory requirements for indexing large datasets (e.g., Wikipedia) can be significant.
  • Context Size Limitations: The concatenation of the input and retrieved passages may hit the maximum token limits of models like BART.
  • Retriever Freezing: The authors fixed the document encoder during fine-tuning, which simplifies training but might leave room for performance improvements with an updated retriever.

The paper suggests future work to address these issues, such as:

  • Exploring alternative retrieval mechanisms (e.g., sparse retrievers or hybrid dense-sparse retrieval).
  • Reducing the memory footprint of the document index.
  • Extending RAG to leverage multimodal inputs (e.g., images and videos).


Conclusion

The RAG framework represents a powerful paradigm shift in knowledge-intensive NLP tasks by combining pre-trained generative models with retrieval mechanisms. It effectively mitigates some of the core limitations of parametric models, such as hallucination and the inability to update knowledge dynamically. By introducing hybrid memory into seq2seq tasks, RAG not only improves performance but also enhances interpretability and adaptability across diverse applications.

This work paves the way for more robust and transparent language models, with immense potential in real-world systems that require accurate and explainable outputs.



Let me know if you’d like to expand on any specific section or add diagrams to make the blog post more engaging!