在TVM的Relay IR中,即使是看似简单的常量表达式relay.const(1),其背后也隐藏着整个类型系统的精妙设计。让我们从include/tvm/relay/expr.h
中的Constant类入手,逐步拆解…"
类名 | 继承关系 | 角色 | 关键特性 |
---|---|---|---|
ConstantNode |
public ExprNode |
常量表达式的实际数据存储 | 包含常量数据(NDArray )、类型信息,并实现属性访问、哈希和相等比较逻辑。 |
Constant |
public RelayExpr |
常量表达式的智能指针封装 | 提供用户友好的构造函数和访问方法,隐藏内存管理细节。 |
ConstantNode
详解class ConstantNode : public ExprNode {
public:
/*! \brief The data of the tensor */
runtime::NDArray data;
/*! \return The corresponding tensor type of the data */
TensorType tensor_type() const;
/*! \return Whether it is scalar(rank-0 tensor) */
bool is_scalar() const { return data->ndim == 0; }
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("data", &data);
v->Visit("span", &span);
v->Visit("mdata", &mdata);
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const ConstantNode* other, SEqualReducer equal) const {
return equal(data, other->data);
}
void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(data); }
static constexpr const char* _type_key = "relay.Constant";
TVM_DECLARE_FINAL_OBJECT_INFO(ConstantNode, ExprNode);
};
data
(runtime::NDArray
)
NDArray
统一表示多维数组。tensor_type()
data
的维度(shape
)和数据类型(dtype
)自动生成对应的 TensorType
。is_scalar()
data->ndim == 0
。VisitAttrs
v->Visit("data", &data); // 张量数据
v->Visit("span", &span); // 源码位置信息
v->Visit("mdata", &mdata); // 元数据(如调试信息)
v->Visit("_checked_type_", &checked_type_); // 类型检查后的类型
SEqualReduce
和 SHashReduce
ConstantNode
的 data
是否相同(用于优化中的常量折叠)。data
生成哈希值(用于快速查找重复常量)。TVM_DECLARE_FINAL_OBJECT_INFO(ConstantNode, ExprNode);
_type_key = "relay.Constant"
:唯一标识常量节点类型。FINAL
:禁止继承,确保常量节点的行为不可被修改。Constant
详解class Constant : public Expr {
public:
/*!
* \brief The constructor
* \param data The data of the constant tensor.
* \param span The source span of the expression.
*/
TVM_DLL explicit Constant(runtime::NDArray data, Span span = Span(), MetaData mdata = MetaData());
TVM_DEFINE_OBJECT_REF_METHODS(Constant, RelayExpr, ConstantNode);
};
构造函数
explicit Constant(runtime::NDArray data, Span span = Span(), MetaData mdata = MetaData());
NDArray
数据,构造一个常量表达式。# Python 前端等价代码
data = np.array([1, 2, 3], dtype="float32")
const_expr = relay.Constant(tvm.nd.array(data))
智能指针方法
TVM_DEFINE_OBJECT_REF_METHODS(Constant, RelayExpr, ConstantNode);
展开后提供:
operator->()
:直接访问 ConstantNode
成员(如 const_expr->data
)。get()
:获取底层 ConstantNode
指针。ObjectRef
的引用计数)。 在TVM源码中,include/tvm/relay/expr.h
是 Relay IR(中间表示)的核心头文件,定义了所有Relay表达式的基础数据结构和类型系统。它是实现TVM高层计算图表示的关键组成部分。以下是该文件的详细解析:
相关重要文件
文件路径 | 关联内容 |
---|---|
include/tvm/relay/type.h |
类型系统(TensorType等) |
include/tvm/relay/op.h |
运算符定义 |
include/tvm/relay/adt.h |
代数数据类型支持 |
src/relay/ir/expr.cc |
表达式方法的实现 |
include/tvm/relay/expr.h
文件主要包含:
RelayExpr
/RelayExprNode
)class RelayExprNode : public BaseExprNode { /*...*/ };
class RelayExpr : public BaseExpr { /*...*/ };
checked_type_
字段)VisitAttrs
)SEqualReduce
) RelayExprNode
是 Relay 表达式的实际实现类,是一个 C++ 类,包含了表达式的所有数据和功能实现。它是所有 Relay 表达式类型的基类。
RelayExpr
是一个智能指针(relay::Expr
),它指向 RelayExprNode
或其子类的实例。它提供了对 RelayExprNode
的安全访问和管理。
特性 | RelayExprNode | RelayExpr |
---|---|---|
类型 | C++ 类 | 智能指针(std::shared_ptr 的封装) |
生命周期管理 | 需要手动管理 | 自动管理 |
使用方式 | 通常不直接使用,作为实现细节 | 用户主要交互的接口 |
继承关系 | 作为基类定义表达式结构 | 作为访问接口 |
在 TVM 中,通常的模式是:
RelayExprNode
的具体表达式节点类RelayExpr
作为这些节点的引用// 创建一个常量表达式
auto const_node = relay::ConstantNode::make(tvm::runtime::NDArray::Zeros(...));
RelayExpr const_expr = const_node;
// 通常更简洁的写法
RelayExpr const_expr = relay::Constant(tvm::runtime::NDArray::Zeros(...));
// 创建一个变量表达式
auto var_node = relay::VarNode::make("x", relay::Type());
RelayExpr var_expr = var_node;
// 或者更简洁地
RelayExpr var_expr = relay::Var("x", relay::Type());
// 创建函数应用表达式
RelayExpr func = ...; // 某个函数
RelayExpr arg = ...; // 某个参数
auto call_node = relay::CallNode::make(func, {arg});
RelayExpr call_expr = call_node;
// 或者
RelayExpr call_expr = relay::Call(func, {arg});
用户代码:在大多数情况下,你应该使用 RelayExpr
而不是直接操作 RelayExprNode
。
扩展 Relay:如果你想定义新的表达式类型,需要继承 RelayExprNode
并实现相应接口。
类型转换:可以使用 as
方法将 RelayExpr
向下转换为特定类型的节点指针:
RelayExpr expr = ...;
if (const auto* call = expr.as<CallNode>()) {
// 现在可以访问 CallNode 的特定成员
call->op;
call->args;
}
relay::Var()
创建 VarNode
的 RelayExpr
)。这种分离设计使得 Relay IR 既灵活又安全,同时保持了良好的性能特性
表达式类型 | 说明 | 关键成员/方法 |
---|---|---|
VarNode |
变量(输入/中间结果) | String name_hint , Type type_annotation , Id vid |
ConstantNode |
常量张量(如模型权重) | runtime::NDArray data , tensor_type() , is_scalar() |
CallNode |
函数/运算符调用 | Expr op , Array , Attrs attrs , Array |
LetNode |
Let绑定(实现变量作用域) | Var var , Expr value , Expr body |
TupleNode |
元组结构(多返回值) | Array |
TupleGetItemNode |
从元组中获取元素 | Expr tuple , int index |
IfNode |
条件表达式 | Expr cond , Expr true_branch , Expr false_branch |
OpNode |
基本运算符(如add/concat) | 通过Op::Get("op_name") 获取 |
FunctionNode |
函数定义(在function.h 中声明,但属于表达式) |
Array params , Expr body , Type ret_type , Array |
RefCreateNode |
创建可变引用(用于状态更新) | Expr value |
RefReadNode |
读取引用值 | Expr ref |
RefWriteNode |
更新引用值 | Expr ref , Expr value |
ConstructorNode |
代数数据类型(ADT)的构造器(在adt.h 中声明) |
String tag , Array |
MatchNode |
模式匹配(ADT处理) | Expr data , Array |
TempExprNode |
临时表达式(用于优化过程中的中间表示) | 通常作为优化Pass的中间载体 |
GlobalVarNode |
全局函数引用(跨模块调用) | String name_hint |
SeqExprNode |
顺序执行多个表达式(类似语句块) | Array , Expr body |
include/tvm/relay/expr.h
class Var;
/*! \brief Container for Var */
class VarNode : public ExprNode {
public:
/*!
* \brief The unique identifier of the Var.
*
* vid will be preserved for the same Var during type inference
* and other rewritings, while the VarNode might be recreated
* to attach additional information.
* This property can be used to keep track of parameter Var
* information across passes.
*/
Id vid;
/*!
* \brief type annotaion of the variable.
* This field records user provided type annotation of the Var.
* This field is optional and can be None.
*/
Type type_annotation;
/*! \return The name hint of the variable */
const String& name_hint() const { return vid->name_hint; }
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("vid", &vid);
v->Visit("type_annotation", &type_annotation);
v->Visit("span", &span);
v->Visit("mdata", &mdata);
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const VarNode* other, SEqualReducer equal) const {
return equal(type_annotation, other->type_annotation) && equal.FreeVarEqualImpl(this, other);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(type_annotation);
hash_reduce.FreeVarHashImpl(this);
}
static constexpr const char* _type_key = "relay.Var";
TVM_DECLARE_FINAL_OBJECT_INFO(VarNode, ExprNode);
};
class Var : public Expr {
public:
/*!
* \brief The constructor
* \param name_hint The name hint of a variable.
* \param type_annotation The type annotation of a variable.
* \param span The source span of the expression.
*/
TVM_DLL Var(String name_hint, Type type_annotation, Span span = Span(), MetaData mdata = MetaData())
: Var(Id(name_hint), type_annotation, span, mdata) {}
/*!
* \brief The constructor
* \param vid The unique id of a variable.
* \param type_annotation The type annotation of a variable.
* \param span The source span of the expression.
*/
TVM_DLL Var(Id vid, Type type_annotation, Span span = Span(), MetaData mdata = MetaData());
TVM_DEFINE_OBJECT_REF_METHODS(Var, RelayExpr, VarNode);
};
VarNode
和 Var
共同实现了 Relay IR 的变量系统,采用 TVM 标准的 Object-ObjectRef 设计模式:
VarNode
:存储实际数据的节点类(继承自 ExprNode
)Var
:管理 VarNode
的智能指针包装类(继承自 Expr
)成员 | 类型 | 作用 |
---|---|---|
vid |
Id |
唯一标识符,跨 Pass 保持不变(即使节点被重建) |
type_annotation |
Type |
用户显式指定的类型注解(可空) |
name_hint() |
String |
通过 vid->name_hint 获取的可读名称(非唯一) |
span |
Span |
源码位置信息(用于错误定位) |
mdata |
MetaData |
扩展元数据 |
方法 | 功能 |
---|---|
SEqualReduce |
结构化相等比较(用于优化 Pass 的重复检测) |
SHashReduce |
哈希计算(支持快速查找) |
VisitAttrs |
属性序列化/反序列化 |
class IdNode : public Object {
public:
String name_hint;
// ... 其他元数据
};
Id(name_hint)
构造,但系统会保证其唯一性vid
保持不变x = relay.var("input", shape=(1,3)) # 实际创建:
# vid = Id("input_0x7f") (自动去重)
# name_hint = "input" (用户友好)
graph TD
A[用户构造] -->|relay.var(..., dtype="float32")| B(type_annotation)
B --> C[类型检查]
C -->|更新| D(_checked_type_)
type_annotation
存在:必须与实际使用类型兼容// 方式1:通过 name_hint
Var x("data", TensorType({1,3}, DataType::Float(32)));
// 方式2:直接指定 Id
Var x(Id("data_0x7f"), TensorType({1,3}, DataType::Float(32)));
# Python 前端接口
x = relay.var(
name="input",
shape=(1,3),
dtype="float32",
span=SourceSpan(...)
)
sequenceDiagram
Python->>C++: relay.var() 创建请求
C++->>Heap: 分配 VarNode
C++->>Python: 返回 Var(ObjectRef)
Python->>C++: 析构时触发引用计数-1
def build_linear():
x = relay.var("x", shape=(1,3))
w = relay.var("w", shape=(3,2))
b = relay.var("b", shape=(2,))
y = relay.add(relay.matmul(x, w), b)
return relay.Function([x, w, b], y)
// 在 ConstantFolding 中识别变量引用
if (const VarNode* var = expr.as<VarNode>()) {
if (var_map.count(var->vid)) {
// 替换为已知常量
}
}
// 检查变量类型是否匹配
bool CheckType(const VarNode* var, const Type& expected) {
return var->checked_type().as<TensorType>()->dtype == expected;
}
vid
保证变量在优化过程中的持久标识type_annotation
支持显式/隐式类型指定TVM_DECLARE_FINAL_OBJECT_INFO
防止错误继承span
和 name_hint
增强错误可读性SEqualReduce
/SHashReduce
优化图操作效率Q: 为什么需要同时存在 vid
和 name_hint
?
A: 分工不同:
name_hint
:面向用户,提供可读性(允许重复)vid
:面向系统,保证唯一性和跨Pass一致性Q: 何时会重建 VarNode
?
A: 典型场景:
_checked_type_
vid
但新建节点这个宏是 TVM 类型系统的核心,用于在 C++ 中动态注册和管理对象的类型信息。它的核心作用是: 为每个类自动生成类型注册代码,使其能被 TVM 运行时识别和操作。
#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)
TypeName
:当前类名(如 ConstantNode
)ParentType
:父类名(如 ExprNode
)static_assert(!ParentType::_type_final, "ParentObj marked as final");
final
(通过 _type_final
),则禁止子类继承。static uint32_t RuntimeTypeIndex() {
// 检查子类槽位配置是否合法
static_assert(TypeName::_type_child_slots == 0 ||
ParentType::_type_child_slots == 0 ||
TypeName::_type_child_slots < ParentType::_type_child_slots,
"子类槽位数不能超过父类限制");
// 如果已预分配类型ID,直接返回
if (TypeName::_type_index != ::tvm::runtime::TypeIndex::kDynamic) {
return TypeName::_type_index;
}
// 否则动态分配
return _GetOrAllocRuntimeTypeIndex();
}
uint32_t
)。_type_index
(性能更高),否则动态分配。static uint32_t _GetOrAllocRuntimeTypeIndex() {
static uint32_t tidx = Object::GetOrAllocRuntimeTypeIndex(
TypeName::_type_key, // 类型名称字符串(如 "relay.Constant")
TypeName::_type_index, // 预分配的类型ID
ParentType::RuntimeTypeIndex(), // 父类类型ID
TypeName::_type_child_slots, // 为子类预留的槽位数
TypeName::_type_child_slots_can_overflow // 是否允许超额
);
return tidx;
}
_type_child_slots
:限制子类数量(防止类型爆炸)。_type_child_slots_can_overflow
:为 true
时允许突破限制。你可以把TVM的类型系统想象成一个学校的学生管理系统,而TVM_DECLARE_BASE_OBJECT_INFO
就是给学生(类)办身份证的机器:
// 给"小明同学"办证,班主任是"李老师"
TVM_DECLARE_BASE_OBJECT_INFO(小明, 李老师)
这个宏会自动做三件事:
检查家世清白
static_assert(!李老师::是final班, "班主任明确不收新学生!");
分配学号
_type_index
)_GetOrAllocRuntimeTypeIndex
)登记亲属关系
学号 = 教务处.登记(
姓名:"小明",
班主任:李老师.学号,
可带小弟人数:3 // _type_child_slots
);
TVM_DECLARE_FINAL_OBJECT_INFO(学霸班, 实验班)
obj->IsInstance<小明>()
比查户口本快obj.as<小明>()
能安全转换类型# Python前端定义一个"汉堡店"类
@register_relay_node("food.HamburgerShop")
class HamburgerShopNode(ExprNode):
_type_key = "food.HamburgerShop"
_type_child_slots = 2 # 允许开2家分店
C++层通过这个宏:
CheeseBurgerShop
、ChickenBurgerShop
)这个宏就是TVM给类发身份证+建家族档案的工具,让系统能:
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("data", &data);
v->Visit("span", &span);
v->Visit("mdata", &mdata);
v->Visit("_checked_type_", &checked_type_);
}
VisitAttrs
是 TVM 中用于统一序列化、反序列化和属性访问的核心接口。以下是 ConstantNode
使用该函数的具体示例,涵盖 C++ 和 Python 场景:
// 创建常量节点
runtime::NDArray arr = runtime::NDArray::Empty({2, 2}, DLDataType{kDLFloat, 32, 1}, DLContext{kDLCPU, 0});
ConstantNode* const_node = new ConstantNode();
const_node->data = arr;
// 序列化为JSON
JSONAttrVisitor visitor;
const_node->VisitAttrs(&visitor); // 触发以下调用:
// visitor.Visit("data", &data)
// visitor.Visit("span", &span)...
std::string json = visitor.GetJSON();
输出JSON片段:
{
"type_key": "relay.Constant",
"data": {"b64": "AABAA...", "dtype": "float32", "shape": [2, 2]},
"span": null,
"_checked_type_": "TensorType([2,2], float32)"
}
class ConstantMutator : public AttrMutator {
public:
void VisitAttrs(AttrVisitor* v) override {
if (v->IsMutator()) { // 检查是否为修改模式
runtime::NDArray new_data = ...; // 生成新数据
v->Visit("data", &new_data); // 修改data字段
}
}
};
// 调用示例:
ConstantMutator mutator;
const_node->VisitAttrs(&mutator); // 修改常量数据
class DebugPrinter : public AttrVisitor {
public:
void Visit(const char* key, runtime::NDArray* data) override {
std::cout << key << ": shape=" << data.Shape();
}
};
DebugPrinter printer;
const_node->VisitAttrs(&printer); // 输出:data: shape=[2,2]
import tvm
from tvm import relay
# 创建常量
data = tvm.nd.array(np.zeros((2,2), dtype="float32"))
const = relay.Constant(data)
# Python属性访问(背后调用VisitAttrs)
print(const.data) # 访问NDArray → 触发Visit("data", &data)
print(const.span) # 访问源码位置 → Visit("span", &span)
输出:
None # 未设置span时的默认值
# 保存模型(触发序列化)
mod = tvm.IRModule.from_expr(const)
mod.save("const.json") # 内部调用VisitAttrs
# 加载模型(触发反序列化)
loaded_mod = tvm.ir.load_json("const.json")
loaded_const = loaded_mod["main"].body
assert isinstance(loaded_const, relay.Constant)
class MyVisitor(tvm.ir.AttrVisitor):
def visit(self, name, value):
print(f"Attribute {name} has type {type(value)}")
visitor = MyVisitor()
const.visit_attrs(visitor) # 显式调用VisitAttrs
输出:
Attribute data has type
Attribute span has type
...
class Constant;
/*!
* \brief Constant tensor type.
*/
class ConstantNode : public ExprNode {
public:
/*! \brief The data of the tensor */
runtime::NDArray data;
/*! \return The corresponding tensor type of the data */
TensorType tensor_type() const;
/*! \return Whether it is scalar(rank-0 tensor) */
bool is_scalar() const { return data->ndim == 0; }
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("data", &data);
v->Visit("span", &span);
v->Visit("mdata", &mdata);
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const ConstantNode* other, SEqualReducer equal) const {
return equal(data, other->data);
}
void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(data); }
static constexpr const char* _type_key = "relay.Constant";
TVM_DECLARE_FINAL_OBJECT_INFO(ConstantNode, ExprNode);
};
class Constant : public Expr {
public:
/*!
* \brief The constructor
* \param data The data of the constant tensor.
* \param span The source span of the expression.
*/
TVM_DLL explicit Constant(runtime::NDArray data, Span span = Span(), MetaData mdata = MetaData());
TVM_DEFINE_OBJECT_REF_METHODS(Constant, RelayExpr, ConstantNode);
};
以下是关于 ConstantNode
和 Constant
类的详细解释与概括,结合它们在 TVM Relay IR 中的作用和实现设计: