avatarJerry Liu

Summary

The website provides a comprehensive guide on fine-tuning embedding models to enhance the performance of Retrieval Augmented Generation (RAG) systems using synthetic data without the need for labels or a GPU.

Abstract

The website introduces a detailed, end-to-end tutorial for fine-tuning embedding models to improve the retrieval capabilities of RAG systems. It emphasizes a method that does not require labeled data or GPU resources, achieving a 5-10% increase in retrieval evaluation metrics. The process involves generating synthetic datasets with hypothetical questions linked to text chunks, finetuning an open-source embedding model, and evaluating the model against benchmarks like text-embedding-ada-002. The guide is part of the LlamaIndex repository, which now includes embedding finetuning abstractions, and it showcases the use of LlamaIndex modules to automatically generate training data. The tutorial is designed to be accessible, with step-by-step notebooks provided, and is beneficial for anyone looking to enhance the accuracy and relevance of retrieval in RAG systems.

Opinions

  • The authors believe that the embeddings from pre-trained models may not be optimal for specific retrieval tasks and that finetuning can align them better with the user's data and objectives.
  • They suggest that the ease of prototyping RAG systems belies the difficulty in productionizing them, with retrieval being a common failure point.
  • The authors advocate for the use of synthetic data generation to overcome the lack of pre-existing positive and negative example pairs, which is a novel approach in the field.
  • They highlight the importance of using a comprehensive suite of metrics, such as those provided by the InformationRetrievalEvaluator, to assess the performance of embedding models.
  • The authors are optimistic about the potential of finetuning, pointing to technological advancements and easy-to-use services that facilitate this process.
  • They emphasize the practicality of their approach by noting that the tutorials were tested on an M2 Macbook Pro, demonstrating that finetuning embedding models does not necessitate high-end computational resources.

Fine-Tuning Embeddings for RAG with Synthetic Data

UPDATE 9/10/2023: We’ve included embedding finetuning abstractions into the LlamaIndex repo, so this repo is technically outdated! Please check out our embedding fine-tuning guides in the core documentation.

We’ve created a comprehensive, end-to-end guide showing you how to fine-tune an embedding model to improve performance of Retrieval Augmented Generation (RAG) systems over any unstructured text corpus (no labels required!).

The result is a 5–10% performance increase in retrieval evaluation metrics — our finetuned bge model almost reaches text-embedding-ada-002 levels of retrieval performance in terms of hit rate. This enables more accurate retrieval which leads to better RAG systems as a whole.

This tutorial is helpful to anyone building RAG systems:

  • If you’re new to finetuning, no problem! We have step by step notebooks walking through the key steps. Simply substitute the file links for your own data, and just run every cell.
  • Finetuning embedding models is lightweight and doesn’t require a GPU. These notebooks were tested on an M2 Macbook Pro.

Resources

No Llama image this time :)

Background/Context

The Current RAG Stack

RAG is a popular paradigm for connecting Large Language Models (LLMs) with an external source of data that was not present in its training corpus. It pairs a retrieval model over a knowledge bank with the LLM through its input prompt space. RAG stacks typically look like the following:

  • Indexing: Prepare a corpus of unstructured text, parse/chunk it. Then embed each chunk and put in a vector database.
  • Query-time: Retrieve context from the vector db using top-k embedding similarity lookup, and stuff context into the LLM input space.

(Of course RAG can be much more advanced than this, and LlamaIndex provides tools for both simple and advanced RAG)

Unfortunately RAG is easy to prototype by cobbling together the different components, but hard to productionize. The simple stack has many failure modes and oftentimes the issue lies with bad retrieval — if the returned context is irrelevant to the query, then the capability of the LLM is irrelevant; the answer will always be bad.

How Can We Make Retrieval Better?

We can try more sophisticated retrieval algorithms (e.g. hybrid search, reranking).

An insight from our recent production RAG webinar, however, is that the embeddings themselves may not live in an optimal latent space for your data. Embeddings generated by pre-trained models may be close/far from each other based on the pre-training objective, but may not completely align with your own retrieval objective. For instance, if you’re building search over ML ArXiv papers, you may want the embeddings to align semantically with specific ML concepts (e.g. “LLMs”, “NLP”) and not filler words “This paper is…”).

Finetuning is a way to solve that. The concept of finetuning has become increasingly popular in the LLM space, with technological advancements as well as easy-to-use services.

In this tutorial, we focus on finetuning the embedding model. We show how finetuning the embedding model can lead to better retrieval performance.

Challenges/Considerations

When you finetune embeddings, you need training examples. In the case of embeddings, this typically means that you have both “positive” and “negative” examples — pairs of texts that should be close to each other and far from each other.

An issue is that we don’t have these positive or negative examples apriori. Given a dataset of unstructured text, is it possible to automatically generate these example pairs?

With LlamaIndex you can! We use LlamaIndex modules to automatically generate a set of questions from unstructured text chunks. These (question, chunk) pairs are then used as positive examples as training signals for the model (negative examples are randomly sampled across other chunks).

The next section shows a full walkthrough across all of our modules.

Walkthrough

At a high-level, we do the following:

  1. Generating synthetic dataset for training and evaluation (Notebook)
  2. Finetuning an opensource embedding model (Notebook)
  3. Evaluating the embedding model (Notebook)

Generating synthetic dataset for training and evaluation

The key idea here is that we can leverage an LLM to generate hypothetical questions that are best answered by a given piece of context. This allows us to generate synthetic positive pairs of (query, relevant documents) in a scalable way without requiring human labellers.

More concretely, we first process the given documents into a corpus of text chunks. We do this with the SimpleNodeParser module in LlamaIndex:

parser = SimpleNodeParser()
nodes = parser.get_nodes_from_documents(docs, show_progress=verbose)
corpus = {
  node.node_id: node.get_content(metadata_mode=MetadataMode.NONE) 
  for node in nodes
}

Then for each text chunk, we use LLM to generate a few hypothetical questions that can be answered with information form that text chunk. The example prompt is shown below as well.

prompt_template = prompt_template or """\
  Context information is below.
  
  ---------------------
  {context_str}
  ---------------------
  
  Given the context information and not prior knowledge.
  generate only questions based on the below query.
  
  You are a Teacher/ Professor. Your task is to setup \
  {num_questions_per_chunk} questions for an upcoming \
  quiz/examination. The questions should be diverse in nature \
  across the document. Restrict the questions to the \
  context information provided."
  """

# for a given node, extract questions (do this over all nodes in outer loop)
query = prompt_template.format(context_str=text, num_questions_per_chunk=num_questions_per_chunk)
response = llm.complete(query)

result = str(response).strip().split("\n")
questions = [
    re.sub(r"^\d+[\).\s]", "", question).strip() for question in result
]
questions = [question for question in questions if len(question) > 0]

Finally, we collect all pairs of questions and text chunks as the dataset. Example query, chunk, and mapping is shown below.

# example query
f331640a-b407-4028-8db8-4b8db691dd34: "What is the market value of Lyft's common stock held by non-affiliates as of June 30, 2021, based on the closing sales price of the Class A common stock on that date?"

# example corpus
d5554f3e-cdaf-41d7-ac49-8f0ffe3f5759:"UNITED STATESSECURITIES AND..."

# example mapping
f331640a-b407-4028-8db8-4b8db691dd34: d5554f3e-cdaf-41d7-ac49-8f0ffe3f5759

Finetuning an opensource embedding model

We leverage the high-level model fitting API from sentencetransformers to very easily setup a training process.

We use MultipleNegativesRankingLoss as the training object and InformationRetrievalEvaluator as the evaluator during training. Also, we useBAAI/bge-small-en on Hugging Face as the base model and train for a small number of epochs.

# define model
model_id = "BAAI/bge-small-en"
model = SentenceTransformer(model_id)

...

# define loss
from sentence_transformers import losses
loss = losses.MultipleNegativesRankingLoss(model)

# define evaluator
from sentence_transformers.evaluation import InformationRetrievalEvaluator
# define over validation dataset
...
evaluator = InformationRetrievalEvaluator(queries, corpus, relevant_docs)

# run training
...
model.fit(
    train_objectives=[(loader, loss)],
    epochs=EPOCHS,
    warmup_steps=warmup_steps,
    output_path='exp_finetune',
    show_progress_bar=True,
    evaluator=evaluator, 
    evaluation_steps=50,
)

Evaluating the embedding model

We compare the finetuned model against the base model, as well as the OpenAI embedding model text-embedding-ada-002 .

We evaluate with two main metrics:

  • Hit-rate metric: For each (query, relevant_doc) pair, we retrieve the top-k documents with the query. It’s a hit if the results contain relevant_doc.
  • InformationRetrievalEvaluator from sentence_transformers. This provides a comprehensive suite of metrics such as cosine similarity accuracy, precision, recall at different top-k values.

Results

In terms of hit-rate metric, the base model gets 78% hit-rate on the validation dataset, and the fine-tuned model gets 84%. text-embedding-ada-002 gets 87%, which means that our fine-tuned model is only 3% off!

Hit-rate for `text-embedding-ada-002`, base model, finetuned model

The InformationRetrievalEvaluator shows a similar improvement across an entire suite of metrics. The fine-tuned model increases evaluation metrics by 5–10% compared to the base-model.

Evaluation suite from `InformationRetrievalEvaluator`

Conclusion

We successfully finetuned an embedding model over unlabeled, unstructured data to give better retrieval performance for downstream RAG systems. We show a 5–10% improvement across all metrics!

Resources

(copied from intro)

Fine Tuning
Llm
Llamaindex
NLP
Embedding
Recommended from ReadMedium