Skip to content
This repository has been archived by the owner on Mar 1, 2024. It is now read-only.

Add Corrective RAG LlamaPack #945

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions llama_hub/llama_packs/corrective_rag/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Corrective Retrieval Augmented Generation Pack

This LlamaPack implements the Corrective Retrieval Augmented Generation (CRAG) [paper](https://arxiv.org/pdf/2401.15884.pdf)

Corrective Retrieval Augmented Generation (CRAG) is a method designed to enhance the robustness of language model generation by evaluating and augmenting the relevance of retrieved documents through a lightweight evaluator and large-scale web searches, ensuring more accurate and reliable information is used in generation.

This LlamaPack uses [Tavily AI](https://app.tavily.com/home) API for web-searches. So, we recommend you to get the api-key before proceeding further.

## CLI Usage

You can download llamapacks directly using `llamaindex-cli`, which comes installed with the `llama-index` python package:

```bash
llamaindex-cli download-llamapack CorrectiveRAGPack --download-dir ./corrective_rag_pack
```

You can then inspect the files at `./corrective_rag_pack` and use them as a template for your own project.

## Code Usage

You can download the pack to a the `./corrective_rag_pack` directory:

```python
from llama_index.llama_pack import download_llama_pack

# download and install dependencies
CorrectiveRAGPack = download_llama_pack(
"CorrectiveRAGPack", "./corrective_rag_pack"
)

# You can use any llama-hub loader to get documents!
corrective_rag_pack = CorrectiveRAGPack(documents, tavily_ai_api_key)
```

From here, you can use the pack, or inspect and modify the pack in `./corrective_rag_pack`.

The `run()` function contains around logic behind Corrective Retrieval Augmented Generation - [CRAG](https://arxiv.org/pdf/2401.15884.pdf) paper.

```python
response = corrective_rag_pack.run("What did the author do growing up?", similarity_top_k=2)
```
3 changes: 3 additions & 0 deletions llama_hub/llama_packs/corrective_rag/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from llama_hub.llama_packs.corrective_rag.base import CorrectiveRAGPack

__all__ = ["CorrectiveRAGPack"]
134 changes: 134 additions & 0 deletions llama_hub/llama_packs/corrective_rag/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""LlamaPack class."""
from typing import Any, Dict, List

from llama_index import VectorStoreIndex, SummaryIndex
from llama_index.llama_pack.base import BaseLlamaPack
from llama_index.llms import OpenAI
from llama_index.schema import Document, NodeWithScore
from llama_index.query_pipeline.query import QueryPipeline
from llama_hub.tools.tavily_research.base import TavilyToolSpec
from llama_index.prompts import PromptTemplate

DEFAULT_RELEVANCY_PROMPT_TEMPLATE = PromptTemplate(
template="""As a grader, your task is to evaluate the relevance of a document retrieved in response to a user's question.

Retrieved Document:
-------------------
{context_str}

User Question:
--------------
{query_str}

Evaluation Criteria:
- Consider whether the document contains keywords or topics related to the user's question.
- The evaluation should not be overly stringent; the primary objective is to identify and filter out clearly irrelevant retrievals.

Decision:
- Assign a binary score to indicate the document's relevance.
- Use 'yes' if the document is relevant to the question, or 'no' if it is not.

Please provide your binary score ('yes' or 'no') below to indicate the document's relevance to the user question."""
)

DEFAULT_TRANSFORM_QUERY_TEMPLATE = PromptTemplate(
template="""Your task is to refine a query to ensure it is highly effective for retrieving relevant search results. \n
Analyze the given input to grasp the core semantic intent or meaning. \n
Original Query:
\n ------- \n
{query_str}
\n ------- \n
Your goal is to rephrase or enhance this query to improve its search performance. Ensure the revised query is concise and directly aligned with the intended search objective. \n
Respond with the optimized query only:"""
)


class CorrectiveRAGPack(BaseLlamaPack):
def __init__(self, documents: List[Document], tavily_ai_apikey: str) -> None:
"""Init params."""

llm = OpenAI(model="gpt-4")
self.relevancy_pipeline = QueryPipeline(
chain=[DEFAULT_RELEVANCY_PROMPT_TEMPLATE, llm]
)
self.transform_query_pipeline = QueryPipeline(
chain=[DEFAULT_TRANSFORM_QUERY_TEMPLATE, llm]
)

self.llm = llm
self.index = VectorStoreIndex.from_documents(documents)
self.tavily_tool = TavilyToolSpec(api_key=tavily_ai_apikey)

def get_modules(self) -> Dict[str, Any]:
"""Get modules."""
return {"llm": self.llm, "index": self.index}

def retrieve_nodes(self, query_str: str, **kwargs: Any) -> List[NodeWithScore]:
"""Retrieve the relevant nodes for the query"""
retriever = self.index.as_retriever(**kwargs)
return retriever.retrieve(query_str)

def evaluate_relevancy(
self, retrieved_nodes: List[Document], query_str: str
) -> List[str]:
"""Evaluate relevancy of retrieved documents with the query"""
relevancy_results = []
for node in retrieved_nodes:
relevancy = self.relevancy_pipeline.run(
context_str=node.text, query_str=query_str
)
relevancy_results.append(relevancy.message.content.lower().strip())
return relevancy_results

def extract_relevant_texts(
self, retrieved_nodes: List[NodeWithScore], relevancy_results: List[str]
) -> str:
"""Extract relevant texts from retrieved documents"""
relevant_texts = [
retrieved_nodes[i].text
for i, result in enumerate(relevancy_results)
if result == "yes"
]
return "\n".join(relevant_texts)

def search_with_transformed_query(self, query_str: str) -> str:
"""Search the transformed query with Tavily API"""
search_results = self.tavily_tool.search(query_str, max_results=2)
return "\n".join([result.text for result in search_results])

def get_result(self, relevant_text: str, search_text: str, query_str: str) -> Any:
"""Get result with relevant text"""
documents = [Document(text=relevant_text + "\n" + search_text)]
index = SummaryIndex.from_documents(documents)
query_engine = index.as_query_engine()
return query_engine.query(query_str)

def run(self, query_str: str, **kwargs: Any) -> Any:
"""Run the pipeline."""
# Retrieve nodes based on the input query string.
retrieved_nodes = self.retrieve_nodes(query_str, **kwargs)

# Evaluate the relevancy of each retrieved document in relation to the query string.
relevancy_results = self.evaluate_relevancy(retrieved_nodes, query_str)

# Extract texts from documents that are deemed relevant based on the evaluation.
relevant_text = self.extract_relevant_texts(retrieved_nodes, relevancy_results)

# Initialize search_text variable to handle cases where it might not get defined.
search_text = ""

# If any document is found irrelevant, transform the query string for better search results.
if "no" in relevancy_results:
transformed_query_str = self.transform_query_pipeline.run(
query_str=query_str
).message.content

# Conduct a search with the transformed query string and collect the results.
search_text = self.search_with_transformed_query(transformed_query_str)

# Compile the final result. If there's additional search text from the transformed query,
# it's included; otherwise, only the relevant text from the initial retrieval is returned.
if search_text:
return self.get_result(relevant_text, search_text, query_str)
else:
return self.get_result(relevant_text, "", query_str)
1 change: 1 addition & 0 deletions llama_hub/llama_packs/corrective_rag/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
tavily-python
5 changes: 5 additions & 0 deletions llama_hub/llama_packs/library.json
Original file line number Diff line number Diff line change
Expand Up @@ -292,5 +292,10 @@
"id": "llama_packs/rag_cli_local",
"author": "jerryjliu",
"keywords": ["rag", "cli", "local"]
},
"CorrectiveRAGPack": {
"id": "llama_packs/corrective_rag",
"author": "ravi03071991",
"keywords": ["rag", "retrieve", "crag", "corrective", "corrective_rag"]
}
}
Loading