每一篇文章都短小精悍,不啰嗦。
在 LangChain 框架中,Chain
是连接各种组件(模型、工具、数据库等)的核心骨架,负责将多个步骤按逻辑串联成可执行的工作流。无论是简单的「提问 - 回答」流程,还是复杂的「检索 - 思考 - 工具调用」 pipeline,都依赖 Chain
基类提供的基础能力。本文将从架构设计、核心功能到实现细节,全面解析这一基类的设计智慧。
Chain
基类是所有具体 Chain(如 LLMChain
、RetrievalQAChain
、AgentExecutor
)的抽象父类,它定义了流程编排的通用接口和核心逻辑,解决了三个关键问题:
invoke
、ainvoke
),降低使用和扩展成本;Chain
作为抽象基类(ABC),定义了几个必须由子类实现的核心接口,构成了 Chain 的「能力契约」。
input_keys
与 output_keys
:输入输出的「规格说明」@property
@abstractmethod
def input_keys(self) -> list[str]:
"""Keys expected to be in the chain input."""
@property
@abstractmethod
def output_keys(self) -> list[str]:
"""Keys expected to be in the chain output."""
这两个抽象属性定义了 Chain 的「输入输出规格」:
input_keys
:指定 Chain 运行所需的输入参数名称(如问答 Chain 可能需要 ["question", "context"]
);output_keys
:指定 Chain 输出的结果参数名称(如问答 Chain 可能输出 ["answer"]
)。示例:
一个简单的加法计算 Chain,input_keys = ["a", "b"]
(需要两个加数),output_keys = ["sum"]
(输出和)。
_call
与 _acall
:核心执行逻辑的「入口」@abstractmethod
def _call(
self,
inputs: dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> dict[str, Any]:
"""同步执行 Chain 核心逻辑"""
async def _acall(
self,
inputs: dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> dict[str, Any]:
"""异步执行 Chain 核心逻辑(默认基于同步方法实现)"""
_call
是所有 Chain 的「心脏」,子类必须实现具体的业务逻辑:
inputs
:符合 input_keys
规格的参数字典(如 {"a": 1, "b": 2}
);output_keys
规格的结果字典(如 {"sum": 3}
);run_manager
:用于触发回调事件(如执行开始、结束、错误)。异步支持:_acall
提供异步执行能力,默认通过 run_in_executor
调用同步 _call
,子类可重写以实现原生异步逻辑(如异步调用 API)。
Chain
基类封装了从输入到输出的完整流程,以 invoke
方法为入口,包含以下关键步骤:
prep_inputs
def prep_inputs(self, inputs: Union[dict[str, Any], Any]) -> dict[str, str]:
# 1. 处理非字典输入(如直接传入字符串,自动映射到 input_keys 中的第一个键)
if not isinstance(inputs, dict):
_input_keys = set(self.input_keys)
if self.memory: # 排除 memory 提供的键
_input_keys -= set(self.memory.memory_variables)
inputs = {next(iter(_input_keys)): inputs} # 如 input_keys 为 ["question"],则 inputs 转为 {"question": "xxx"}
# 2. 从 memory 加载上下文(如对话历史)
if self.memory:
external_context = self.memory.load_memory_variables(inputs)
inputs = {** inputs, **external_context} # 合并用户输入与 memory 上下文
return inputs
作用:将原始输入(可能是字符串、字典)转换为符合 input_keys
规格的字典,并整合内存中的上下文(如对话历史)。
示例:
用户调用 chain.invoke("北京天气如何?")
,input_keys
为 ["question"]
,则 prep_inputs
会将输入转为 {"question": "北京天气如何?"}
;若 Chain 有 memory,还会添加 {"history": "之前的对话..."}
。
_validate_inputs
def _validate_inputs(self, inputs: dict[str, Any]) -> None:
missing_keys = set(self.input_keys) - set(inputs)
if missing_keys:
raise ValueError(f"Missing some input keys: {missing_keys}")
作用:确保输入包含 input_keys
定义的所有键,避免因参数缺失导致执行错误。
示例:
若 input_keys
为 ["a", "b"]
,而输入仅提供 {"a": 1}
,则抛出 Missing some input keys: {'b'}
。
_call
(由子类实现)这一步由子类根据业务逻辑实现,例如 LLMChain
的 _call
会调用大语言模型,RetrievalQAChain
会先检索文档再调用模型。
prep_outputs
def prep_outputs(
self,
inputs: dict[str, str],
outputs: dict[str, str],
return_only_outputs: bool = False,
) -> dict[str, str]:
# 1. 验证输出包含所有 output_keys
self._validate_outputs(outputs)
# 2. 将结果保存到 memory(如更新对话历史)
if self.memory:
self.memory.save_context(inputs, outputs)
# 3. 决定是否仅返回输出(还是合并输入)
if return_only_outputs:
return outputs
return {** inputs, **outputs} # 合并输入与输出(方便下游使用上下文)
作用:验证输出格式,保存结果到内存,并决定返回内容(仅输出或包含输入)。
示例:
outputs 为 {"answer": "北京25℃"}
,return_only_outputs=True
则返回 {"answer": "北京25℃"}
;否则返回 {"question": "北京天气如何?", "answer": "北京25℃", ...}
(包含输入和 memory 上下文)。
Chain
通过 run_manager
触发一系列回调事件,实现执行过程的可观察性:
# 在 invoke 方法中
run_manager = callback_manager.on_chain_start(...) # 触发执行开始事件
try:
outputs = self._call(...) # 执行核心逻辑
run_manager.on_chain_end(outputs) # 触发执行结束事件
except Exception as e:
run_manager.on_chain_error(e) # 触发执行错误事件
raise e
应用场景:
memory
属性class Chain(...):
memory: Optional[BaseMemory] = None # 如 ConversationBufferMemory 保存对话历史
作用:让 Chain 具备「状态」,可记住上下文(如对话历史、用户偏好)。例如,聊天机器人通过 memory
记住之前的对话,实现连贯交互。
工作流程:
prep_inputs
阶段:从 memory
加载历史上下文;prep_outputs
阶段:将当前输入输出保存到 memory
。class Chain(...):
callbacks: Callbacks = None # 回调函数或管理器
tags: Optional[list[str]] = None # 标签,用于分类日志
metadata: Optional[dict[str, Any]] = None # 元数据,如请求ID
作用:通过回调函数监控 Chain 的全生命周期(开始、结束、错误),支持日志、监控、调试等。
示例:
def on_chain_start(chain, inputs):
print(f"Chain 开始执行,输入:{inputs}")
chain = MyChain(callbacks=[on_chain_start])
chain.invoke("北京天气") # 执行时会打印 "Chain 开始执行,输入:..."
Chain
提供多种调用接口,适应不同场景:
方法 | 用途 | 示例 |
---|---|---|
invoke |
同步执行,返回完整结果 | result = chain.invoke({"question": "xxx"}) |
ainvoke |
异步执行,返回完整结果 | result = await chain.ainvoke({"question": "xxx"}) |
run |
简化接口,直接返回输出键对应的结果 | answer = chain.run("北京天气") |
__call__ |
兼容旧版本,功能类似 invoke |
result = chain({"question": "xxx"}) |
示例:
run
方法适合简单场景,自动提取 output_keys
中的第一个结果:
# 假设 output_keys = ["answer"]
answer = chain.run("北京天气如何?") # 直接返回 answer 的值,无需处理字典
save
与 dict
def save(self, file_path: Union[Path, str]) -> None:
chain_dict = self.dict() # 转为字典(包含 _type、参数等)
if save_path.suffix == ".json":
json.dump(chain_dict, f, indent=4) # 保存为 JSON
elif save_path.suffix in (".yaml", ".yml"):
yaml.dump(chain_dict, f) # 保存为 YAML
作用:将 Chain 配置序列化到文件,方便复用和部署。加载时可通过 from_config
方法重建 Chain。
Chain
基类定义了流程骨架(invoke
→ prep_inputs
→ _call
→ prep_outputs
),而具体步骤(如 _call
)由子类实现。这种设计让所有 Chain 遵循统一流程,同时保留个性化实现空间。
举例:
Chain
定义「输入→验证→执行→输出」的固定流程;LLMChain
重写 _call
实现「调用大模型」的逻辑;RetrievalQAChain
重写 _call
实现「检索 + 调用模型」的逻辑。通过 callbacks
机制,允许外部对象(观察者)监控 Chain 的执行状态(开始、结束、错误),而无需修改 Chain 本身的代码。这种设计实现了「执行逻辑」与「监控逻辑」的解耦。
memory
、callbacks
等属性通过外部注入,使 Chain 可以灵活组合不同组件(如不同的内存实现、不同的回调工具),符合「开闭原则」(对扩展开放,对修改关闭)。
Chain
基类作为 LangChain 流程编排的基础,其核心价值体现在:
1.标准化 :定义统一的输入输出流程,让所有 Chain 遵循相同的接口和规范;
2. 复用性 :封装通用能力(内存、回调、验证),子类无需重复开发;
3. 扩展性 :通过抽象方法和依赖注入,支持灵活扩展(如自定义 Chain、内存、回调);
4. 可观测性 **:内置回调机制,方便调试、监控和日志记录。
理解 Chain
基类,不仅能掌握 LangChain 的使用精髓,更能领悟「流程编排」的设计思想 ——** 将复杂流程拆解为标准化步骤,通过抽象与实现的分离,实现灵活性与复用性的平衡 **。对于开发者而言,基于 Chain
基类开发自定义 Chain 时,只需聚焦核心业务逻辑(_call
方法),其余通用能力均可直接复用,大幅提升开发效率。