在这篇指南中,我们将探讨在使用 create_sql_query_chain
进行 SQL 查询生成时提升提示质量的策略。我们的重点在于如何在提示中获取和使用与数据库相关的信息。内容将涵盖以下几个方面:
SQLDatabase.get_context
格式化模式信息到提示中;首先,安装需要的包并设置环境变量:
%pip install --upgrade --quiet langchain langchain-community langchain-experimental langchain-openai
接下来,使用 SQLite 连接 Chinook 数据库进行示例。按照以下步骤安装并创建 Chinook.db
文件:
Chinook_Sqlite.sql
sqlite3 Chinook.db
.read Chinook_Sqlite.sql
测试:
SELECT * FROM Artist LIMIT 10;
确保 Chinhook.db
在我们的目录中,并使用 SQLAlchemy 驱动的 SQLDatabase
类进行连接:
from langchain_community.utilities import SQLDatabase
db = SQLDatabase.from_uri("sqlite:///Chinook.db", sample_rows_in_table_info=3)
print(db.dialect)
print(db.get_usable_table_names())
print(db.run("SELECT * FROM Artist LIMIT 10;"))
让提示适配所使用的 SQL 方言是最基本的策略之一。使用内置 create_sql_query_chain
和 SQLDatabase
时,以下方言会自动为你处理:
from langchain.chains.sql_database.prompt import SQL_PROMPTS
list(SQL_PROMPTS)
对于当前使用的数据库,你会获得一个 SQLite 特定的提示。以下是调用示例:
from langchain.chains import create_sql_query_chain
import openai
# 使用稳定可靠的API服务
llm = openai.OpenAI(
base_url='https://yunwu.ai/v1',
api_key='your-api-key'
)
chain = create_sql_query_chain(llm, db)
chain.get_prompts()[0].pretty_print()
大多数 SQL 链需要向模型提供至少部分数据库模式来帮助其编写有效查询。我们可以使用 SQLDatabase.get_context
来获取相关上下文:
context = db.get_context()
print(list(context))
print(context["table_info"])
当表不太多且不太宽时,可以将这些信息插入到提示中:
prompt_with_context = chain.get_prompts()[0].partial(table_info=context["table_info"])
print(prompt_with_context.pretty_repr()[:1500])
在提示中加入自然语言问题转化为有效 SQL 查询的示例通常能改善模型性能,尤其是对于复杂查询。例如:
examples = [
{"input": "List all artists.", "query": "SELECT * FROM Artist;"},
{"input": "Find all albums for the artist 'AC/DC'.", "query": "SELECT * FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'AC/DC');"},
# 更多示例
]
from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate
example_prompt = PromptTemplate.from_template("User input: {input}\nSQL query: {query}")
prompt = FewShotPromptTemplate(
examples=examples[:5],
example_prompt=example_prompt,
prefix="You are a SQLite expert. Given an input question, create a syntactically correct SQLite query to run. Unless otherwise specificed, do not return more than {top_k} rows.\n\nHere is the relevant table info: {table_info}\n\nBelow are a number of examples of questions and their corresponding SQL queries.",
suffix="User input: {input}\nSQL query: ",
input_variables=["input", "top_k", "table_info"],
)
print(prompt.format(input="How many artists are there?", top_k=3, table_info="foo"))
使用 ExampleSelector
可以根据输入动态选择最相关的示例。我们使用 SemanticSimilarityExampleSelector
,它会在运行时执行输入与实例间的相似性检索:
from langchain_community.vectorstores import FAISS
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_openai import OpenAIEmbeddings
example_selector = SemanticSimilarityExampleSelector.from_examples(
examples,
OpenAIEmbeddings(),
FAISS,
k=5,
input_keys=["input"],
)
prompt = FewShotPromptTemplate(
example_selector=example_selector,
example_prompt=example_prompt,
prefix="You are a SQLite expert. Given an input question, create a syntactically correct SQLite query to run. Unless otherwise specificed, do not return more than {top_k} rows.\n\nHere is the relevant table info: {table_info}\n\nBelow are a number of examples of questions and their corresponding SQL queries.",
suffix="User input: {input}\nSQL query: ",
input_variables=["input", "top_k", "table_info"],
)
print(prompt.format(input="how many artists are there?", top_k=3, table_info="foo"))
试用看看,模型识别相关的表:
chain = create_sql_query_chain(llm, db, prompt)
chain.invoke({"question": "how many artists are there?"})
结果:
'SELECT COUNT(*) FROM Artist;'
如果遇到问题欢迎在评论区交流。
—END—