首頁  >  文章  >  後端開發  >  使用 LangChain 向 IRIS SQL 發送文本

使用 LangChain 向 IRIS SQL 發送文本

PHPz
PHPz原創
2024-08-29 06:33:351127瀏覽

Text to IRIS SQL with LangChain

如何使用 LangChain 框架、IRIS 向量搜尋和 LLM 根據使用者提示產生 IRIS 相容 SQL 的實驗。

本文以此筆記本為基礎。您可以在 OpenExchange 中使用此應用程式在隨時可用的環境中運行它。

設定

首先,我們需要安裝必要的函式庫:

!pip install --upgrade --quiet langchain langchain-openai langchain-iris pandas

接下來,我們導入所需的模組並設定環境:

import os
import datetime
import hashlib
from copy import deepcopy
from sqlalchemy import create_engine
import getpass
import pandas as pd
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain.docstore.document import Document
from langchain_community.document_loaders import DataFrameLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain_core.output_parsers import StrOutputParser
from langchain.globals import set_llm_cache
from langchain.cache import SQLiteCache
from langchain_iris import IRISVector

我們將使用 SQLiteCache 來快取 LLM 呼叫:

# Cache for LLM calls
set_llm_cache(SQLiteCache(database_path=".langchain.db"))

設定IRIS資料庫連線參數:

# IRIS database connection parameters
os.environ["ISC_LOCAL_SQL_HOSTNAME"] = "localhost"
os.environ["ISC_LOCAL_SQL_PORT"] = "1972"
os.environ["ISC_LOCAL_SQL_NAMESPACE"] = "IRISAPP"
os.environ["ISC_LOCAL_SQL_USER"] = "_system"
os.environ["ISC_LOCAL_SQL_PWD"] = "SYS"

如果環境中尚未設定 OpenAI API 金鑰,則提示使用者輸入:

if not "OPENAI_API_KEY" in os.environ:
    os.environ["OPENAI_API_KEY"] = getpass.getpass()

為 IRIS 資料庫建立連接字串:

# IRIS database connection string
args = {
    'hostname': os.getenv("ISC_LOCAL_SQL_HOSTNAME"), 
    'port': os.getenv("ISC_LOCAL_SQL_PORT"), 
    'namespace': os.getenv("ISC_LOCAL_SQL_NAMESPACE"), 
    'username': os.getenv("ISC_LOCAL_SQL_USER"), 
    'password': os.getenv("ISC_LOCAL_SQL_PWD")
}
iris_conn_str = f"iris://{args['username']}:{args['password']}@{args['hostname']}:{args['port']}/{args['namespace']}"

建立與 IRIS 資料庫的連線:

# Connection to IRIS database
engine = create_engine(iris_conn_str)
cnx = engine.connect().connection

準備一個字典來保存系統提示的上下文資訊:

# Dict for context information for system prompt
context = {}
context["top_k"] = 3

即時創作

為了將使用者輸入轉換為與 IRIS 資料庫相容的 SQL 查詢,我們需要為語言模型建立有效的提示。我們從初始提示開始,它提供了產生 SQL 查詢的基本說明。此範本源自LangChain預設的MSSQL提示,並針對IRIS資料庫進行了客製化。

# Basic prompt template with IRIS database SQL instructions
iris_sql_template = """
You are an InterSystems IRIS expert. Given an input question, first create a syntactically correct InterSystems IRIS query to run and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the TOP clause as per InterSystems IRIS. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in single quotes ('') to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CAST(CURRENT_DATE as date) function to get the current date, if the question involves "today".
Use double quotes to delimit columns identifiers.
Return just plain SQL; don't apply any kind of formatting.
"""

此基本提示將語言模型 (LLM) 配置為充當 SQL 專家,並為 IRIS 資料庫提供特定指導。接下來,我們提供一個輔助提示,其中包含有關資料庫架構的信息,以避免出現幻覺。

# SQL template extension for including tables context information
tables_prompt_template = """
Only use the following tables:
{table_info}
"""

為了提高法學碩士回答的準確性,我們使用了一種稱為「少樣本提示」的技術。這涉及向法學碩士展示一些例子。

# SQL template extension for including few shots
prompt_sql_few_shots_template = """
Below are a number of examples of questions and their corresponding SQL queries.

{examples_value}
"""

我們為少量範例定義範本:

# Few shots prompt template
example_prompt_template = "User input: {input}\nSQL query: {query}"
example_prompt = PromptTemplate.from_template(example_prompt_template)

我們使用少樣本模板建立使用者提示:

# User prompt template
user_prompt = "\n" + example_prompt.invoke({"input": "{input}", "query": ""}).to_string()

最後,我們組合所有提示來創建最終的提示:

# Complete prompt template
prompt = (
    ChatPromptTemplate.from_messages([("system", iris_sql_template)])
    + ChatPromptTemplate.from_messages([("system", tables_prompt_template)])
    + ChatPromptTemplate.from_messages([("system", prompt_sql_few_shots_template)])
    + ChatPromptTemplate.from_messages([("human", user_prompt)])
)
prompt

此提示需要變數 example_value、input、table_info 和 top_k。

提示的結構如下:

ChatPromptTemplate(
    input_variables=['examples_value', 'input', 'table_info', 'top_k'], 
    messages=[
        SystemMessagePromptTemplate(
            prompt=PromptTemplate(
                input_variables=['top_k'], 
                template=iris_sql_template
            )
        ), 
        SystemMessagePromptTemplate(
            prompt=PromptTemplate(
                input_variables=['table_info'], 
                template=tables_prompt_template
            )
        ), 
        SystemMessagePromptTemplate(
            prompt=PromptTemplate(
                input_variables=['examples_value'], 
                template=prompt_sql_few_shots_template
            )
        ), 
        HumanMessagePromptTemplate(
            prompt=PromptTemplate(
                input_variables=['input'], 
                template=user_prompt
            )
        )
    ]
)

為了視覺化提示將如何傳送到 LLM,我們可以使用所需變數的佔位符值:

prompt_value = prompt.invoke({
    "top_k": "<top_k>",
    "table_info": "<table_info>",
    "examples_value": "<examples_value>",
    "input": "<input>"
})
print(prompt_value.to_string())
System: 
You are an InterSystems IRIS expert. Given an input question, first create a syntactically correct InterSystems IRIS query to run and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most <top_k> results using the TOP clause as per InterSystems IRIS. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in single quotes ('') to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CAST(CURRENT_DATE as date) function to get the current date, if the question involves "today".
Use double quotes to delimit columns identifiers.
Return just plain SQL; don't apply any kind of formatting.

System: 
Only use the following tables:
<table_info>

System: 
Below are a number of examples of questions and their corresponding SQL queries.

<examples_value>

Human: 
User input: <input>
SQL query: 

現在,我們準備好透過提供必要的變數將此提示傳送給 LLM。準備好後,讓我們繼續下一步。

提供表格資訊

為了建立準確的 SQL 查詢,我們需要向語言模型 (LLM) 提供有關資料庫表的詳細資訊。如果沒有這些信息,法學碩士可能會產​​生看似合理但由於幻覺而不正確的查詢。因此,我們的第一步是建立一個從 IRIS 資料庫檢索表格定義的函數。

檢索表定義的函數

以下函數查詢 INFORMATION_SCHEMA 以取得指定模式的表定義。如果提供了特定的表,它將檢索該表的定義;否則,它將檢索架構中所有表的定義。

def get_table_definitions_array(cnx, schema, table=None):
    cursor = cnx.cursor()

    # Base query to get columns information
    query = """
    SELECT TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_DEFAULT, PRIMARY_KEY, null EXTRA
    FROM INFORMATION_SCHEMA.COLUMNS
    WHERE TABLE_SCHEMA = %s
    """

    # Parameters for the query
    params = [schema]

    # Adding optional filters
    if table:
        query += " AND TABLE_NAME = %s"
        params.append(table)

    # Execute the query
    cursor.execute(query, params)

    # Fetch the results
    rows = cursor.fetchall()

    # Process the results to generate the table definition(s)
    table_definitions = {}
    for row in rows:
        table_schema, table_name, column_name, column_type, is_nullable, column_default, column_key, extra = row
        if table_name not in table_definitions:
            table_definitions[table_name] = []
        table_definitions[table_name].append({
            "column_name": column_name,
            "column_type": column_type,
            "is_nullable": is_nullable,
            "column_default": column_default,
            "column_key": column_key,
            "extra": extra
        })

    primary_keys = {}

    # Build the output string
    result = []
    for table_name, columns in table_definitions.items():
        table_def = f"CREATE TABLE {schema}.{table_name} (\n"
        column_definitions = []
        for column in columns:
            column_def = f"  {column['column_name']} {column['column_type']}"
            if column['is_nullable'] == "NO":
                column_def += " NOT NULL"
            if column['column_default'] is not None:
                column_def += f" DEFAULT {column['column_default']}"
            if column['extra']:
                column_def += f" {column['extra']}"
            column_definitions.append(column_def)
        if table_name in primary_keys:
            pk_def = f"  PRIMARY KEY ({', '.join(primary_keys[table_name])})"
            column_definitions.append(pk_def)
        table_def += ",\n".join(column_definitions)
        table_def += "\n);"
        result.append(table_def)

    return result

檢索架構的表定義

在此範例中,我們使用 Aviation 模式,可在此處取得。

# Retrieve table definitions for the Aviation schema
tables = get_table_definitions_array(cnx, "Aviation")
print(tables)

此函數傳回 Aviation 模式中所有資料表的 CREATE TABLE 語句:

[
    'CREATE TABLE Aviation.Aircraft (\n  Event bigint NOT NULL,\n  ID varchar NOT NULL,\n  AccidentExplosion varchar,\n  AccidentFire varchar,\n  AirFrameHours varchar,\n  AirFrameHoursSince varchar,\n  AirFrameHoursSinceLastInspection varchar,\n  AircraftCategory varchar,\n  AircraftCertMaxGrossWeight integer,\n  AircraftHomeBuilt varchar,\n  AircraftKey integer NOT NULL,\n  AircraftManufacturer varchar,\n  AircraftModel varchar,\n  AircraftRegistrationClass varchar,\n  AircraftSerialNo varchar,\n  AircraftSeries varchar,\n  Damage varchar,\n  DepartureAirportId varchar,\n  DepartureCity varchar,\n  DepartureCountry varchar,\n  DepartureSameAsEvent varchar,\n  DepartureState varchar,\n  DepartureTime integer,\n  DepartureTimeZone varchar,\n  DestinationAirportId varchar,\n  DestinationCity varchar,\n  DestinationCountry varchar,\n  DestinationSameAsLocal varchar,\n  DestinationState varchar,\n  EngineCount integer,\n  EvacuationOccurred varchar,\n  EventId varchar NOT NULL,\n  FlightMedical varchar,\n  FlightMedicalType varchar,\n  FlightPhase integer,\n  FlightPlan varchar,\n  FlightPlanActivated varchar,\n  FlightSiteSeeing varchar,\n  FlightType varchar,\n  GearType varchar,\n  LastInspectionDate timestamp,\n  LastInspectionType varchar,\n  Missing varchar,\n  OperationDomestic varchar,\n  OperationScheduled varchar,\n  OperationType varchar,\n  OperatorCertificate varchar,\n  OperatorCertificateNum varchar,\n  OperatorCode varchar,\n  OperatorCountry varchar,\n  OperatorIndividual varchar,\n  OperatorName varchar,\n  OperatorState varchar,\n  Owner varchar,\n  OwnerCertified varchar,\n  OwnerCountry varchar,\n  OwnerState varchar,\n  RegistrationNumber varchar,\n  ReportedToICAO varchar,\n  SeatsCabinCrew integer,\n  SeatsFlightCrew integer,\n  SeatsPassengers integer,\n  SeatsTotal integer,\n  SecondPilot varchar,\n  childsub bigint NOT NULL DEFAULT $i(^Aviation.EventC("Aircraft"))\n);',
    'CREATE TABLE Aviation.Crew (\n  Aircraft varchar NOT NULL,\n  ID varchar NOT NULL,\n  Age integer,\n  AircraftKey integer NOT NULL,\n  Category varchar,\n  CrewNumber integer NOT NULL,\n  EventId varchar NOT NULL,\n  Injury varchar,\n  MedicalCertification varchar,\n  MedicalCertificationDate timestamp,\n  MedicalCertificationValid varchar,\n  Seat varchar,\n  SeatbeltUsed varchar,\n  Sex varchar,\n  ShoulderHarnessUsed varchar,\n  ToxicologyTestPerformed varchar,\n  childsub bigint NOT NULL DEFAULT $i(^Aviation.AircraftC("Crew"))\n);',
    'CREATE TABLE Aviation.Event (\n  ID bigint NOT NULL DEFAULT $i(^Aviation.EventD),\n  AirportDirection integer,\n  AirportDistance varchar,\n  AirportElevation integer,\n  AirportLocation varchar,\n  AirportName varchar,\n  Altimeter varchar,\n  EventDate timestamp,\n  EventId varchar NOT NULL,\n  EventTime integer,\n  FAADistrictOffice varchar,\n  InjuriesGroundFatal integer,\n  InjuriesGroundMinor integer,\n  InjuriesGroundSerious integer,\n  InjuriesHighest varchar,\n  InjuriesTotal integer,\n  InjuriesTotalFatal integer,\n  InjuriesTotalMinor integer,\n  InjuriesTotalNone integer,\n  InjuriesTotalSerious integer,\n  InvestigatingAgency varchar,\n  LightConditions varchar,\n  LocationCity varchar,\n  LocationCoordsLatitude double,\n  LocationCoordsLongitude double,\n  LocationCountry varchar,\n  LocationSiteZipCode varchar,\n  LocationState varchar,\n  MidAir varchar,\n  NTSBId varchar,\n  NarrativeCause varchar,\n  NarrativeFull varchar,\n  NarrativeSummary varchar,\n  OnGroundCollision varchar,\n  SkyConditionCeiling varchar,\n  SkyConditionCeilingHeight integer,\n  SkyConditionNonCeiling varchar,\n  SkyConditionNonCeilingHeight integer,\n  TimeZone varchar,\n  Type varchar,\n  Visibility varchar,\n  WeatherAirTemperature integer,\n  WeatherPrecipitation varchar,\n  WindDirection integer,\n  WindDirectionIndicator varchar,\n  WindGust integer,\n  WindGustIndicator varchar,\n  WindVelocity integer,\n  WindVelocityIndicator varchar\n);'
]

有了這些表定義,我們就可以繼續下一步,即將它們整合到我們的 LLM 提示中。這確保了 LLM 在產生 SQL 查詢時擁有有關資料庫架構的準確且全面的資訊。

選擇最相關的表

使用資料庫時,尤其是較大的資料庫時,在提示中傳送所有資料表的資料定義語言 (DDL) 可能不切實際。雖然這種方法可能適用於小型資料庫,但現實世界的資料庫通常包含數百或數千個表,導致處理所有這些表的效率很低。

此外,語言模型不太可能需要了解資料庫中的每個表才能有效產生 SQL 查詢。為了應對這項挑戰,我們可以利用語義搜尋功能根據使用者的查詢僅選擇最相關的表格。

方法

我們透過使用語義搜尋和 IRIS 向量搜尋來實現這一目標。請注意,如果 SQL 元素標識符(例如表、欄位和鍵)具有有意義的名稱,則此方法最有效。如果您的識別碼是任意程式碼,請考慮使用資料字典。

步驟

  1. 檢索表格資訊

首先,將表定義提取到 pandas DataFrame 中:

# Retrieve table definitions into a pandas DataFrame
table_def = get_table_definitions_array(cnx=cnx, schema='Aviation')
table_df = pd.DataFrame(data=table_def, columns=["col_def"])
table_df["id"] = table_df.index + 1
table_df

The DataFrame (table_df) will look something like this:

col_def id
0 CREATE TABLE Aviation.Aircraft (\n Event bigi... 1
1 CREATE TABLE Aviation.Crew (\n Aircraft varch... 2
2 CREATE TABLE Aviation.Event (\n ID bigint NOT... 3
  1. Split Definitions into Documents

Next, split the table definitions into Langchain Documents. This step is crucial for handling large chunks of text and extracting text embeddings:

loader = DataFrameLoader(table_df, page_content_column="col_def")
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=400, chunk_overlap=20, separator="\n")
tables_docs = text_splitter.split_documents(documents)
tables_docs

The resulting tables_docs list contains split documents with metadata, like so:

[Document(metadata={'id': 1}, page_content='CREATE TABLE Aviation.Aircraft (\n  Event bigint NOT NULL,\n  ID varchar NOT NULL,\n  ...'),
 Document(metadata={'id': 2}, page_content='CREATE TABLE Aviation.Crew (\n  Aircraft varchar NOT NULL,\n  ID varchar NOT NULL,\n  ...'),
 Document(metadata={'id': 3}, page_content='CREATE TABLE Aviation.Event (\n  ID bigint NOT NULL DEFAULT $i(^Aviation.EventD),\n  ...')]
  1. Extract Embeddings and Store in IRIS

Now, use the IRISVector class from langchain-iris to extract embedding vectors and store them:

tables_vector_store = IRISVector.from_documents(
    embedding=OpenAIEmbeddings(), 
    documents=tables_docs,
    connection_string=iris_conn_str,
    collection_name="sql_tables",
    pre_delete_collection=True
)

Note: The pre_delete_collection flag is set to True for demonstration purposes to ensure a fresh collection in each test run. In a production environment, this flag should generally be set to False.

  1. Find Relevant Documents

With the table embeddings stored, you can now query for relevant tables based on user input:

input_query = "List the first 2 manufacturers"
relevant_tables_docs = tables_vector_store.similarity_search(input_query, k=3)
relevant_tables_docs

For example, querying for manufacturers might return:

[Document(metadata={'id': 1}, page_content='GearType varchar,\n  LastInspectionDate timestamp,\n  ...'),
 Document(metadata={'id': 1}, page_content='AircraftModel varchar,\n  AircraftRegistrationClass varchar,\n  ...'),
 Document(metadata={'id': 3}, page_content='LocationSiteZipCode varchar,\n  LocationState varchar,\n  ...')]

From the metadata, you can see that only table ID 1 (Aviation.Aircraft) is relevant, which aligns with the query.

  1. Handling Edge Cases

While this approach is generally effective, it may not always be perfect. For instance, querying for crash sites might also return less relevant tables:

input_query = "List the top 10 most crash sites"
relevant_tables_docs = tables_vector_store.similarity_search(input_query, k=3)
relevant_tables_docs

Results might include:

[Document(metadata={'id': 3}, page_content='LocationSiteZipCode varchar,\n  LocationState varchar,\n  ...'),
 Document(metadata={'id': 3}, page_content='InjuriesGroundSerious integer,\n  InjuriesHighest varchar,\n  ...'),
 Document(metadata={'id': 1}, page_content='CREATE TABLE Aviation.Aircraft (\n  Event bigint NOT NULL,\n  ID varchar NOT NULL,\n  ...')]

Despite retrieving the correct Aviation.Event table twice, the Aviation.Aircraft table may also appear, which could be improved with additional filtering or thresholding. This is beyond the scope of this example and will be left for future implementations.

  1. Define a Function to Retrieve Relevant Tables

To automate this process, define a function to filter and return the relevant tables based on user input:

def get_relevant_tables(user_input, tables_vector_store, table_df):
    relevant_tables_docs = tables_vector_store.similarity_search(user_input)
    relevant_tables_docs_indices = [x.metadata["id"] for x in relevant_tables_docs]
    indices = table_df["id"].isin(relevant_tables_docs_indices)
    relevant_tables_array = [x for x in table_df[indices]["col_def"]]
    return relevant_tables_array

This function will help in efficiently retrieving only the relevant tables to send to the LLM, reducing the prompt length and improving overall query performance.

Selecting the Most Relevant Examples (Few-Shot Prompting)

When working with language models (LLMs), providing them with relevant examples helps ensure accurate and contextually appropriate responses. These examples, referred to as "few-shot" examples, guide the LLM in understanding the structure and context of the queries it should handle.

In our case, we need to populate the examples_value variable with a diverse set of SQL queries that cover a broad spectrum of IRIS SQL syntax and the tables available in the database. This helps prevent the LLM from generating incorrect or irrelevant queries.

Defining Example Queries

Below is a list of example queries designed to illustrate various SQL operations:

examples = [
    {"input": "List all aircrafts.", "query": "SELECT * FROM Aviation.Aircraft"},
    {"input": "Find all incidents for the aircraft with ID 'N12345'.", "query": "SELECT * FROM Aviation.Event WHERE EventId IN (SELECT EventId FROM Aviation.Aircraft WHERE ID = 'N12345')"},
    {"input": "List all incidents in the 'Commercial' operation type.", "query": "SELECT * FROM Aviation.Event WHERE EventId IN (SELECT EventId FROM Aviation.Aircraft WHERE OperationType = 'Commercial')"},
    {"input": "Find the total number of incidents.", "query": "SELECT COUNT(*) FROM Aviation.Event"},
    {"input": "List all incidents that occurred in 'Canada'.", "query": "SELECT * FROM Aviation.Event WHERE LocationCountry = 'Canada'"},
    {"input": "How many incidents are associated with the aircraft with AircraftKey 5?", "query": "SELECT COUNT(*) FROM Aviation.Aircraft WHERE AircraftKey = 5"},
    {"input": "Find the total number of distinct aircrafts involved in incidents.", "query": "SELECT COUNT(DISTINCT AircraftKey) FROM Aviation.Aircraft"},
    {"input": "List all incidents that occurred after 5 PM.", "query": "SELECT * FROM Aviation.Event WHERE EventTime > 1700"},
    {"input": "Who are the top 5 operators by the number of incidents?", "query": "SELECT TOP 5 OperatorName, COUNT(*) AS IncidentCount FROM Aviation.Aircraft GROUP BY OperatorName ORDER BY IncidentCount DESC"},
    {"input": "Which incidents occurred in the year 2020?", "query": "SELECT * FROM Aviation.Event WHERE YEAR(EventDate) = '2020'"},
    {"input": "What was the month with most events in the year 2020?", "query": "SELECT TOP 1 MONTH(EventDate) EventMonth, COUNT(*) EventCount FROM Aviation.Event WHERE YEAR(EventDate) = '2020' GROUP BY MONTH(EventDate) ORDER BY EventCount DESC"},
    {"input": "How many crew members were involved in incidents?", "query": "SELECT COUNT(*) FROM Aviation.Crew"},
    {"input": "List all incidents with detailed aircraft information for incidents that occurred in the year 2012.", "query": "SELECT e.EventId, e.EventDate, a.AircraftManufacturer, a.AircraftModel, a.AircraftCategory FROM Aviation.Event e JOIN Aviation.Aircraft a ON e.EventId = a.EventId WHERE Year(e.EventDate) = 2012"},
    {"input": "Find all incidents where there were more than 5 injuries and include the aircraft manufacturer and model.", "query": "SELECT e.EventId, e.InjuriesTotal, a.AircraftManufacturer, a.AircraftModel FROM Aviation.Event e JOIN Aviation.Aircraft a ON e.EventId = a.EventId WHERE e.InjuriesTotal > 5"},
    {"input": "List all crew members involved in incidents with serious injuries, along with the incident date and location.", "query": "SELECT c.CrewNumber AS 'Crew Number', c.Age, c.Sex AS Gender, e.EventDate AS 'Event Date', e.LocationCity AS 'Location City', e.LocationState AS 'Location State' FROM Aviation.Crew c JOIN Aviation.Event e ON c.EventId = e.EventId WHERE c.Injury = 'Serious'"}
]

Selecting Relevant Examples

Given the ever-expanding list of examples, it’s impractical to provide the LLM with all of them. Instead, we use IRIS Vector Search along with the SemanticSimilarityExampleSelector class to identify the most relevant examples based on user prompts.

Define the Example Selector:

example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples,
    OpenAIEmbeddings(),
    IRISVector,
    k=5,
    input_keys=["input"],
    connection_string=iris_conn_str,
    collection_name="sql_samples",
    pre_delete_collection=True
)

Note: The pre_delete_collection flag is used here for demonstration purposes to ensure a fresh collection in each test run. In a production environment, this flag should be set to False to avoid unnecessary deletions.

Query the Selector:

To find the most relevant examples for a given input, use the selector as follows:

input_query = "Find all events in 2010 informing the Event Id and date, location city and state, aircraft manufacturer and model."
relevant_examples = example_selector.select_examples({"input": input_query})

The results might look like this:

[{'input': 'List all incidents with detailed aircraft information for incidents that occurred in the year 2012.', 'query': 'SELECT e.EventId, e.EventDate, a.AircraftManufacturer, a.AircraftModel, a.AircraftCategory FROM Aviation.Event e JOIN Aviation.Aircraft a ON e.EventId = a.EventId WHERE Year(e.EventDate) = 2012'},
 {'input': "Find all incidents for the aircraft with ID 'N12345'.", 'query': "SELECT * FROM Aviation.Event WHERE EventId IN (SELECT EventId FROM Aviation.Aircraft WHERE ID = 'N12345')"},
 {'input': 'Find all incidents where there were more than 5 injuries and include the aircraft manufacturer and model.', 'query': 'SELECT e.EventId, e.InjuriesTotal, a.AircraftManufacturer, a.AircraftModel FROM Aviation.Event e JOIN Aviation.Aircraft a ON e.EventId = a.EventId WHERE e.InjuriesTotal > 5'},
 {'input': 'List all aircrafts.', 'query': 'SELECT * FROM Aviation.Aircraft'},
 {'input': 'Find the total number of distinct aircrafts involved in incidents.', 'query': 'SELECT COUNT(DISTINCT AircraftKey) FROM Aviation.Aircraft'}]

If you specifically need examples related to quantities, you can query the selector accordingly:

input_query = "What is the number of incidents involving Boeing aircraft."
quantity_examples = example_selector.select_examples({"input": input_query})

The output may be:

[{'input': 'How many incidents are associated with the aircraft with AircraftKey 5?', 'query': 'SELECT COUNT(*) FROM Aviation.Aircraft WHERE AircraftKey = 5'},
 {'input': 'Find the total number of distinct aircrafts involved in incidents.', 'query': 'SELECT COUNT(DISTINCT AircraftKey) FROM Aviation.Aircraft'},
 {'input': 'How many crew members were involved in incidents?', 'query': 'SELECT COUNT(*) FROM Aviation.Crew'},
 {'input': 'Find all incidents where there were more than 5 injuries and include the aircraft manufacturer and model.', 'query': 'SELECT e.EventId, e.InjuriesTotal, a.AircraftManufacturer, a.AircraftModel FROM Aviation.Event e JOIN Aviation.Aircraft a ON e.EventId = a.EventId WHERE e.InjuriesTotal > 5'},
 {'input': 'List all incidents with detailed aircraft information for incidents that occurred in the year 2012.', 'query': 'SELECT e.EventId, e.EventDate, a.AircraftManufacturer, a.AircraftModel, a.AircraftCategory FROM Aviation.Event e JOIN Aviation.Aircraft a ON e.EventId = a.EventId WHERE Year(e.EventDate) = 2012'}]

This output includes examples that specifically address counting and quantities.

Future Considerations

While the SemanticSimilarityExampleSelector is powerful, it’s important to note that not all selected examples may be perfect. Future improvements may involve adding filters or thresholds to exclude less relevant results, ensuring that only the most appropriate examples are provided to the LLM.

Accuracy Test

To assess the performance of the prompt and SQL query generation, we need to set up and run a series of tests. The goal is to evaluate how well the LLM generates SQL queries based on user inputs, with and without the use of example-based few shots.

Function to Generate SQL Queries

We start by defining a function that uses the LLM to generate SQL queries based on the provided context, prompt, user input, and other parameters:

def get_sql_from_text(context, prompt, user_input, use_few_shots, tables_vector_store, table_df, example_selector=None, example_prompt=None):
    relevant_tables = get_relevant_tables(user_input, tables_vector_store, table_df)
    context["table_info"] = "\n\n".join(relevant_tables)

    examples = example_selector.select_examples({"input": user_input}) if example_selector else []
    context["examples_value"] = "\n\n".join([
        example_prompt.invoke(x).to_string() for x in examples
    ])

    model = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
    output_parser = StrOutputParser()
    chain_model = prompt | model | output_parser

    response = chain_model.invoke({
        "top_k": context["top_k"],
        "table_info": context["table_info"],
        "examples_value": context["examples_value"],
        "input": user_input
    })
    return response

Execute the Prompt

Test the prompt with and without examples:

# Prompt execution **with** few shots
input = "Find all events in 2010 informing the Event Id and date, location city and state, aircraft manufacturer and model."
response_with_few_shots = get_sql_from_text(
    context, 
    prompt, 
    user_input=input, 
    use_few_shots=True, 
    tables_vector_store=tables_vector_store, 
    table_df=table_df,
    example_selector=example_selector, 
    example_prompt=example_prompt,
)
print(response_with_few_shots)
SELECT e.EventId, e.EventDate, e.LocationCity, e.LocationState, a.AircraftManufacturer, a.AircraftModel
FROM Aviation.Event e
JOIN Aviation.Aircraft a ON e.EventId = a.EventId
WHERE Year(e.EventDate) = 2010
# Prompt execution **without** few shots
input = "Find all events in 2010 informing the Event Id and date, location city and state, aircraft manufacturer and model."
response_with_no_few_shots = get_sql_from_text(
    context, 
    prompt, 
    user_input=input, 
    use_few_shots=False, 
    tables_vector_store=tables_vector_store, 
    table_df=table_df,
)
print(response_with_no_few_shots)
SELECT TOP 3 "EventId", "EventDate", "LocationCity", "LocationState", "AircraftManufacturer", "AircraftModel"
FROM Aviation.Event e
JOIN Aviation.Aircraft a ON e.ID = a.Event
WHERE e.EventDate >= '2010-01-01' AND e.EventDate < '2011-01-01'

Utility Functions for Testing

To test the generated SQL queries, we define some utility functions:

def execute_sql_query(cnx, query):
    try:
        cursor = cnx.cursor()
        cursor.execute(query)
        rows = cursor.fetchall()
        return rows
    except:
        print('Error running query:')
        print(query)
        print('-'*80)
    return None

def sql_result_equals(cnx, query, expected):
    rows = execute_sql_query(cnx, query)
    result = [set(row._asdict().values()) for row in rows or []]
    if result != expected and rows is not None:
        print('Result not as expected for query:')
        print(query)
        print('-'*80)
    return result == expected
# SQL test for prompt **with** few shots
print("SQL is OK" if not execute_sql_query(cnx, response_with_few_shots) is None else "SQL is not OK")
    SQL is OK
# SQL test for prompt **without** few shots
print("SQL is OK" if not execute_sql_query(cnx, response_with_no_few_shots) is None else "SQL is not OK")
    error on running query: 
    SELECT TOP 3 "EventId", "EventDate", "LocationCity", "LocationState", "AircraftManufacturer", "AircraftModel"
    FROM Aviation.Event e
    JOIN Aviation.Aircraft a ON e.ID = a.Event
    WHERE e.EventDate >= '2010-01-01' AND e.EventDate < '2011-01-01'
    --------------------------------------------------------------------------------
    SQL is not OK

Define and Execute Tests

Define a set of test cases and run them:

tests = [{
    "input": "What were the top 3 years with the most recorded events?",
    "expected": [{128, 2003}, {122, 2007}, {117, 2005}]
},{
    "input": "How many incidents involving Boeing aircraft.",
    "expected": [{5}]
},{
    "input": "How many incidents that resulted in fatalities.",
    "expected": [{237}]
},{
    "input": "List event Id and date and, crew number, age and gender for incidents that occurred in 2013.",
    "expected": [{1, datetime.datetime(2013, 3, 4, 11, 6), '20130305X71252', 59, 'M'},
                 {1, datetime.datetime(2013, 1, 1, 15, 0), '20130101X94035', 32, 'M'},
                 {2, datetime.datetime(2013, 1, 1, 15, 0), '20130101X94035', 35, 'M'},
                 {1, datetime.datetime(2013, 1, 12, 15, 0), '20130113X42535', 25, 'M'},
                 {2, datetime.datetime(2013, 1, 12, 15, 0), '20130113X42535', 34, 'M'},
                 {1, datetime.datetime(2013, 2, 1, 15, 0), '20130203X53401', 29, 'M'},
                 {1, datetime.datetime(2013, 2, 15, 15, 0), '20130218X70747', 27, 'M'},
                 {1, datetime.datetime(2013, 3, 2, 15, 0), '20130303X21011', 49, 'M'},
                 {1, datetime.datetime(2013, 3, 23, 13, 52), '20130326X85150', 'M', None}]
},{
    "input": "Find the total number of incidents that occurred in the United States.",
    "expected": [{1178}]
},{
    "input": "List all incidents latitude and longitude coordinates with more than 5 injuries that occurred in 2010.",
    "expected": [{-78.76833333333333, 43.25277777777778}]
},{
    "input": "Find all incidents in 2010 informing the Event Id and date, location city and state, aircraft manufacturer and model.",
    "expected": [
        {datetime.datetime(2010, 5, 20, 13, 43), '20100520X60222', 'CIRRUS DESIGN CORP', 'Farmingdale', 'New York', 'SR22'},
        {datetime.datetime(2010, 4, 11, 15, 0), '20100411X73253', 'CZECH AIRCRAFT WORKS SPOL SRO', 'Millbrook', 'New York', 'SPORTCRUISER'},
        {'108', datetime.datetime(2010, 1, 9, 12, 55), '20100111X41106', 'Bayport', 'New York', 'STINSON'},
        {datetime.datetime(2010, 8, 1, 14, 20), '20100801X85218', 'A185F', 'CESSNA', 'New York', 'Newfane'}
    ]
}]

Accuracy Evaluation

Run the tests and calculate the accuracy:

def execute_tests(cnx, context, prompt, use_few_shots, tables_vector_store, table_df, example_selector, example_prompt):
    tests_generated_sql = [(x, get_sql_from_text(
            context, 
            prompt, 
            user_input=x['input'], 
            use_few_shots=use_few_shots, 
            tables_vector_store=tables_vector_store, 
            table_df=table_df,
            example_selector=example_selector if use_few_shots else None, 
            example_prompt=example_prompt if use_few_shots else None,
        )) for x in deepcopy(tests)]

    tests_sql_executions = [(x[0], sql_result_equals(cnx, x[1], x[0]['expected'])) 
                            for x in tests_generated_sql]

    accuracy = sum(1 for i in tests_sql_executions if i[1] == True) / len(tests_sql_executions)
    print(f'Accuracy: {accuracy}')
    print('-'*80)

Results

# Accuracy tests for prompts executed **without** few shots
use_few_shots = False
execute_tests(
    cnx,
    context, 
    prompt, 
    use_few_shots, 
    tables_vector_store, 
    table_df, 
    example_selector, 
    example_prompt
)
    error on running query: 
    SELECT "EventDate", COUNT("EventId") as "TotalEvents"
    FROM Aviation.Event
    GROUP BY "EventDate"
    ORDER BY "TotalEvents" DESC
    TOP 3;
    --------------------------------------------------------------------------------
    error on running query: 
    SELECT "EventId", "EventDate", "C"."CrewNumber", "C"."Age", "C"."Sex"
    FROM "Aviation.Event" AS "E"
    JOIN "Aviation.Crew" AS "C" ON "E"."ID" = "C"."EventId"
    WHERE "E"."EventDate" >= '2013-01-01' AND "E"."EventDate" < '2014-01-01'
    --------------------------------------------------------------------------------
    result not expected for query: 
    SELECT TOP 3 "e"."EventId", "e"."EventDate", "e"."LocationCity", "e"."LocationState", "a"."AircraftManufacturer", "a"."AircraftModel"
    FROM "Aviation"."Event" AS "e"
    JOIN "Aviation"."Aircraft" AS "a" ON "e"."ID" = "a"."Event"
    WHERE "e"."EventDate" >= '2010-01-01' AND "e"."EventDate" < '2011-01-01'
    --------------------------------------------------------------------------------
    accuracy: 0.5714285714285714
    --------------------------------------------------------------------------------
# Accuracy tests for prompts executed **with** few shots
use_few_shots = True
execute_tests(
    cnx,
    context, 
    prompt, 
    use_few_shots, 
    tables_vector_store, 
    table_df, 
    example_selector, 
    example_prompt
)
    error on running query: 
    SELECT e.EventId, e.EventDate, e.LocationCity, e.LocationState, a.AircraftManufacturer, a.AircraftModel
    FROM Aviation.Event e
    JOIN Aviation.Aircraft a ON e.EventId = a.EventId
    WHERE Year(e.EventDate) = 2010 TOP 3
    --------------------------------------------------------------------------------
    accuracy: 0.8571428571428571
    --------------------------------------------------------------------------------

Conclusion

The accuracy of SQL queries generated with examples (few shots) is approximately 49% higher compared to those generated without examples (85% vs. 57%).

References

  • https://python.langchain.com/v0.1/docs/expression_language/get_started/
  • https://python.langchain.com/v0.1/docs/use_cases/sql/prompting/
  • https://python.langchain.com/v0.1/docs/modules/model_io/prompts/composition/

以上是使用 LangChain 向 IRIS SQL 發送文本的詳細內容。更多資訊請關注PHP中文網其他相關文章!

陳述:
本文內容由網友自願投稿,版權歸原作者所有。本站不承擔相應的法律責任。如發現涉嫌抄襲或侵權的內容,請聯絡admin@php.cn