《Relay IR的基石:expr.h 中的表达式类型系统剖析》

文章目录

  • 一 、从Constant看Relay表达式的设计哲学
    • 1. 类定义概述
    • 2. `ConstantNode` 详解
      • 1. 核心成员
      • 2. 关键方法
      • 3. 类型系统注册
    • 3. `Constant` 详解
      • 1. 核心功能
  • 二. 核心内容概述
    • (1) Relay表达式基类
      • 1. RelayExprNode 和 RelayExpr 的区别与用法
      • 2. 主要区别
      • 3. 使用模式
        • 例子1:常量表达式
        • 例子2:变量表达式
        • 例子3:函数应用
      • 4. 实际使用建议
    • (2) 具体表达式类型
      • 1. 表达式类型 VarNode举例子
        • 1. 核心设计理念
        • 2. 关键成员解析
          • (1) 核心字段
          • (2) 特殊方法
        • 3. 变量标识系统
          • (1) vid (Unique ID)
          • (2) name_hint 与 vid 的关系
        • 4. 类型系统整合
          • (1) 类型注解流程
          • (2) 类型推导规则
        • 5. 内存模型与跨语言交互
          • (1) C++ 层构造
          • (2) Python 绑定
          • **(3) 对象生命周期**
        • 6. 关键应用场景
          • (1) 函数参数定义
          • (2) 优化 Pass 中的变量处理
          • (3) 类型检查
        • 7. 设计亮点总结
        • 8. 典型问题分析
    • (3) TVM_DECLARE_BASE_OBJECT_INFO 宏详解
      • 1. 宏的参数
      • 2. 静态断言检查(防止非法继承)
      • 2. 运行时类型索引(RuntimeTypeIndex)
      • 3. 动态分配类型索引(_GetOrAllocRuntimeTypeIndex)
      • 通俗版解释:TVM的类型身份证系统
        • 1. 为什么要办身份证?
        • 2. 办证过程(宏的作用)
        • 3. 特殊班级(FINAL版)
        • 4. 实际有什么用?
        • 举个栗子
        • 一句话总结
      • (4) 遍历接口
        • 1. C++ 场景示例
          • (1) 模型序列化(保存为JSON)
          • (2) 优化Pass中的常量修改
          • (3) 调试打印
        • 2. Python 场景示例
          • (1) 直接属性访问
          • (2) 模型保存与加载
          • (3) 自定义属性访问器


一 、从Constant看Relay表达式的设计哲学

  在TVM的Relay IR中,即使是看似简单的常量表达式relay.const(1),其背后也隐藏着整个类型系统的精妙设计。让我们从include/tvm/relay/expr.h中的Constant类入手,逐步拆解…"

1. 类定义概述

类名 继承关系 角色 关键特性
ConstantNode public ExprNode 常量表达式的实际数据存储 包含常量数据(NDArray)、类型信息,并实现属性访问、哈希和相等比较逻辑。
Constant public RelayExpr 常量表达式的智能指针封装 提供用户友好的构造函数和访问方法,隐藏内存管理细节。

2. ConstantNode 详解

class ConstantNode : public ExprNode {
 public:
  /*! \brief The data of the tensor */
  runtime::NDArray data;

  /*! \return The corresponding tensor type of the data */
  TensorType tensor_type() const;

  /*! \return Whether it is scalar(rank-0 tensor) */
  bool is_scalar() const { return data->ndim == 0; }

  void VisitAttrs(tvm::AttrVisitor* v) {
    v->Visit("data", &data);
    v->Visit("span", &span);
    v->Visit("mdata", &mdata);
    v->Visit("_checked_type_", &checked_type_);
  }

  bool SEqualReduce(const ConstantNode* other, SEqualReducer equal) const {
    return equal(data, other->data);
  }

  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(data); }

  static constexpr const char* _type_key = "relay.Constant";
  TVM_DECLARE_FINAL_OBJECT_INFO(ConstantNode, ExprNode);
};

1. 核心成员

  • data (runtime::NDArray)

    • 存储常量张量的实际数据(如权重、偏置等),TVM 使用 NDArray 统一表示多维数组。
    • 示例:卷积层的权重矩阵会被存储在这里。
  • tensor_type()

    • 根据 data 的维度(shape)和数据类型(dtype)自动生成对应的 TensorType
    • 用途:类型推断时确定常量的类型。
  • is_scalar()

    • 判断常量是否为标量(0维张量),如 data->ndim == 0

2. 关键方法

  • VisitAttrs

    • 实现属性的序列化/反序列化,支持以下字段:
      v->Visit("data", &data);          // 张量数据
      v->Visit("span", &span);         // 源码位置信息
      v->Visit("mdata", &mdata);       // 元数据(如调试信息)
      v->Visit("_checked_type_", &checked_type_);  // 类型检查后的类型
      
  • SEqualReduceSHashReduce

    • 结构化相等比较:比较两个 ConstantNodedata 是否相同(用于优化中的常量折叠)。
    • 哈希计算:基于 data 生成哈希值(用于快速查找重复常量)。

3. 类型系统注册

TVM_DECLARE_FINAL_OBJECT_INFO(ConstantNode, ExprNode);
  • _type_key = "relay.Constant":唯一标识常量节点类型。
  • FINAL:禁止继承,确保常量节点的行为不可被修改。

3. Constant 详解

class Constant : public Expr {
 public:
  /*!
   * \brief The constructor
   * \param data The data of the constant tensor.
   * \param span The source span of the expression.
   */
  TVM_DLL explicit Constant(runtime::NDArray data, Span span = Span(), MetaData mdata = MetaData());

  TVM_DEFINE_OBJECT_REF_METHODS(Constant, RelayExpr, ConstantNode);
};

1. 核心功能

  • 构造函数

    explicit Constant(runtime::NDArray data, Span span = Span(), MetaData mdata = MetaData());
    
    • 接收 NDArray 数据,构造一个常量表达式。
    • 示例
      # Python 前端等价代码
      data = np.array([1, 2, 3], dtype="float32")
      const_expr = relay.Constant(tvm.nd.array(data))
      
  • 智能指针方法

    TVM_DEFINE_OBJECT_REF_METHODS(Constant, RelayExpr, ConstantNode);
    

    展开后提供:

    • operator->():直接访问 ConstantNode 成员(如 const_expr->data)。
    • get():获取底层 ConstantNode 指针。
    • 自动内存管理(通过 ObjectRef 的引用计数)。

二. 核心内容概述

  在TVM源码中,include/tvm/relay/expr.hRelay IR(中间表示)的核心头文件,定义了所有Relay表达式的基础数据结构和类型系统。它是实现TVM高层计算图表示的关键组成部分。以下是该文件的详细解析:
相关重要文件

文件路径 关联内容
include/tvm/relay/type.h 类型系统(TensorType等)
include/tvm/relay/op.h 运算符定义
include/tvm/relay/adt.h 代数数据类型支持
src/relay/ir/expr.cc 表达式方法的实现

include/tvm/relay/expr.h文件主要包含:

  • (1) Relay表达式基类RelayExpr/RelayExprNode
  • (2) 所有具体表达式类型的声明(如变量、常量、函数调用等)
  • (3) 表达式类型的遍历和转换接口
  • (4) 类型系统和属性访问的支持

(1) Relay表达式基类

class RelayExprNode : public BaseExprNode { /*...*/ };
class RelayExpr : public BaseExpr { /*...*/ };
  • 角色:所有Relay表达式的公共基类
  • 功能
    • 提供类型系统支持(通过checked_type_字段)
    • 实现属性访问(VisitAttrs
    • 支持结构化相等比较(SEqualReduce

1. RelayExprNode 和 RelayExpr 的区别与用法

  RelayExprNode 是 Relay 表达式的实际实现类,是一个 C++ 类,包含了表达式的所有数据和功能实现。它是所有 Relay 表达式类型的基类。
  RelayExpr 是一个智能指针(relay::Expr),它指向 RelayExprNode 或其子类的实例。它提供了对 RelayExprNode 的安全访问和管理。

2. 主要区别

特性 RelayExprNode RelayExpr
类型 C++ 类 智能指针(std::shared_ptr 的封装)
生命周期管理 需要手动管理 自动管理
使用方式 通常不直接使用,作为实现细节 用户主要交互的接口
继承关系 作为基类定义表达式结构 作为访问接口

3. 使用模式

在 TVM 中,通常的模式是:

  1. 定义一个继承自 RelayExprNode 的具体表达式节点类
  2. 使用 RelayExpr 作为这些节点的引用
例子1:常量表达式
// 创建一个常量表达式
auto const_node = relay::ConstantNode::make(tvm::runtime::NDArray::Zeros(...));
RelayExpr const_expr = const_node;

// 通常更简洁的写法
RelayExpr const_expr = relay::Constant(tvm::runtime::NDArray::Zeros(...));
例子2:变量表达式
// 创建一个变量表达式
auto var_node = relay::VarNode::make("x", relay::Type());
RelayExpr var_expr = var_node;

// 或者更简洁地
RelayExpr var_expr = relay::Var("x", relay::Type());
例子3:函数应用
// 创建函数应用表达式
RelayExpr func = ...; // 某个函数
RelayExpr arg = ...;  // 某个参数
auto call_node = relay::CallNode::make(func, {arg});
RelayExpr call_expr = call_node;

// 或者
RelayExpr call_expr = relay::Call(func, {arg});

4. 实际使用建议

  1. 用户代码:在大多数情况下,你应该使用 RelayExpr 而不是直接操作 RelayExprNode

  2. 扩展 Relay:如果你想定义新的表达式类型,需要继承 RelayExprNode 并实现相应接口。

  3. 类型转换:可以使用 as 方法将 RelayExpr 向下转换为特定类型的节点指针:

RelayExpr expr = ...;
if (const auto* call = expr.as<CallNode>()) {
  // 现在可以访问 CallNode 的特定成员
  call->op;
  call->args;
}
  1. 创建新表达式:TVM 提供了辅助函数来创建表达式,通常以节点类型名去掉 “Node” 命名(如 relay::Var() 创建 VarNodeRelayExpr)。

这种分离设计使得 Relay IR 既灵活又安全,同时保持了良好的性能特性

(2) 具体表达式类型

表达式类型 说明 关键成员/方法
VarNode 变量(输入/中间结果) String name_hint, Type type_annotation, Id vid
ConstantNode 常量张量(如模型权重) runtime::NDArray data, tensor_type(), is_scalar()
CallNode 函数/运算符调用 Expr op, Array args, Attrs attrs, Array type_args
LetNode Let绑定(实现变量作用域) Var var, Expr value, Expr body
TupleNode 元组结构(多返回值) Array fields
TupleGetItemNode 从元组中获取元素 Expr tuple, int index
IfNode 条件表达式 Expr cond, Expr true_branch, Expr false_branch
OpNode 基本运算符(如add/concat) 通过Op::Get("op_name")获取
FunctionNode 函数定义(在function.h中声明,但属于表达式) Array params, Expr body, Type ret_type, Array type_params
RefCreateNode 创建可变引用(用于状态更新) Expr value
RefReadNode 读取引用值 Expr ref
RefWriteNode 更新引用值 Expr ref, Expr value
ConstructorNode 代数数据类型(ADT)的构造器(在adt.h中声明) String tag, Array inputs
MatchNode 模式匹配(ADT处理) Expr data, Array clauses
TempExprNode 临时表达式(用于优化过程中的中间表示) 通常作为优化Pass的中间载体
GlobalVarNode 全局函数引用(跨模块调用) String name_hint
SeqExprNode 顺序执行多个表达式(类似语句块) Array bindings, Expr body

1. 表达式类型 VarNode举例子

include/tvm/relay/expr.h

class Var;
/*! \brief Container for Var */
class VarNode : public ExprNode {
 public:
  /*!
   * \brief The unique identifier of the Var.
   *
   * vid will be preserved for the same Var during type inference
   * and other rewritings, while the VarNode might be recreated
   * to attach additional information.
   * This property can be used to keep track of parameter Var
   * information across passes.
   */
  Id vid;
  /*!
   * \brief type annotaion of the variable.
   * This field records user provided type annotation of the Var.
   * This field is optional and can be None.
   */
  Type type_annotation;

  /*! \return The name hint of the variable */
  const String& name_hint() const { return vid->name_hint; }

  void VisitAttrs(tvm::AttrVisitor* v) {
    v->Visit("vid", &vid);
    v->Visit("type_annotation", &type_annotation);
    v->Visit("span", &span);
    v->Visit("mdata", &mdata);
    v->Visit("_checked_type_", &checked_type_);
  }

  bool SEqualReduce(const VarNode* other, SEqualReducer equal) const {
    return equal(type_annotation, other->type_annotation) && equal.FreeVarEqualImpl(this, other);
  }

  void SHashReduce(SHashReducer hash_reduce) const {
    hash_reduce(type_annotation);
    hash_reduce.FreeVarHashImpl(this);
  }

  static constexpr const char* _type_key = "relay.Var";
  TVM_DECLARE_FINAL_OBJECT_INFO(VarNode, ExprNode);
};

class Var : public Expr {
 public:
  /*!
   * \brief The constructor
   * \param name_hint The name hint of a variable.
   * \param type_annotation The type annotation of a variable.
   * \param span The source span of the expression.
   */
  TVM_DLL Var(String name_hint, Type type_annotation, Span span = Span(), MetaData mdata = MetaData())
      : Var(Id(name_hint), type_annotation, span, mdata) {}

  /*!
   * \brief The constructor
   * \param vid The unique id of a variable.
   * \param type_annotation The type annotation of a variable.
   * \param span The source span of the expression.
   */
  TVM_DLL Var(Id vid, Type type_annotation, Span span = Span(), MetaData mdata = MetaData());

  TVM_DEFINE_OBJECT_REF_METHODS(Var, RelayExpr, VarNode);
};
1. 核心设计理念

VarNodeVar 共同实现了 Relay IR 的变量系统,采用 TVM 标准的 Object-ObjectRef 设计模式

  • VarNode:存储实际数据的节点类(继承自 ExprNode
  • Var:管理 VarNode智能指针包装类(继承自 Expr

2. 关键成员解析
(1) 核心字段
成员 类型 作用
vid Id 唯一标识符,跨 Pass 保持不变(即使节点被重建)
type_annotation Type 用户显式指定的类型注解(可空)
name_hint() String 通过 vid->name_hint 获取的可读名称(非唯一)
span Span 源码位置信息(用于错误定位)
mdata MetaData 扩展元数据
(2) 特殊方法
方法 功能
SEqualReduce 结构化相等比较(用于优化 Pass 的重复检测)
SHashReduce 哈希计算(支持快速查找)
VisitAttrs 属性序列化/反序列化

3. 变量标识系统
(1) vid (Unique ID)
class IdNode : public Object {
 public:
  String name_hint;
  // ... 其他元数据
};
  • 核心特性
    • 通过 Id(name_hint) 构造,但系统会保证其唯一性
    • 即使优化 Pass 重建变量节点,vid 保持不变
    • 用于跨 Pass 跟踪参数变量(如梯度更新时识别同一参数)
(2) name_hint 与 vid 的关系
x = relay.var("input", shape=(1,3))  # 实际创建:
                                      # vid = Id("input_0x7f") (自动去重)
                                      # name_hint = "input" (用户友好)

4. 类型系统整合
(1) 类型注解流程
graph TD
    A[用户构造] -->|relay.var(..., dtype="float32")| B(type_annotation)
    B --> C[类型检查]
    C -->|更新| D(_checked_type_)
(2) 类型推导规则
  • type_annotation 存在:必须与实际使用类型兼容
  • 若为空:从上下文推断类型

5. 内存模型与跨语言交互
(1) C++ 层构造
// 方式1:通过 name_hint
Var x("data", TensorType({1,3}, DataType::Float(32)));

// 方式2:直接指定 Id
Var x(Id("data_0x7f"), TensorType({1,3}, DataType::Float(32)));
(2) Python 绑定
# Python 前端接口
x = relay.var(
    name="input",
    shape=(1,3),
    dtype="float32",
    span=SourceSpan(...)
)
(3) 对象生命周期
sequenceDiagram
    Python->>C++: relay.var() 创建请求
    C++->>Heap: 分配 VarNode
    C++->>Python: 返回 Var(ObjectRef)
    Python->>C++: 析构时触发引用计数-1

6. 关键应用场景
(1) 函数参数定义
def build_linear():
    x = relay.var("x", shape=(1,3))
    w = relay.var("w", shape=(3,2))
    b = relay.var("b", shape=(2,))
    y = relay.add(relay.matmul(x, w), b)
    return relay.Function([x, w, b], y)
(2) 优化 Pass 中的变量处理
// 在 ConstantFolding 中识别变量引用
if (const VarNode* var = expr.as<VarNode>()) {
    if (var_map.count(var->vid)) {
        // 替换为已知常量
    }
}
(3) 类型检查
// 检查变量类型是否匹配
bool CheckType(const VarNode* var, const Type& expected) {
    return var->checked_type().as<TensorType>()->dtype == expected;
}

7. 设计亮点总结
  1. 稳定性vid 保证变量在优化过程中的持久标识
  2. 灵活性type_annotation 支持显式/隐式类型指定
  3. 安全性TVM_DECLARE_FINAL_OBJECT_INFO 防止错误继承
  4. 可调试性spanname_hint 增强错误可读性
  5. 性能SEqualReduce/SHashReduce 优化图操作效率

8. 典型问题分析

Q: 为什么需要同时存在 vidname_hint
A: 分工不同:

  • name_hint:面向用户,提供可读性(允许重复)
  • vid:面向系统,保证唯一性和跨Pass一致性

Q: 何时会重建 VarNode
A: 典型场景:

  • 类型推断后附加 _checked_type_
  • 优化 Pass 中克隆表达式时保留原 vid 但新建节点

(3) TVM_DECLARE_BASE_OBJECT_INFO 宏详解

  这个宏是 TVM 类型系统的核心,用于在 C++ 中动态注册和管理对象的类型信息。它的核心作用是: 为每个类自动生成类型注册代码,使其能被 TVM 运行时识别和操作


1. 宏的参数

#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)
  • TypeName:当前类名(如 ConstantNode
  • ParentType:父类名(如 ExprNode

2. 静态断言检查(防止非法继承)

static_assert(!ParentType::_type_final, "ParentObj marked as final");
  • 作用:如果父类被标记为 final(通过 _type_final),则禁止子类继承。

2. 运行时类型索引(RuntimeTypeIndex)

static uint32_t RuntimeTypeIndex() {
  // 检查子类槽位配置是否合法
  static_assert(TypeName::_type_child_slots == 0 || 
                ParentType::_type_child_slots == 0 ||
                TypeName::_type_child_slots < ParentType::_type_child_slots,
               "子类槽位数不能超过父类限制");

  // 如果已预分配类型ID,直接返回
  if (TypeName::_type_index != ::tvm::runtime::TypeIndex::kDynamic) {
    return TypeName::_type_index;
  }
  // 否则动态分配
  return _GetOrAllocRuntimeTypeIndex();
}
  • 功能:返回类的唯一类型 ID(uint32_t)。
  • 优化:优先使用预分配的 _type_index(性能更高),否则动态分配。

3. 动态分配类型索引(_GetOrAllocRuntimeTypeIndex)

static uint32_t _GetOrAllocRuntimeTypeIndex() {
  static uint32_t tidx = Object::GetOrAllocRuntimeTypeIndex(
      TypeName::_type_key,         // 类型名称字符串(如 "relay.Constant")
      TypeName::_type_index,       // 预分配的类型ID
      ParentType::RuntimeTypeIndex(), // 父类类型ID
      TypeName::_type_child_slots, // 为子类预留的槽位数
      TypeName::_type_child_slots_can_overflow // 是否允许超额
  );
  return tidx;
}
  • 作用:向 TVM 运行时注册类型,并分配唯一 ID。
  • 关键参数
    • _type_child_slots:限制子类数量(防止类型爆炸)。
    • _type_child_slots_can_overflow:为 true 时允许突破限制。

通俗版解释:TVM的类型身份证系统

你可以把TVM的类型系统想象成一个学校的学生管理系统,而TVM_DECLARE_BASE_OBJECT_INFO就是给学生(类)办身份证的机器:


1. 为什么要办身份证?
  • 每个学生(类)需要唯一学号(类型ID)
  • 需要知道他的班主任是谁(父类)
  • 防止有人冒充转校生(非法继承)
2. 办证过程(宏的作用)
// 给"小明同学"办证,班主任是"李老师"
TVM_DECLARE_BASE_OBJECT_INFO(小明, 李老师)

这个宏会自动做三件事:

  1. 检查家世清白

    static_assert(!李老师::是final班, "班主任明确不收新学生!");
    
    • 如果班主任声明"我们班不接收转学生",就报错
  2. 分配学号

    • 优先用预留的VIP学号(_type_index
    • 没有就现场摇号(_GetOrAllocRuntimeTypeIndex
  3. 登记亲属关系

    学号 = 教务处.登记(
     姓名:"小明",
     班主任:李老师.学号,
     可带小弟人数:3  // _type_child_slots
    );
    
3. 特殊班级(FINAL版)
TVM_DECLARE_FINAL_OBJECT_INFO(学霸班, 实验班)
  • 相当于在班级门口挂**“禁止转入”**牌子
  • 其他班同学想转学过来会直接报错
4. 实际有什么用?
  • 查身份证快obj->IsInstance<小明>() 比查户口本快
  • 安全转班obj.as<小明>() 能安全转换类型
  • 防止冒名顶替:禁止随便认爹(错误继承)

举个栗子
# Python前端定义一个"汉堡店"类
@register_relay_node("food.HamburgerShop")
class HamburgerShopNode(ExprNode):
    _type_key = "food.HamburgerShop"
    _type_child_slots = 2  # 允许开2家分店

C++层通过这个宏:

  1. 给汉堡店分配类型ID(比如9527)
  2. 记录它的父类是ExprNode
  3. 允许最多2个子类(比如CheeseBurgerShopChickenBurgerShop

一句话总结

这个宏就是TVM给类发身份证+建家族档案的工具,让系统能:

  • ✅ 快速识别"你是谁"(类型检查)
  • ✅ 知道"你爸是谁"(继承关系)
  • ❌ 防止"乱认亲戚"(非法继承)

(4) 遍历接口

  void VisitAttrs(tvm::AttrVisitor* v) {
    v->Visit("data", &data);
    v->Visit("span", &span);
    v->Visit("mdata", &mdata);
    v->Visit("_checked_type_", &checked_type_);
  }

  VisitAttrs 是 TVM 中用于统一序列化、反序列化和属性访问的核心接口。以下是 ConstantNode 使用该函数的具体示例,涵盖 C++ 和 Python 场景:


1. C++ 场景示例
(1) 模型序列化(保存为JSON)
// 创建常量节点
runtime::NDArray arr = runtime::NDArray::Empty({2, 2}, DLDataType{kDLFloat, 32, 1}, DLContext{kDLCPU, 0});
ConstantNode* const_node = new ConstantNode();
const_node->data = arr;

// 序列化为JSON
JSONAttrVisitor visitor;
const_node->VisitAttrs(&visitor);  // 触发以下调用:
                                   // visitor.Visit("data", &data)
                                   // visitor.Visit("span", &span)...
std::string json = visitor.GetJSON();

输出JSON片段

{
  "type_key": "relay.Constant",
  "data": {"b64": "AABAA...", "dtype": "float32", "shape": [2, 2]},
  "span": null,
  "_checked_type_": "TensorType([2,2], float32)"
}
(2) 优化Pass中的常量修改
class ConstantMutator : public AttrMutator {
 public:
  void VisitAttrs(AttrVisitor* v) override {
    if (v->IsMutator()) {  // 检查是否为修改模式
      runtime::NDArray new_data = ...; // 生成新数据
      v->Visit("data", &new_data);    // 修改data字段
    }
  }
};

// 调用示例:
ConstantMutator mutator;
const_node->VisitAttrs(&mutator);  // 修改常量数据
(3) 调试打印
class DebugPrinter : public AttrVisitor {
 public:
  void Visit(const char* key, runtime::NDArray* data) override {
    std::cout << key << ": shape=" << data.Shape();
  }
};

DebugPrinter printer;
const_node->VisitAttrs(&printer);  // 输出:data: shape=[2,2]

2. Python 场景示例
(1) 直接属性访问
import tvm
from tvm import relay

# 创建常量
data = tvm.nd.array(np.zeros((2,2), dtype="float32"))
const = relay.Constant(data)

# Python属性访问(背后调用VisitAttrs)
print(const.data)      # 访问NDArray → 触发Visit("data", &data)
print(const.span)      # 访问源码位置 → Visit("span", &span)

输出


None  # 未设置span时的默认值
(2) 模型保存与加载
# 保存模型(触发序列化)
mod = tvm.IRModule.from_expr(const)
mod.save("const.json")  # 内部调用VisitAttrs

# 加载模型(触发反序列化)
loaded_mod = tvm.ir.load_json("const.json")
loaded_const = loaded_mod["main"].body
assert isinstance(loaded_const, relay.Constant)
(3) 自定义属性访问器
class MyVisitor(tvm.ir.AttrVisitor):
    def visit(self, name, value):
        print(f"Attribute {name} has type {type(value)}")

visitor = MyVisitor()
const.visit_attrs(visitor)  # 显式调用VisitAttrs

输出

Attribute data has type 
Attribute span has type 
...

class Constant;
/*!
 * \brief Constant tensor type.
 */
class ConstantNode : public ExprNode {
 public:
  /*! \brief The data of the tensor */
  runtime::NDArray data;

  /*! \return The corresponding tensor type of the data */
  TensorType tensor_type() const;

  /*! \return Whether it is scalar(rank-0 tensor) */
  bool is_scalar() const { return data->ndim == 0; }

  void VisitAttrs(tvm::AttrVisitor* v) {
    v->Visit("data", &data);
    v->Visit("span", &span);
    v->Visit("mdata", &mdata);
    v->Visit("_checked_type_", &checked_type_);
  }

  bool SEqualReduce(const ConstantNode* other, SEqualReducer equal) const {
    return equal(data, other->data);
  }

  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(data); }

  static constexpr const char* _type_key = "relay.Constant";
  TVM_DECLARE_FINAL_OBJECT_INFO(ConstantNode, ExprNode);
};

class Constant : public Expr {
 public:
  /*!
   * \brief The constructor
   * \param data The data of the constant tensor.
   * \param span The source span of the expression.
   */
  TVM_DLL explicit Constant(runtime::NDArray data, Span span = Span(), MetaData mdata = MetaData());

  TVM_DEFINE_OBJECT_REF_METHODS(Constant, RelayExpr, ConstantNode);
};

以下是关于 ConstantNodeConstant 类的详细解释与概括,结合它们在 TVM Relay IR 中的作用和实现设计:



你可能感兴趣的:(TVM,AI编译器,TVM,人工智能)