Improving RAG AI Systems
If you’ve been in the AI field in recent times, you’ve definitely heard of Retrieval Augmented Generation (RAG) as a technique for building domain-specific question-answering systems (or more generally, chatbots 😆).
There are a couple of challenges anyone who has built a RAG application either as a side project or for use in a production environment would have stumbled upon and in this article, I will discuss some techniques to overcome these challenges. Much of what I’ll share in this article are learnings from DeepLearning.AI’s recently released course, Advanced Retrieval for AI with Chroma.
INTRODUCTION
Before we go any further, let me do a quick crash course to remind us of what RAG is and how it works.
RAG is a technique for building Generative AI (GenAI) chatbots to answer questions based on specific user-provided data.
In general, GenAI models (like ChatGPT) are trained on a large corpus of data and so can usually answer most of our questions. Some of the downsides of this are, that the model has no knowledge of more recent data, has no knowledge of an organization’s internal data, etc.
To give the model access to the data (context) it needs to answer certain questions (queries), it is necessary to pass the data to the model. The problem tho, is that the models have a context window which prevents me (us) from just passing in all the available data. We need to limit the data to only the relevant data for the query. Also, it is important in some cases to make sure the model doesn’t just respond based on assumptions (known as hallucination) but rather sticks to real information and when there’s no answer, it says so. In comes Retrieval Augmented Generation.
In RAG applications, data is chunked and embedded (converted to vectors/ numbers in a multi-dimensional space, representing the sentimental and contextual meanings of the individual words) and stored in a vector database (vectordb). When the user asks a question, the query is also embedded (using the same embedding model) and the query embedding is used to check through the stored embeddings and “retrieve” relevant document chunks based on the similarities between the chunks and the queries.
These retrieved document chunks are then passed to the LLM model along with the user’s query and in most cases, some necessary instructions (or prompts) that guide the LLM’s behavior.
This article isn’t too focused on RAG so I wouldn’t go into more detail but I hope the above overview is sufficient to lay a foundation or jug your memory as the case may be. For more details, this article by Onkar Mishral would be very beneficial.
Problems with RAG
When building RAG applications today, all AI Engineers face the same common challenge of how to deal with irrelevant queries as well as irrelevant results.
In production environments, it’s common to find users asking questions to the system that aren’t relevant to the data uploaded. It could be that they ask questions that really can’t be answered by the data or they ask questions that result in the system retrieving document chunks that may not be relevant to what the user has in mind, bringing in the concept of relevancy versus distraction — more on that later.
“Imagine all your data as a cloud of points sitting in this high-dimensional space. A query that lands inside the cloud is likely to find nearest neighbours that are (sort of) densely packed and close together inside the cloud but a query that lands outside the cloud is likely to find nearest neighbours from a lot of different parts of the cloud so they tend to be more spread out.”
- Anton Troynikov
Another (more often overlooked) challenge when building RAG systems is the issues with the context window of your embedding model. Embedding models also have a limited context window and so tend to truncate any characters or tokens beyond their context window. This makes it critical to handle your text splitting carefully (this is beyond the scope of today’s article).
Much of the focus of this article would be on how to make sure the context (document chunks) passed to the LLM consists more of relevant documents than distractors.
QUERY EXPANSION
Paper: https://arxiv.org/abs/2305.03653
The query expansion techniques try to get different “perspectives” on the user’s query to cause the retrieval process to search over a wider dimensional space for relevant document chunks. This can be achieved in two ways:
- Expansion with generated answers
- Expansion with multiple queries
Expansion with Generated Answer
With this technique, the user’s query is first passed to an LLM prompted to come up with a hypothetical answer for the user’s query (essentially the model is asked to hallucinate — I guess hallucinating is not always bad after all 😆).
This hypothetical answer is concatenated with the orignal query and then passed into the vector database to search for relevant document chunks. This new combination moves the query elsewhere in the high-dimensional space, hopefully generating better results for the retrieval search.
The retrieved documents are passed with the user’s query into the LLM to generate an answer as usual.
Expansion with Multiple Queries
In expanding with multiple queries, the first LLM is prompted to generate additional questions that are related to the original query and would help in getting the needed information. It’s important to also include in the prompt that the model generates simple questions to improve the usefulness of the generated queries in retrieving relevant document chunks from the vector database.
These new queries as well as the original query are used to search the vectordb for relevant documents. ChromaDB particularly shines in this technique because it is able to handle multiple queries in parallel for document retrieval and so all you need do is pass in the list of all queries to be used as such:
queries = [original_query] + augmented_queries
results = chroma_collection.query(query_texts=queries, n_results=5, include=["documents"])
The above code snippet would return 5 relevant documents for each of the queries passed in. Since these queries are similar, it’s possible to have duplicated documents. It is therefore necessary to deduplicate the document set so as to manage tokens being passed into the LLM.
And then finally, the retrieved documents are passed with the original query into the LLM for an answer.
___
When working with Query Expansion techniques, you find that Prompt Engineering as a concept becomes even more important because you’re dealing with at least one more LLM. It is therefore important to experiment with your prompts and see which works best for your use case.
CROSS-ENCODER RE-RANKING
Previously, I had alluded to the concept of relevancy versus distraction. Well, documents retrieved from a vector database have varying levels of relevance to the query with which they were retrieved. In practice, you find that some may even have very low relevance and as such just end up being distractions to the LLM.
Take a scenario where a user asks a question about LangChain and in the database, only 2 document chunks actually address the user’s question; but in building the RAG system, the engineer specified that 5 documents be retrieved, this automatically means that 3 of the retrieved documents are distractors to the LLM — This is just a very generic scenario to explain the concept. In practice, it’s a little more nuanced.
This is even much more the case when you use Expansion with Multiple Queries as discussed in the preceding section because even more documents are retrieved and as such there is an increased likelihood of having distractors in the mix.
In addition, it’s a common occurrence where the answer to a query is within a knowledge base (text corpus stored in the vectordb) but somehow the RAG system doesn’t find the answer or doesn’t find the best answer. In many cases, increasing the number of relevant documents retrieved solves this issue. But why?
Well, the simplest answer is that the retrieval system didn’t retrieve the “best fit” chunk in the first set of chunks when the number of retrieved documents were fewer but when the required number was expanded, it found it.
Remember, when working with GenAI, it’s best practice to manage your tokens either as a strategy to manage costs (if you’re working with a paid model) or at the very least to improve results (paid service or not) as the model has more tokens to use in generating answers.
— -
I wouldn’t go into the nut-and-bolts of how cross-encoder models work here but in this context, the cross-encoder is used to score the relevance of a given document (chunk) to the user’s query.
— -
Say you retrieve 10 documents from the database, when they’re compared with the user’s query using the cross-encoder, you may find that the document in position 9 is more relevant to the query than the document in position 3.
It is, therefore, necessary to re-order the documents so that the more relevant documents are brought higher up in the listing, and with that, if you want to pass in only 5 chunks to the LLM, you’re sure that those are the top 5 most relevant chunks.
query = "..."
results = chroma_collection.query(query_texts=query, n_results=10, include=['documents'])
retrieved_documents = results['documents'][0]
# Load the cross encoder model
from sentence_transformers import CrossEncoder
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
# Create input pairs for the cross encoder amd predict scores for their relevance
pairs = [[query, doc] for doc in retrieved_documents]
scores = cross_encoder.predict(pairs)
relevant_documents = []
# Sort the scores and add the top documents corresponding with the top 5 to
# the list of relevant documents to be passed into the LLM as context
for i in np.argsort(scores)[::-1][5]:
relevant_docuemnts.append(retrieved_docuemnts[i]
The code snippet above shows the process of handling the cross-encoder re-ranking using a sentence transformers cross-encoder model.
CONCLUSION
While RAG is a very effective way of ensuring your GenAI system is giving users accurate and relevant answers to their queries, there are still shortcomings with it and it’s important to pay close attention to the retrieval process in working towards improving the overall system. I hope with this article, you find some guidance on how to improve your systems.
With this article, I wanted to keep things as simple and high-level as possible and so with the code snippets shared, I assumed the reader has handled package installations as well as necessary imports. Should any of my readers encounter issues when trying to implement these, please reach out to me and I would be glad to help out as much as I can.
Did you find this useful or insightful, remember to give me a clap and share the article as well. Also, feel free to share your thoughts and comments below. Find out more about me and my projects as well as my writings on my website: triumphurias.com