【LangChain】langchain.chains.create_sql_query_chain() 函数:基于自然语言生成 SQL 查询的链(Chain)

langchain.chains.create_sql_query_chain 函数是 LangChain 库中的一个函数,用于创建基于自然语言生成 SQL 查询的链(Chain),结合语言模型(LLM)和数据库上下文生成可执行的 SQL 语句。

本文基于 LangChain 0.3.x,详细介绍 create_sql_query_chain 的定义、参数、方法和典型场景,并提供一个独立示例,展示如何使用 create_sql_query_chain 结合 ChatOpenAISQLDatabase 实现人工智能主题的数据库查询,示例突出该函数在自然语言到 SQL 转换中的作用。( LangChain Chains 文档)


langchain.chains.create_sql_query_chain 简介

create_sql_query_chain 是 LangChain 库中 langchain.chains.sql_database.query 模块的一个函数,用于创建一个 LCEL(LangChain Expression Language)链,将自然语言查询转换为 SQL 语句。它结合语言模型和数据库上下文(如表结构),生成符合数据库方言的 SQL 查询,适合需要从数据库中提取数据的场景。

核心功能

  • 将自然语言问题转换为结构化的 SQL 查询。
  • 利用数据库元数据(如表名、列名)生成准确的查询。
  • 支持多种数据库方言(如 SQLite、PostgreSQL、MySQL)。
  • 与 LCEL 链无缝集成,可进一步处理查询结果。

适用场景

  • 构建自然语言数据库查询接口(如问答系统)。
  • 自动化数据分析,允许用户用自然语言提取数据。
  • 结合 RAG 或代理,从数据库中获取动态数据。
  • 开发交互式数据查询工具,支持非技术用户。

与其他链对比

  • create_sql_query_chain: 生成 SQL 查询,专注于数据库交互。
  • RetrievalQA: 用于文档检索的 RAG 链。
  • ConversationChain: 处理对话历史。
  • SQLDatabaseChain: 执行 SQL 查询并返回结果(更高级,但已部分弃用)。

注意

  • 生成的 SQL 查询需验证,避免语法错误或注入风险。
  • 依赖 langchain_community.utilities.SQLDatabase 提供数据库上下文。

函数定义和参数

以下是 create_sql_query_chain 的定义,基于 LangChain 源码(langchain/chains/sql_database/query.py)和官方文档(create_sql_query_chain)。

函数签名
def create_sql_query_chain(
    llm: BaseLanguageModel,
    db: SQLDatabase,
    prompt: Optional[BasePromptTemplate] = None,
    k: int = 5
) -> Runnable[Union[str, Dict[str, Any]], str]
  • 参数
    • llmBaseLanguageModel):语言模型,用于生成 SQL 查询(如 ChatOpenAI)。
    • dbSQLDatabase):数据库实例,提供表结构和方言信息。
    • promptOptional[BasePromptTemplate]):自定义提示模板,默认为内置模板。
    • kint):生成查询时考虑的上下文行数,默认为 5。
  • 返回值
    • Runnable[Union[str, Dict[str, Any]], str]:一个 LCEL 链,接受字符串或字典输入,输出 SQL 查询字符串。
  • 功能
    • 使用 LLM 和数据库上下文生成 SQL 查询。
    • 支持自定义提示模板以优化查询生成。
    • 输出为纯 SQL 语句,需进一步执行。
默认提示模板

默认模板(简化为中文):

你是一个 SQL 专家,基于以下数据库信息将用户问题转换为 SQL 查询。
数据库方言: {dialect}
表结构: {table_info}
问题: {input}
输出仅包含 SQL 查询语句。
输入格式
  • 字符串:直接输入自然语言问题:
    chain.invoke("有多少用户?")
    
  • 字典:指定 input 键:
    chain.invoke({"input": "有多少用户?"})
    
输出格式
  • 字符串:生成的 SQL 查询,如:
    SELECT COUNT(*) FROM users;
    

工作原理

create_sql_query_chain 的运行逻辑如下:

  • 输入:接受自然语言问题(字符串或字典)。
  • 处理
    • db 获取数据库元数据(表名、列名、样本数据)。
    • 结合 llmprompt,生成 SQL 查询。
    • 使用 k 参数控制上下文数据量。
  • 输出:返回 SQL 查询字符串。
  • 执行:需手动或通过链执行查询(如 db.run(query))。

在 LCEL 中的作用

  • 作为链的一部分,生成 SQL 查询。
  • 可与 SQLDatabase.run 或自定义解析器组合,获取查询结果。

示例流程

from langchain.chains import create_sql_query_chain
from langchain_openai import ChatOpenAI
from langchain_community.utilities import SQLDatabase
llm = ChatOpenAI()
db = SQLDatabase.from_uri("sqlite:///example.db")
chain = create_sql_query_chain(llm, db)
query = chain.invoke("有多少用户?")
# 输出: SELECT COUNT(*) FROM users;

常用方法

create_sql_query_chain 返回一个 Runnable 对象,支持以下方法:

1. invoke
def invoke(self, input: Union[str, Dict[str, Any]], config: Optional[RunnableConfig] = None) -> str
  • 功能:同步调用,生成 SQL 查询。
  • 输入
    • inputstr | Dict):问题字符串或包含 input 键的字典。
    • configOptional[RunnableConfig]):运行配置(如超时)。
  • 输出:SQL 查询字符串。
  • 示例
    query = chain.invoke("列出所有用户的姓名")
    print(query)  # 输出: SELECT name FROM users;
    
2. ainvoke
async def ainvoke(self, input: Union[str, Dict[str, Any]], config: Optional[RunnableConfig] = None) -> str
  • 功能:异步调用,生成 SQL 查询。
  • 示例
    query = await chain.ainvoke("有多少用户?")
    print(query)  # 输出: SELECT COUNT(*) FROM users;
    
3. stream / astream
  • 功能:支持流式输出,逐块返回查询。
  • 示例
    for chunk in chain.stream("列出所有用户"):
        print(chunk, end="")
    

使用方式

以下是使用 create_sql_query_chain 的步骤。

1. 安装依赖
pip install --upgrade langchain langchain-openai sqlalchemy
2. 设置 OpenAI API 密钥
export OPENAI_API_KEY="your-api-key"

或在代码中:

import os
os.environ["OPENAI_API_KEY"] = "your-api-key"
3. 初始化数据库和 LLM
from langchain_community.utilities import SQLDatabase
from langchain_openai import ChatOpenAI
db = SQLDatabase.from_uri("sqlite:///example.db")
llm = ChatOpenAI(model="gpt-3.5-turbo")
4. 创建 SQL 查询链
from langchain.chains import create_sql_query_chain
chain = create_sql_query_chain(llm, db)
5. 调用链并执行查询
query = chain.invoke("有多少用户?")
result = db.run(query)
print(result)

使用 create_sql_query_chain 的示例

以下是一个独立示例,展示如何使用 create_sql_query_chain 结合 ChatOpenAISQLDatabase 实现人工智能主题的数据库查询。链生成 SQL 查询并执行,回答用户关于 AI 相关数据的自然语言问题。

准备环境

  • 获取 OpenAI API 密钥:OpenAI Platform。
  • 设置环境变量:
    export OPENAI_API_KEY="your-api-key"
    
  • 安装依赖:
    pip install --upgrade langchain langchain-openai sqlalchemy
    
  • 创建 SQLite 数据库 ai_data.db
    CREATE TABLE ai_projects (
        id INTEGER PRIMARY KEY,
        name TEXT NOT NULL,
        category TEXT,
        year INTEGER
    );
    INSERT INTO ai_projects (name, category, year) VALUES
        ('AlphaGo', 'Game AI', 2016),
        ('GPT-3', 'NLP', 2020),
        ('DALL-E', 'Generative AI', 2021);
    

代码

from langchain_openai import ChatOpenAI
from langchain_community.utilities import SQLDatabase
from langchain.chains import create_sql_query_chain
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableLambda
from langchain_core.output_parsers import StrOutputParser

# 初始化 ChatOpenAI
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.7)

# 初始化 SQLDatabase
db = SQLDatabase.from_uri("sqlite:///ai_data.db")

# 创建 SQL 查询链
sql_chain = create_sql_query_chain(llm, db)

# 定义提示模板,用于格式化查询结果
prompt = ChatPromptTemplate.from_template(
    "你是一个数据库专家,基于以下 SQL 查询结果回答问题:\n结果:{result}\n问题:{input}\n回答:"
)

# 定义输出解析器
parser = StrOutputParser()

# 创建完整工作流
chain = (
    {
        "query": sql_chain,
        "input": RunnableLambda(lambda x: x.strip())
    }
    | RunnableLambda(lambda x: {"result": db.run(x["query"]), "input": x["input"]})
    | prompt
    | llm
    | parser
)

# 测试 SQL 查询工作流
print("测试 create_sql_query_chain 和数据库查询:")
try:
    questions = [
        "有多少 AI 项目?",
        "列出 2020 年后的 AI 项目"
    ]
    
    for question in questions:
        result = chain.invoke(question)
        print(f"\n输入问题: {question}")
        print(f"回答: {result}")
except Exception as e:
    print(f"错误: {e}")

输出示例(实际输出取决于模型和数据库):

测试 create_sql_query_chain 和数据库查询:

输入问题: 有多少 AI 项目?
回答: 数据库中有 3 个 AI 项目。

输入问题: 列出 2020 年后的 AI 项目
回答: 2020 年后的 AI 项目包括:
- GPT-3(类别:NLP,2020 年)
- DALL-E(类别:Generative AI,2021 年)
代码说明
  1. LLM 初始化
    • 使用 ChatOpenAI 调用 gpt-3.5-turbo,设置 temperature=0.7
  2. 数据库初始化
    • 连接 SQLite 数据库 ai_data.db,包含 ai_projects 表。
  3. SQL 查询链
    • 使用 create_sql_query_chain 创建 SQL 生成链。
  4. 工作流
    • 使用 sql_chain 生成 SQL 查询。
    • 使用 RunnableLambda 清理输入并执行查询(db.run)。
    • 组合 promptllmparser 格式化结果。
  5. 测试
    • 测试两个问题:统计项目数和过滤年份。
    • 显示问题和回答,展示自然语言到 SQL 的转换。
  6. 错误处理
    • 使用 try-except 捕获 API 或数据库错误。

运行要求

  • 有效的 OpenAI API 密钥:
    export OPENAI_API_KEY="your-api-key"
    
  • 安装依赖:
    pip install --upgrade langchain langchain-openai sqlalchemy
    
  • SQLite 数据库 ai_data.db 已创建。
  • 网络连接:访问 https://api.openai.com.

注意事项

  1. API 密钥
    • 确保 OPENAI_API_KEY 已设置:
      echo $OPENAI_API_KEY
      
    • 或在代码中设置:
      llm = ChatOpenAI(api_key="your-api-key")
      
  2. 数据库配置
    • 验证数据库 URI:
      db = SQLDatabase.from_uri("sqlite:///ai_data.db")
      print(db.get_table_info())
      
    • 限制表范围:
      db = SQLDatabase.from_uri("sqlite:///ai_data.db", include_tables=["ai_projects"])
      
  3. SQL 安全
    • 检查生成的 SQL 查询,避免注入:
      query = sql_chain.invoke("无效查询")
      print(query)
      
    • 限制 LLM 生成危险命令:
      prompt = ChatPromptTemplate.from_template("仅生成 SELECT 查询: {input}")
      sql_chain = create_sql_query_chain(llm, db, prompt=prompt)
      
  4. 性能优化
    • 异步调用:使用 ainvoke
      query = await sql_chain.ainvoke("有多少用户?")
      
    • 缓存查询:结合 langchain.cache
      from langchain.cache import InMemoryCache
      langchain.llm_cache = InMemoryCache()
      
    • 限制上下文:调整 k
      sql_chain = create_sql_query_chain(llm, db, k=3)
      
  5. 错误调试
    • 查询错误
      • 检查 SQL 语法:
        query = sql_chain.invoke("无效查询")
        print(query)
        
      • 验证执行:
        print(db.run(query))
        
    • API 错误
      • 检查密钥:
        print(os.environ.get("OPENAI_API_KEY"))
        
      • 增加超时:
        llm = ChatOpenAI(timeout=30)
        
    • 数据库连接
      • 检查 URI:
        print(db.get_usable_table_names())
        
      • 测试连接:
        print(db.run("SELECT 1"))
        

常见问题

Q1:如何自定义提示模板?
A:创建自定义 ChatPromptTemplate

from langchain_core.prompts import ChatPromptTemplate
prompt = ChatPromptTemplate.from_template(
    "生成 {dialect} 的 SQL 查询,基于表结构 {table_info},回答问题:{input}\n仅输出 SQL 语句。"
)
sql_chain = create_sql_query_chain(llm, db, prompt=prompt)

Q2:如何执行和解析查询结果?
A:结合 db.run 和提示:

chain = sql_chain | RunnableLambda(lambda x: {"query": x, "result": db.run(x)})
result = chain.invoke("有多少用户?")
print(result["result"])

Q3:如何与代理结合?
A:使用 SQLDatabaseToolkit

from langchain.agents import create_sql_agent
from langchain_community.agent_toolkits import SQLDatabaseToolkit
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
agent = create_sql_agent(llm=llm, toolkit=toolkit, verbose=True)
result = agent.run("有多少 AI 项目?")

Q4:如何支持开源模型?
A:使用 ChatOllama

from langchain_ollama import ChatOllama
llm = ChatOllama(model="llama3")
sql_chain = create_sql_query_chain(llm, db)
query = sql_chain.invoke("有多少用户?")

总结

langchain.chains.create_sql_query_chain 是 LangChain 中用于自然语言到 SQL 查询转换的强大工具,核心功能包括:

  • 定义:生成 SQL 查询,基于 LLM 和数据库上下文。
  • 参数llmdbpromptk
  • 常用方法invoke(同步)、ainvoke(异步)、stream(流式)。
  • 适用场景:数据库查询、数据分析、自然语言接口。

你可能感兴趣的:(LangChain,langchain,chains,create_sql_quer,sql_database,sql)