利用DSPy优化LangChain RAG系统的实战指南

利用DSPy优化LangChain RAG系统的实战指南

技术背景介绍

DSPy 是一个用于大语言模型(LLMs)的出色框架,它引入了一个自动编译器,能够教会模型如何执行你程序中的声明性步骤。具体来说,DSPy 编译器会在内部追踪你的程序,然后为大型语言模型(LLMs)创建高质量的提示(或为小型 LLMs 训练自动微调),以教会它们任务的步骤。感谢 Omar Khattab 的努力,现在 DSPy 可以与 LangChain 集成,从而实现程序的自动优化。

本文将演示如何利用 DSPy 对 LangChain RAG(Retrieve And Generate)系统进行优化。尽管这是一个概念验证教程,但我们将展示如何利用 DSPy 提升 RAG 系统的性能。

核心原理解析

在 RAG 系统中,检索模块负责从文档中提取相关的信息,而生成模块则根据这些信息生成合适的回答。使用 DSPy,我们可以自动化并优化这些流程,以提高系统的整体表现。

DSPy 的核心在于通过追踪程序路径并生成高质量提示,来教会模型如何更好地完成任务。通过与 LangChain 的集成,我们可以将这一过程应用到 RAG 系统中,从而实现检索和生成的自动化优化。

代码实现演示

首先,我们需要安装相关的依赖项:

!pip install -U dspy-ai
!pip install -U openai jinja2
!pip install -U langchain langchain-community langchain-openai langchain-core

设置API密钥

我们将使用 OpenAI 的服务,所以需要设置 API 密钥:

import getpass
import os

os.environ["OPENAI_API_KEY"] = getpass.getpass()

创建检索模块

在这个例子中,我们将使用 DSPy 提供的 ColBERT 检索器:

import dspy

colbertv2 = dspy.ColBERTv2(url="http://20.102.90.50:2017/wiki17_abstracts")

我们为 LangChain 设置缓存:

from langchain.globals import set_llm_cache
from langchain_community.cache import SQLiteCache
from langchain_openai import OpenAI

set_llm_cache(SQLiteCache(database_path="cache.db"))

llm = OpenAI(model_name="gpt-3.5-turbo-instruct", temperature=0)

定义检索函数:

def retrieve(inputs):
    return [doc["text"] for doc in colbertv2(inputs["question"], k=5)]

创建初始 RAG 管道

定义任务和提示模板:

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough

prompt = PromptTemplate.from_template(
    "Given {context}, answer the question `{question}` as a tweet."
)

vanilla_chain = (
    RunnablePassthrough.assign(context=retrieve) | prompt | llm | StrOutputParser()
)

集成 DSPy

使用 DSPy 包装 LangChain 的模块:

from dspy.predict.langchain import LangChainModule, LangChainPredict

zeroshot_chain = (
    RunnablePassthrough.assign(context=retrieve)
    | LangChainPredict(prompt, llm)
    | StrOutputParser()
)

zeroshot_chain = LangChainModule(zeroshot_chain)

测试模块

question = "In what region was Eddy Mazzoleni born?"

zeroshot_chain.invoke({"question": question})

加载数据集

使用 HotPotQA 数据集进行训练和评估:

import dspy
from dspy.datasets import HotPotQA

dataset = HotPotQA(
    train_seed=1,
    train_size=200,
    eval_seed=2023,
    dev_size=200,
    test_size=0,
    keep_details=True,
)

trainset = [x.without("id", "type").with_inputs("question") for x in dataset.train]
devset = [x.without("id", "type").with_inputs("question") for x in dataset.dev]
valset, devset = devset[:50], devset[50:]

定义评价指标

class Assess(dspy.Signature):
    context = dspy.InputField(desc="ignore if N/A")
    assessed_text = dspy.InputField()
    assessment_question = dspy.InputField()
    assessment_answer = dspy.OutputField(desc="Yes or No")

gpt4T = dspy.OpenAI(model="gpt-4-1106-preview", max_tokens=1000, model_type="chat")
METRIC = None

def metric(gold, pred, trace=None):
    question, answer, tweet = gold.question, gold.answer, pred.output
    context = colbertv2(question, k=5)

    engaging = "Does the assessed text make for a self-contained, engaging tweet?"
    faithful = "Is the assessed text grounded in the context? Say no if it includes significant facts not in the context."
    correct = f"The text above is should answer `{question}`. The gold answer is `{answer}`."
    correct = f"{correct} Does the assessed text above contain the gold answer?"

    with dspy.context(lm=gpt4T):
        faithful = dspy.Predict(Assess)(
            context=context, assessed_text=tweet, assessment_question=faithful
        )
        correct = dspy.Predict(Assess)(
            context="N/A", assessed_text=tweet, assessment_question=correct
        )
        engaging = dspy.Predict(Assess)(
            context="N/A", assessed_text=tweet, assessment_question=engaging
        )

    correct, engaging, faithful = [
        m.assessment_answer.split()[0].lower() == "yes"
        for m in [correct, engaging, faithful]
    ]
    score = (correct + engaging + faithful) if correct and (len(tweet) <= 280) else 0

    if METRIC is not None:
        if METRIC == "correct":
            return correct
        if METRIC == "engaging":
            return engaging
        if METRIC == "faithful":
            return faithful

    if trace is not None:
        return score >= 3
    return score / 3.0

评估基线

from dspy.evaluate.evaluate import Evaluate

evaluate = Evaluate(
    metric=metric, devset=devset, num_threads=8, display_progress=True, display_table=5
)
evaluate(zeroshot_chain)

优化性能

from dspy.teleprompt import BootstrapFewShotWithRandomSearch

optimizer = BootstrapFewShotWithRandomSearch(
    metric=metric, max_bootstrapped_demos=3, num_candidate_programs=3
)

optimized_chain = optimizer.compile(zeroshot_chain, trainset=trainset, valset=valset)

评估优化后的链

evaluate(optimized_chain)

应用场景分析

DSPy 可以广泛应用于各种机器学习任务中,特别是在需要优化提示和生成高质量文本的场景中。与 LangChain 的集成使得 DSPy 在构建 RAG 系统时更加高效,实现了从检索到生成的全流程优化。

实践建议

  1. 数据准备:在进行优化之前,确保数据集是干净且标注准确的。
  2. 设定评价指标:明确评价指标,以便在优化过程中有明确的目标。
  3. 逐步优化:先进行基线评估,再逐步进行优化,以观察每一步的提升效果。
  4. 人工检查:尽管自动化的评价指标很重要,但人工检查有助于发现模型可能忽视的问题。

如果遇到问题欢迎在评论区交流。

—END—

你可能感兴趣的:(langchain,python)