langchain.chains.create_sql_query_chain
函数是 LangChain 库中的一个函数,用于创建基于自然语言生成 SQL 查询的链(Chain),结合语言模型(LLM)和数据库上下文生成可执行的 SQL 语句。
本文基于 LangChain 0.3.x,详细介绍 create_sql_query_chain
的定义、参数、方法和典型场景,并提供一个独立示例,展示如何使用 create_sql_query_chain
结合 ChatOpenAI
和 SQLDatabase
实现人工智能主题的数据库查询,示例突出该函数在自然语言到 SQL 转换中的作用。( LangChain Chains 文档)
langchain.chains.create_sql_query_chain
简介create_sql_query_chain
是 LangChain 库中 langchain.chains.sql_database.query
模块的一个函数,用于创建一个 LCEL(LangChain Expression Language)链,将自然语言查询转换为 SQL 语句。它结合语言模型和数据库上下文(如表结构),生成符合数据库方言的 SQL 查询,适合需要从数据库中提取数据的场景。
核心功能:
适用场景:
与其他链对比:
create_sql_query_chain
: 生成 SQL 查询,专注于数据库交互。RetrievalQA
: 用于文档检索的 RAG 链。ConversationChain
: 处理对话历史。SQLDatabaseChain
: 执行 SQL 查询并返回结果(更高级,但已部分弃用)。注意:
langchain_community.utilities.SQLDatabase
提供数据库上下文。以下是 create_sql_query_chain
的定义,基于 LangChain 源码(langchain/chains/sql_database/query.py
)和官方文档(create_sql_query_chain)。
def create_sql_query_chain(
llm: BaseLanguageModel,
db: SQLDatabase,
prompt: Optional[BasePromptTemplate] = None,
k: int = 5
) -> Runnable[Union[str, Dict[str, Any]], str]
llm
(BaseLanguageModel
):语言模型,用于生成 SQL 查询(如 ChatOpenAI
)。db
(SQLDatabase
):数据库实例,提供表结构和方言信息。prompt
(Optional[BasePromptTemplate]
):自定义提示模板,默认为内置模板。k
(int
):生成查询时考虑的上下文行数,默认为 5。Runnable[Union[str, Dict[str, Any]], str]
:一个 LCEL 链,接受字符串或字典输入,输出 SQL 查询字符串。默认模板(简化为中文):
你是一个 SQL 专家,基于以下数据库信息将用户问题转换为 SQL 查询。
数据库方言: {dialect}
表结构: {table_info}
问题: {input}
输出仅包含 SQL 查询语句。
chain.invoke("有多少用户?")
input
键:chain.invoke({"input": "有多少用户?"})
SELECT COUNT(*) FROM users;
create_sql_query_chain
的运行逻辑如下:
db
获取数据库元数据(表名、列名、样本数据)。llm
和 prompt
,生成 SQL 查询。k
参数控制上下文数据量。db.run(query)
)。在 LCEL 中的作用:
SQLDatabase.run
或自定义解析器组合,获取查询结果。示例流程:
from langchain.chains import create_sql_query_chain
from langchain_openai import ChatOpenAI
from langchain_community.utilities import SQLDatabase
llm = ChatOpenAI()
db = SQLDatabase.from_uri("sqlite:///example.db")
chain = create_sql_query_chain(llm, db)
query = chain.invoke("有多少用户?")
# 输出: SELECT COUNT(*) FROM users;
create_sql_query_chain
返回一个 Runnable
对象,支持以下方法:
invoke
def invoke(self, input: Union[str, Dict[str, Any]], config: Optional[RunnableConfig] = None) -> str
input
(str | Dict
):问题字符串或包含 input
键的字典。config
(Optional[RunnableConfig]
):运行配置(如超时)。query = chain.invoke("列出所有用户的姓名")
print(query) # 输出: SELECT name FROM users;
ainvoke
async def ainvoke(self, input: Union[str, Dict[str, Any]], config: Optional[RunnableConfig] = None) -> str
query = await chain.ainvoke("有多少用户?")
print(query) # 输出: SELECT COUNT(*) FROM users;
stream
/ astream
for chunk in chain.stream("列出所有用户"):
print(chunk, end="")
以下是使用 create_sql_query_chain
的步骤。
pip install --upgrade langchain langchain-openai sqlalchemy
export OPENAI_API_KEY="your-api-key"
或在代码中:
import os
os.environ["OPENAI_API_KEY"] = "your-api-key"
from langchain_community.utilities import SQLDatabase
from langchain_openai import ChatOpenAI
db = SQLDatabase.from_uri("sqlite:///example.db")
llm = ChatOpenAI(model="gpt-3.5-turbo")
from langchain.chains import create_sql_query_chain
chain = create_sql_query_chain(llm, db)
query = chain.invoke("有多少用户?")
result = db.run(query)
print(result)
create_sql_query_chain
的示例以下是一个独立示例,展示如何使用 create_sql_query_chain
结合 ChatOpenAI
和 SQLDatabase
实现人工智能主题的数据库查询。链生成 SQL 查询并执行,回答用户关于 AI 相关数据的自然语言问题。
准备环境:
export OPENAI_API_KEY="your-api-key"
pip install --upgrade langchain langchain-openai sqlalchemy
ai_data.db
:CREATE TABLE ai_projects (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
category TEXT,
year INTEGER
);
INSERT INTO ai_projects (name, category, year) VALUES
('AlphaGo', 'Game AI', 2016),
('GPT-3', 'NLP', 2020),
('DALL-E', 'Generative AI', 2021);
代码:
from langchain_openai import ChatOpenAI
from langchain_community.utilities import SQLDatabase
from langchain.chains import create_sql_query_chain
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableLambda
from langchain_core.output_parsers import StrOutputParser
# 初始化 ChatOpenAI
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.7)
# 初始化 SQLDatabase
db = SQLDatabase.from_uri("sqlite:///ai_data.db")
# 创建 SQL 查询链
sql_chain = create_sql_query_chain(llm, db)
# 定义提示模板,用于格式化查询结果
prompt = ChatPromptTemplate.from_template(
"你是一个数据库专家,基于以下 SQL 查询结果回答问题:\n结果:{result}\n问题:{input}\n回答:"
)
# 定义输出解析器
parser = StrOutputParser()
# 创建完整工作流
chain = (
{
"query": sql_chain,
"input": RunnableLambda(lambda x: x.strip())
}
| RunnableLambda(lambda x: {"result": db.run(x["query"]), "input": x["input"]})
| prompt
| llm
| parser
)
# 测试 SQL 查询工作流
print("测试 create_sql_query_chain 和数据库查询:")
try:
questions = [
"有多少 AI 项目?",
"列出 2020 年后的 AI 项目"
]
for question in questions:
result = chain.invoke(question)
print(f"\n输入问题: {question}")
print(f"回答: {result}")
except Exception as e:
print(f"错误: {e}")
输出示例(实际输出取决于模型和数据库):
测试 create_sql_query_chain 和数据库查询:
输入问题: 有多少 AI 项目?
回答: 数据库中有 3 个 AI 项目。
输入问题: 列出 2020 年后的 AI 项目
回答: 2020 年后的 AI 项目包括:
- GPT-3(类别:NLP,2020 年)
- DALL-E(类别:Generative AI,2021 年)
ChatOpenAI
调用 gpt-3.5-turbo
,设置 temperature=0.7
。ai_data.db
,包含 ai_projects
表。create_sql_query_chain
创建 SQL 生成链。sql_chain
生成 SQL 查询。RunnableLambda
清理输入并执行查询(db.run
)。prompt
、llm
和 parser
格式化结果。try-except
捕获 API 或数据库错误。运行要求:
export OPENAI_API_KEY="your-api-key"
pip install --upgrade langchain langchain-openai sqlalchemy
ai_data.db
已创建。https://api.openai.com
.OPENAI_API_KEY
已设置:echo $OPENAI_API_KEY
llm = ChatOpenAI(api_key="your-api-key")
db = SQLDatabase.from_uri("sqlite:///ai_data.db")
print(db.get_table_info())
db = SQLDatabase.from_uri("sqlite:///ai_data.db", include_tables=["ai_projects"])
query = sql_chain.invoke("无效查询")
print(query)
prompt = ChatPromptTemplate.from_template("仅生成 SELECT 查询: {input}")
sql_chain = create_sql_query_chain(llm, db, prompt=prompt)
ainvoke
:query = await sql_chain.ainvoke("有多少用户?")
langchain.cache
:from langchain.cache import InMemoryCache
langchain.llm_cache = InMemoryCache()
k
:sql_chain = create_sql_query_chain(llm, db, k=3)
query = sql_chain.invoke("无效查询")
print(query)
print(db.run(query))
print(os.environ.get("OPENAI_API_KEY"))
llm = ChatOpenAI(timeout=30)
print(db.get_usable_table_names())
print(db.run("SELECT 1"))
Q1:如何自定义提示模板?
A:创建自定义 ChatPromptTemplate
:
from langchain_core.prompts import ChatPromptTemplate
prompt = ChatPromptTemplate.from_template(
"生成 {dialect} 的 SQL 查询,基于表结构 {table_info},回答问题:{input}\n仅输出 SQL 语句。"
)
sql_chain = create_sql_query_chain(llm, db, prompt=prompt)
Q2:如何执行和解析查询结果?
A:结合 db.run
和提示:
chain = sql_chain | RunnableLambda(lambda x: {"query": x, "result": db.run(x)})
result = chain.invoke("有多少用户?")
print(result["result"])
Q3:如何与代理结合?
A:使用 SQLDatabaseToolkit
:
from langchain.agents import create_sql_agent
from langchain_community.agent_toolkits import SQLDatabaseToolkit
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
agent = create_sql_agent(llm=llm, toolkit=toolkit, verbose=True)
result = agent.run("有多少 AI 项目?")
Q4:如何支持开源模型?
A:使用 ChatOllama
:
from langchain_ollama import ChatOllama
llm = ChatOllama(model="llama3")
sql_chain = create_sql_query_chain(llm, db)
query = sql_chain.invoke("有多少用户?")
langchain.chains.create_sql_query_chain
是 LangChain 中用于自然语言到 SQL 查询转换的强大工具,核心功能包括:
llm
、db
、prompt
和 k
。invoke
(同步)、ainvoke
(异步)、stream
(流式)。