DSPy 是一个用于大语言模型(LLMs)的出色框架,它引入了一个自动编译器,能够教会模型如何执行你程序中的声明性步骤。具体来说,DSPy 编译器会在内部追踪你的程序,然后为大型语言模型(LLMs)创建高质量的提示(或为小型 LLMs 训练自动微调),以教会它们任务的步骤。感谢 Omar Khattab 的努力,现在 DSPy 可以与 LangChain 集成,从而实现程序的自动优化。
本文将演示如何利用 DSPy 对 LangChain RAG(Retrieve And Generate)系统进行优化。尽管这是一个概念验证教程,但我们将展示如何利用 DSPy 提升 RAG 系统的性能。
在 RAG 系统中,检索模块负责从文档中提取相关的信息,而生成模块则根据这些信息生成合适的回答。使用 DSPy,我们可以自动化并优化这些流程,以提高系统的整体表现。
DSPy 的核心在于通过追踪程序路径并生成高质量提示,来教会模型如何更好地完成任务。通过与 LangChain 的集成,我们可以将这一过程应用到 RAG 系统中,从而实现检索和生成的自动化优化。
首先,我们需要安装相关的依赖项:
!pip install -U dspy-ai
!pip install -U openai jinja2
!pip install -U langchain langchain-community langchain-openai langchain-core
我们将使用 OpenAI 的服务,所以需要设置 API 密钥:
import getpass
import os
os.environ["OPENAI_API_KEY"] = getpass.getpass()
在这个例子中,我们将使用 DSPy 提供的 ColBERT 检索器:
import dspy
colbertv2 = dspy.ColBERTv2(url="http://20.102.90.50:2017/wiki17_abstracts")
我们为 LangChain 设置缓存:
from langchain.globals import set_llm_cache
from langchain_community.cache import SQLiteCache
from langchain_openai import OpenAI
set_llm_cache(SQLiteCache(database_path="cache.db"))
llm = OpenAI(model_name="gpt-3.5-turbo-instruct", temperature=0)
定义检索函数:
def retrieve(inputs):
return [doc["text"] for doc in colbertv2(inputs["question"], k=5)]
定义任务和提示模板:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
prompt = PromptTemplate.from_template(
"Given {context}, answer the question `{question}` as a tweet."
)
vanilla_chain = (
RunnablePassthrough.assign(context=retrieve) | prompt | llm | StrOutputParser()
)
使用 DSPy 包装 LangChain 的模块:
from dspy.predict.langchain import LangChainModule, LangChainPredict
zeroshot_chain = (
RunnablePassthrough.assign(context=retrieve)
| LangChainPredict(prompt, llm)
| StrOutputParser()
)
zeroshot_chain = LangChainModule(zeroshot_chain)
question = "In what region was Eddy Mazzoleni born?"
zeroshot_chain.invoke({"question": question})
使用 HotPotQA 数据集进行训练和评估:
import dspy
from dspy.datasets import HotPotQA
dataset = HotPotQA(
train_seed=1,
train_size=200,
eval_seed=2023,
dev_size=200,
test_size=0,
keep_details=True,
)
trainset = [x.without("id", "type").with_inputs("question") for x in dataset.train]
devset = [x.without("id", "type").with_inputs("question") for x in dataset.dev]
valset, devset = devset[:50], devset[50:]
class Assess(dspy.Signature):
context = dspy.InputField(desc="ignore if N/A")
assessed_text = dspy.InputField()
assessment_question = dspy.InputField()
assessment_answer = dspy.OutputField(desc="Yes or No")
gpt4T = dspy.OpenAI(model="gpt-4-1106-preview", max_tokens=1000, model_type="chat")
METRIC = None
def metric(gold, pred, trace=None):
question, answer, tweet = gold.question, gold.answer, pred.output
context = colbertv2(question, k=5)
engaging = "Does the assessed text make for a self-contained, engaging tweet?"
faithful = "Is the assessed text grounded in the context? Say no if it includes significant facts not in the context."
correct = f"The text above is should answer `{question}`. The gold answer is `{answer}`."
correct = f"{correct} Does the assessed text above contain the gold answer?"
with dspy.context(lm=gpt4T):
faithful = dspy.Predict(Assess)(
context=context, assessed_text=tweet, assessment_question=faithful
)
correct = dspy.Predict(Assess)(
context="N/A", assessed_text=tweet, assessment_question=correct
)
engaging = dspy.Predict(Assess)(
context="N/A", assessed_text=tweet, assessment_question=engaging
)
correct, engaging, faithful = [
m.assessment_answer.split()[0].lower() == "yes"
for m in [correct, engaging, faithful]
]
score = (correct + engaging + faithful) if correct and (len(tweet) <= 280) else 0
if METRIC is not None:
if METRIC == "correct":
return correct
if METRIC == "engaging":
return engaging
if METRIC == "faithful":
return faithful
if trace is not None:
return score >= 3
return score / 3.0
from dspy.evaluate.evaluate import Evaluate
evaluate = Evaluate(
metric=metric, devset=devset, num_threads=8, display_progress=True, display_table=5
)
evaluate(zeroshot_chain)
from dspy.teleprompt import BootstrapFewShotWithRandomSearch
optimizer = BootstrapFewShotWithRandomSearch(
metric=metric, max_bootstrapped_demos=3, num_candidate_programs=3
)
optimized_chain = optimizer.compile(zeroshot_chain, trainset=trainset, valset=valset)
evaluate(optimized_chain)
DSPy 可以广泛应用于各种机器学习任务中,特别是在需要优化提示和生成高质量文本的场景中。与 LangChain 的集成使得 DSPy 在构建 RAG 系统时更加高效,实现了从检索到生成的全流程优化。
如果遇到问题欢迎在评论区交流。
—END—