diff --git a/llama_hub/llama_packs/corrective_rag/README.md b/llama_hub/llama_packs/corrective_rag/README.md new file mode 100644 index 0000000000..9b82c8e174 --- /dev/null +++ b/llama_hub/llama_packs/corrective_rag/README.md @@ -0,0 +1,41 @@ +# Corrective Retrieval Augmented Generation Pack + +This LlamaPack implements the Corrective Retrieval Augmented Generation (CRAG) [paper](https://arxiv.org/pdf/2401.15884.pdf) + +Corrective Retrieval Augmented Generation (CRAG) is a method designed to enhance the robustness of language model generation by evaluating and augmenting the relevance of retrieved documents through a lightweight evaluator and large-scale web searches, ensuring more accurate and reliable information is used in generation. + +This LlamaPack uses [Tavily AI](https://app.tavily.com/home) API for web-searches. So, we recommend you to get the api-key before proceeding further. + +## CLI Usage + +You can download llamapacks directly using `llamaindex-cli`, which comes installed with the `llama-index` python package: + +```bash +llamaindex-cli download-llamapack CorrectiveRAGPack --download-dir ./corrective_rag_pack +``` + +You can then inspect the files at `./corrective_rag_pack` and use them as a template for your own project. + +## Code Usage + +You can download the pack to a the `./corrective_rag_pack` directory: + +```python +from llama_index.llama_pack import download_llama_pack + +# download and install dependencies +CorrectiveRAGPack = download_llama_pack( + "CorrectiveRAGPack", "./corrective_rag_pack" +) + +# You can use any llama-hub loader to get documents! +corrective_rag_pack = CorrectiveRAGPack(documents, tavily_ai_api_key) +``` + +From here, you can use the pack, or inspect and modify the pack in `./corrective_rag_pack`. + +The `run()` function contains around logic behind Corrective Retrieval Augmented Generation - [CRAG](https://arxiv.org/pdf/2401.15884.pdf) paper. + +```python +response = corrective_rag_pack.run("What did the author do growing up?", similarity_top_k=2) +``` diff --git a/llama_hub/llama_packs/corrective_rag/__init__.py b/llama_hub/llama_packs/corrective_rag/__init__.py new file mode 100644 index 0000000000..c6f27e2cf6 --- /dev/null +++ b/llama_hub/llama_packs/corrective_rag/__init__.py @@ -0,0 +1,3 @@ +from llama_hub.llama_packs.corrective_rag.base import CorrectiveRAGPack + +__all__ = ["CorrectiveRAGPack"] diff --git a/llama_hub/llama_packs/corrective_rag/base.py b/llama_hub/llama_packs/corrective_rag/base.py new file mode 100644 index 0000000000..95eaa20f76 --- /dev/null +++ b/llama_hub/llama_packs/corrective_rag/base.py @@ -0,0 +1,134 @@ +"""LlamaPack class.""" +from typing import Any, Dict, List + +from llama_index import VectorStoreIndex, SummaryIndex +from llama_index.llama_pack.base import BaseLlamaPack +from llama_index.llms import OpenAI +from llama_index.schema import Document, NodeWithScore +from llama_index.query_pipeline.query import QueryPipeline +from llama_hub.tools.tavily_research.base import TavilyToolSpec +from llama_index.prompts import PromptTemplate + +DEFAULT_RELEVANCY_PROMPT_TEMPLATE = PromptTemplate( + template="""As a grader, your task is to evaluate the relevance of a document retrieved in response to a user's question. + + Retrieved Document: + ------------------- + {context_str} + + User Question: + -------------- + {query_str} + + Evaluation Criteria: + - Consider whether the document contains keywords or topics related to the user's question. + - The evaluation should not be overly stringent; the primary objective is to identify and filter out clearly irrelevant retrievals. + + Decision: + - Assign a binary score to indicate the document's relevance. + - Use 'yes' if the document is relevant to the question, or 'no' if it is not. + + Please provide your binary score ('yes' or 'no') below to indicate the document's relevance to the user question.""" +) + +DEFAULT_TRANSFORM_QUERY_TEMPLATE = PromptTemplate( + template="""Your task is to refine a query to ensure it is highly effective for retrieving relevant search results. \n + Analyze the given input to grasp the core semantic intent or meaning. \n + Original Query: + \n ------- \n + {query_str} + \n ------- \n + Your goal is to rephrase or enhance this query to improve its search performance. Ensure the revised query is concise and directly aligned with the intended search objective. \n + Respond with the optimized query only:""" +) + + +class CorrectiveRAGPack(BaseLlamaPack): + def __init__(self, documents: List[Document], tavily_ai_apikey: str) -> None: + """Init params.""" + + llm = OpenAI(model="gpt-4") + self.relevancy_pipeline = QueryPipeline( + chain=[DEFAULT_RELEVANCY_PROMPT_TEMPLATE, llm] + ) + self.transform_query_pipeline = QueryPipeline( + chain=[DEFAULT_TRANSFORM_QUERY_TEMPLATE, llm] + ) + + self.llm = llm + self.index = VectorStoreIndex.from_documents(documents) + self.tavily_tool = TavilyToolSpec(api_key=tavily_ai_apikey) + + def get_modules(self) -> Dict[str, Any]: + """Get modules.""" + return {"llm": self.llm, "index": self.index} + + def retrieve_nodes(self, query_str: str, **kwargs: Any) -> List[NodeWithScore]: + """Retrieve the relevant nodes for the query""" + retriever = self.index.as_retriever(**kwargs) + return retriever.retrieve(query_str) + + def evaluate_relevancy( + self, retrieved_nodes: List[Document], query_str: str + ) -> List[str]: + """Evaluate relevancy of retrieved documents with the query""" + relevancy_results = [] + for node in retrieved_nodes: + relevancy = self.relevancy_pipeline.run( + context_str=node.text, query_str=query_str + ) + relevancy_results.append(relevancy.message.content.lower().strip()) + return relevancy_results + + def extract_relevant_texts( + self, retrieved_nodes: List[NodeWithScore], relevancy_results: List[str] + ) -> str: + """Extract relevant texts from retrieved documents""" + relevant_texts = [ + retrieved_nodes[i].text + for i, result in enumerate(relevancy_results) + if result == "yes" + ] + return "\n".join(relevant_texts) + + def search_with_transformed_query(self, query_str: str) -> str: + """Search the transformed query with Tavily API""" + search_results = self.tavily_tool.search(query_str, max_results=2) + return "\n".join([result.text for result in search_results]) + + def get_result(self, relevant_text: str, search_text: str, query_str: str) -> Any: + """Get result with relevant text""" + documents = [Document(text=relevant_text + "\n" + search_text)] + index = SummaryIndex.from_documents(documents) + query_engine = index.as_query_engine() + return query_engine.query(query_str) + + def run(self, query_str: str, **kwargs: Any) -> Any: + """Run the pipeline.""" + # Retrieve nodes based on the input query string. + retrieved_nodes = self.retrieve_nodes(query_str, **kwargs) + + # Evaluate the relevancy of each retrieved document in relation to the query string. + relevancy_results = self.evaluate_relevancy(retrieved_nodes, query_str) + + # Extract texts from documents that are deemed relevant based on the evaluation. + relevant_text = self.extract_relevant_texts(retrieved_nodes, relevancy_results) + + # Initialize search_text variable to handle cases where it might not get defined. + search_text = "" + + # If any document is found irrelevant, transform the query string for better search results. + if "no" in relevancy_results: + transformed_query_str = self.transform_query_pipeline.run( + query_str=query_str + ).message.content + + # Conduct a search with the transformed query string and collect the results. + search_text = self.search_with_transformed_query(transformed_query_str) + + # Compile the final result. If there's additional search text from the transformed query, + # it's included; otherwise, only the relevant text from the initial retrieval is returned. + if search_text: + return self.get_result(relevant_text, search_text, query_str) + else: + return self.get_result(relevant_text, "", query_str) diff --git a/llama_hub/llama_packs/corrective_rag/requirements.txt b/llama_hub/llama_packs/corrective_rag/requirements.txt new file mode 100644 index 0000000000..5ab5bcac8a --- /dev/null +++ b/llama_hub/llama_packs/corrective_rag/requirements.txt @@ -0,0 +1 @@ +tavily-python \ No newline at end of file diff --git a/llama_hub/llama_packs/library.json b/llama_hub/llama_packs/library.json index 186dddb6c4..71c2ddf5ac 100644 --- a/llama_hub/llama_packs/library.json +++ b/llama_hub/llama_packs/library.json @@ -292,5 +292,10 @@ "id": "llama_packs/rag_cli_local", "author": "jerryjliu", "keywords": ["rag", "cli", "local"] + }, + "CorrectiveRAGPack": { + "id": "llama_packs/corrective_rag", + "author": "ravi03071991", + "keywords": ["rag", "retrieve", "crag", "corrective", "corrective_rag"] } }