Skip to content

Commit

Permalink
Merge pull request #148 from 0xThresh/feat-text-to-sql-valves
Browse files Browse the repository at this point in the history
Add Valves to Text to SQL Pipeline
  • Loading branch information
tjbck authored Jul 3, 2024
2 parents a5bf4bd + 7b257d3 commit f5a758c
Showing 1 changed file with 34 additions and 14 deletions.
48 changes: 34 additions & 14 deletions examples/pipelines/rag/text_to_sql_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,51 @@

from typing import List, Union, Generator, Iterator
import os
from pydantic import BaseModel
from llama_index.llms.ollama import Ollama
from llama_index.core.query_engine import NLSQLTableQueryEngine
from llama_index.core import SQLDatabase, PromptTemplate
from sqlalchemy import create_engine


class Pipeline:
class Valves(BaseModel):
DB_HOST: str
DB_PORT: str
DB_USER: str
DB_PASSWORD: str
DB_DATABASE: str
DB_TABLES: list[str]
OLLAMA_HOST: str
TEXT_TO_SQL_MODEL: str


# Update valves/ environment variables based on your selected database
def __init__(self):
self.PG_HOST = os.environ["PG_HOST"]
self.PG_PORT = os.environ["PG_PORT"]
self.PG_USER = os.environ["PG_USER"]
self.PG_PASSWORD = os.environ["PG_PASSWORD"]
self.PG_DB = os.environ["PG_DB"]
self.ollama_host = "http://host.docker.internal:11434" # Make sure to update with the URL of your Ollama host, such at http://localhost:11434 or remote server address
self.model = "phi3:medium-128k" # Model to use for text-to-SQL generation
self.name = "Database RAG Pipeline"
self.engine = None
self.nlsql_response = ""
self.tables = ["db_table"] # Update to the name of the database table you want to get data from

# Initialize
self.valves = self.Valves(
**{
"pipelines": ["*"], # Connect to all pipelines
"DB_HOST": os.environ["PG_HOST"], # Database hostname
"DB_PORT": os.environ["PG_PORT"], # Database port
"DB_USER": os.environ["PG_USER"], # User to connect to the database with
"DB_PASSWORD": os.environ["PG_PASSWORD"], # Password to connect to the database with
"DB_DATABASE": os.environ["PG_DB"], # Database to select on the DB instance
"DB_TABLES": ["albums"], # Table(s) to run queries against
"OLLAMA_HOST": "http://host.docker.internal:11434", # Make sure to update with the URL of your Ollama host, such as http://localhost:11434 or remote server address
"TEXT_TO_SQL_MODEL": "phi3:latest" # Model to use for text-to-SQL generation
}
)

def init_db_connection(self):
self.engine = create_engine(f"postgresql+psycopg2://{self.PG_USER}:{self.PG_PASSWORD}@{self.PG_HOST}:{self.PG_PORT}/{self.PG_DB}")
# Update your DB connection string based on selected DB engine - current connection string is for Postgres
self.engine = create_engine(f"postgresql+psycopg2://{self.valves.DB_USER}:{self.valves.DB_PASSWORD}@{self.valves.DB_HOST}:{self.valves.DB_PORT}/{self.valves.DB_DATABASE}")
return self.engine


async def on_startup(self):
# This function is called when the server is started.
self.init_db_connection()
Expand All @@ -48,10 +69,10 @@ def pipe(
# Debug logging is required to see what SQL query is generated by the LlamaIndex library; enable on Pipelines server if needed

# Create database reader for Postgres
sql_database = SQLDatabase(self.engine, include_tables=self.tables)
sql_database = SQLDatabase(self.engine, include_tables=self.valves.DB_TABLES)

# Set up LLM connection; uses phi3 model with 128k context limit since some queries have returned 20k+ tokens
llm = Ollama(model=self.model, base_url=self.ollama_host, request_timeout=180.0, context_window=30000)
llm = Ollama(model=self.valves.TEXT_TO_SQL_MODEL, base_url=self.valves.OLLAMA_HOST, request_timeout=180.0, context_window=30000)

# Set up the custom prompt used when generating SQL queries from text
text_to_sql_prompt = """
Expand All @@ -78,7 +99,7 @@ def pipe(

query_engine = NLSQLTableQueryEngine(
sql_database=sql_database,
tables=self.tables,
tables=self.valves.DB_TABLES,
llm=llm,
embed_model="local",
text_to_sql_prompt=text_to_sql_template,
Expand All @@ -88,4 +109,3 @@ def pipe(
response = query_engine.query(user_message)

return response.response_gen

0 comments on commit f5a758c

Please sign in to comment.