Aloha 机械臂的学习记录1——AWE:Bimanual Simulation Suite(Save waypoints)

继续以TASK = sim_transfer_cube_scripted为例子:

Save waypoints

python example/act_waypoint.py --dataset=data/act/sim_transfer_cube_scripted --err_threshold=0.01 --save_waypoints

在Linux终端运行该命令后的结果如下截图:

Aloha 机械臂的学习记录1——AWE:Bimanual Simulation Suite(Save waypoints)_第1张图片

该结果所执行的文件为:sim_transfer_cube_scripted文件夹下0-49个episode.hdf5文件,对于下面act_waypoint.py 文件进行分析。

act_waypoint.py 文件中代码需要仔细阅读,其中:

def main(args):
    num_waypoints = []
    num_frames = []

    # load data
    for i in tqdm(range(args.start_idx, args.end_idx + 1)):
        dataset_path = os.path.join(args.dataset, f"episode_{i}.hdf5")
        with h5py.File(dataset_path, "r+") as root:
            qpos = root["/observations/qpos"][()]

            if args.use_ee:
                qpos = np.array(qpos)  # ts, dim

                # calculate EE pose
                from act.convert_ee import get_ee

                left_arm_ee = get_ee(qpos[:, :6], qpos[:, 6:7])
                right_arm_ee = get_ee(qpos[:, 7:13], qpos[:, 13:14])
                qpos = np.concatenate([left_arm_ee, right_arm_ee], axis=1)

            # select waypoints
            waypoints = dp_waypoint_selection( # if it's too slow, use greedy_waypoint_selection
                env=None,
                actions=qpos,
                gt_states=qpos,
                err_threshold=args.err_threshold,
                pos_only=True,
            )
            print(
                f"Episode {i}: {len(qpos)} frames -> {len(waypoints)} waypoints (ratio: {len(qpos)/len(waypoints):.2f})"
            )
            num_waypoints.append(len(waypoints))
            num_frames.append(len(qpos))

            # save waypoints
            if args.save_waypoints:
                name = f"/waypoints"
                try:
                    root[name] = waypoints
                except:
                    # if the waypoints dataset already exists, ask the user if they want to overwrite
                    print("waypoints dataset already exists. Overwrite? (y/n)")
                    ans = input()
                    if ans == "y":
                        del root[name]
                        root[name] = waypoints

            # visualize ground truth qpos and waypoints
            if args.plot_3d:
                if not args.use_ee:
                    qpos = np.array(qpos)  # ts, dim
                    from act.convert_ee import get_xyz

                    left_arm_xyz = get_xyz(qpos[:, :6])
                    right_arm_xyz = get_xyz(qpos[:, 7:13])
                else:
                    left_arm_xyz = left_arm_ee[:, :3]
                    right_arm_xyz = right_arm_ee[:, :3]

                # Find global min and max for each axis
                all_data = np.concatenate([left_arm_xyz, right_arm_xyz], axis=0)
                min_x, min_y, min_z = np.min(all_data, axis=0)
                max_x, max_y, max_z = np.max(all_data, axis=0)

                fig = plt.figure(figsize=(20, 10))
                ax1 = fig.add_subplot(121, projection="3d") 
                ax1.set_xlabel("x")
                ax1.set_ylabel("y")
                ax1.set_zlabel("z")
                ax1.set_title("Left", fontsize=20)
                ax1.set_xlim([min_x, max_x])
                ax1.set_ylim([min_y, max_y])
                ax1.set_zlim([min_z, max_z])

                plot_3d_trajectory(ax1, left_arm_xyz, label="ground truth", legend=False)

                ax2 = fig.add_subplot(122, projection="3d")
                ax2.set_xlabel("x")
                ax2.set_ylabel("y")
                ax2.set_zlabel("z")
                ax2.set_title("Right", fontsize=20)
                ax2.set_xlim([min_x, max_x])
                ax2.set_ylim([min_y, max_y])
                ax2.set_zlim([min_z, max_z])

                plot_3d_trajectory(ax2, right_arm_xyz, label="ground truth", legend=False)

                # prepend 0 to waypoints to include the initial state
                waypoints = [0] + waypoints

                plot_3d_trajectory(
                    ax1,
                    [left_arm_xyz[i] for i in waypoints],
                    label="waypoints",
                    legend=False,
                )  # Plot waypoints for left_arm_xyz
                plot_3d_trajectory(
                    ax2,
                    [right_arm_xyz[i] for i in waypoints],
                    label="waypoints",
                    legend=False,
                )  # Plot waypoints for right_arm_xyz

                fig.suptitle(f"Task: {args.dataset.split('/')[-1]}", fontsize=30) 

                handles, labels = ax1.get_legend_handles_labels()
                fig.legend(handles, labels, loc="lower center", ncol=2, fontsize=20)

                fig.savefig(
                    f"plot/act/{args.dataset.split('/')[-1]}_{i}_t_{args.err_threshold}_waypoints.png"
                )
                plt.close(fig)

            root.close()

    print(
        f"Average number of waypoints: {np.mean(num_waypoints)} \tAverage number of frames: {np.mean(num_frames)} \tratio: {np.mean(num_frames)/np.mean(num_waypoints)}"
    )

这是一个主要的 main 函数,它负责执行主要的数据处理、分析和可视化操作。下面来解释一下这个函数的主要步骤:

  1. num_waypointsnum_frames 初始化为两个空列表,用于记录每个轨迹的waypoints数量和帧数。

  2. 通过循环遍历 args.start_idxargs.end_idx 之间的轨迹索引,执行以下操作:

    a. 构建数据集文件路径 dataset_path,该路径基于传递的 --dataset 参数和当前的轨迹索引 i

    b. 使用 h5py 打开 HDF5 数据集文件,并从中读取关节位置数据 qpos

    c. 如果 --use_ee 参数被设置,它会计算末端执行器(end-effector)的位置并更新 qpos,否则会保持不变。

    d. 使用 dp_waypoint_selection 函数选择关键点(waypoints),并打印轨迹的长度和waypoints的数量。

    e. 将waypoints的数量和帧数添加到相应的列表 num_waypointsnum_frames 中。

    f. 如果 --save_waypoints 参数被设置,它会将选择的waypoints保存到HDF5文件中。如果waypoints数据集已经存在,则询问用户是否要覆盖。

    g. 如果 --plot_3d 参数被设置,它会创建一个三维图形以可视化轨迹和waypoints,然后将图形保存为PNG文件。这一部分包括设置图形属性、绘制轨迹和waypoints、设置图形标题和图例、保存图形并关闭图形。

    h. 最后,关闭 HDF5 文件。

  3. 完成所有轨迹的处理后,计算waypoints的平均数量和帧数,并将结果打印出来。

这个函数的目的是加载轨迹数据,选择waypoints,可视化轨迹和waypoints,并进行必要的记录和统计。具体的操作会根据传递的参数和数据集文件而有所不同。

下面依次分析每段代码的含义:

            if args.use_ee:
                qpos = np.array(qpos)  # ts, dim

                # calculate EE pose
                from act.convert_ee import get_ee

                left_arm_ee = get_ee(qpos[:, :6], qpos[:, 6:7])
                right_arm_ee = get_ee(qpos[:, 7:13], qpos[:, 13:14])
                qpos = np.concatenate([left_arm_ee, right_arm_ee], axis=1)

这部分代码主要用于计算末端执行器(end-effector)的姿态(EE pose),然后更新 qpos,以便在后续的处理中使用。具体的步骤如下:

  1. qpos = np.array(qpos):将 qpos 转换为 NumPy 数组,以便进行矩阵操作。qpos 通常是一个时间序列(ts)的关节位置数据。

  2. 导入 act.convert_ee 模块中的 get_ee 函数:这个函数可能包含了计算末端执行器姿态的逻辑。

  3. 使用 get_ee 函数计算左臂(left_arm_ee)和右臂(right_arm_ee)的末端执行器姿态。根据代码的写法,qpos 被分成了左臂和右臂的部分,并分别传递给 get_ee 函数。这个函数的作用是计算给定关节位置的末端执行器姿态。

  4. 将左臂和右臂的末端执行器姿态合并为一个新的 qpos 数组。这个新的 qpos 数组将包含左臂和右臂的末端执行器姿态数据,以便后续处理使用。

  5. left_arm_ee = get_ee(qpos[:, :6], qpos[:, 6:7]):这一行代码调用了 get_ee 函数,传递了左臂的关节位置(qpos[:, :6])和左臂的末端执行器夹爪状态(qpos[:, 6:7])。get_ee 函数计算左臂末端执行器的姿态,并将结果存储在 left_arm_ee 中。

  6. right_arm_ee = get_ee(qpos[:, 7:13], qpos[:, 13:14]):这一行代码类似于上一行,但是计算右臂末端执行器的姿态,并将结果存储在 right_arm_ee 中。右臂的关节位置范围为 qpos[:, 7:13],夹爪状态为 qpos[:, 13:14]

  7. qpos = np.concatenate([left_arm_ee, right_arm_ee], axis=1):在这一行代码中,left_arm_eeright_arm_ee 的结果被连接成一个新的数组 qpos,并且连接操作是在列方向上进行的(axis=1)。这个新的 qpos 数组包含了左臂和右臂末端执行器的姿态信息,以及夹爪状态。

  8. qpos 是一个包含关节位置数据的 NumPy 数组,通常是一个时间序列,其中每一行表示一个时间步的关节位置。qpos[:, :6]qpos[:, 6:7] 是对 qpos 数组的切片操作,用于选择特定的列数据。让我为你解释这两个切片的含义:

  9. qpos[:, :6]:这个切片操作选择了 qpos 数组的所有行(时间步)和前6列(从第0列到第5列)。在这个上下文中,通常前6列对应于机器人或物体的左臂的关节位置数据。因此,qpos[:, :6] 返回一个包含左臂关节位置数据的子数组。

  10. qpos[:, 6:7]:这个切片操作选择了 qpos 数组的所有行和第6列。在这个上下文中,第6列通常对应于机器人或物体的左臂的末端执行器夹爪状态数据。因此,qpos[:, 6:7] 返回一个包含左臂末端执行器夹爪状态数据的子数组。

这段代码的作用是将左臂和右臂的末端执行器姿态信息合并到一个新的数组中,以便在后续的处理中使用。总之,这部分代码用于将原始的关节位置数据转换为末端执行器姿态数据,以便后续处理能够基于末端执行器姿态进行操作。如果需要更多关于 get_ee 函数的详细信息,需要查看 act.convert_ee 模块中的实现代码。

下面给出act.convert_ee模块中的get_ee 函数,python文件对应于act/convert_ee.py文件。

def get_ee(joints, grippers):
    result = []
    rotation_transformer = RotationTransformer(from_rep="matrix", to_rep="rotation_6d")
    for joint, gripper in zip(joints, grippers):
        T_sb = mr.FKinSpace(vx300s.M, vx300s.Slist, joint)

        # Extract position vector
        xyz = T_sb[:3, 3]

        # Extract rotation matrix
        rot_matrix = T_sb[:3, :3]
        # Convert to 6d rotation
        rot_6d = rotation_transformer.forward(rot_matrix)

        # concatenate xyz with rotation
        result.append(np.concatenate((xyz, rot_6d, gripper)))
    return np.array(result)

这是一个名为 get_ee 的函数,它的主要功能是根据关节位置和末端执行器夹爪的状态来计算末端执行器(end-effector)的姿态。下面解释一下这个函数的主要步骤:

  1. 创建一个空列表 result,用于存储计算出的末端执行器姿态。

  2. 创建一个 RotationTransformer 对象,用于将旋转矩阵转换为6维旋转表示。这个对象的配置为从矩阵表示(from_rep="matrix")转换到6维旋转表示(to_rep="rotation_6d")。

  3. 使用循环遍历输入的 joints(关节位置)和 grippers(末端执行器夹爪状态)数据,其中 jointsgrippers 应该是长度相同的列表。

  4. 对于每组关节位置和夹爪状态,执行以下操作:

    a. 使用 mr.FKinSpace 函数计算基于给定的关节位置的末端执行器变换矩阵 T_sb。这个矩阵描述了末端执行器相对于基座(base)的姿态和位置。

    b. 从 T_sb 中提取位置向量 xyz,它包含了末端执行器的x、y和z坐标。

    c. 从 T_sb 中提取旋转矩阵 rot_matrix,它包含了末端执行器的旋转部分。

    d. 使用 RotationTransformer 对象将旋转矩阵 rot_matrix 转换为6维旋转表示 rot_6d

    e. 将 xyzrot_6d 和夹爪状态 gripper 连接成一个单独的数组,并将其添加到 result 列表中。

  5. 返回 result 列表,其中包含了每个时间步的末端执行器姿态数据。

这个函数的目的是将关节位置和夹爪状态转换为末端执行器的姿态信息,并将结果存储在一个列表中,这可以用于后续的轨迹分析和可视化操作。

# select waypoints
            waypoints = dp_waypoint_selection( # if it's too slow, use greedy_waypoint_selection
                env=None,
                actions=qpos,
                gt_states=qpos,
                err_threshold=args.err_threshold,
                pos_only=True,
            )
            print(
                f"Episode {i}: {len(qpos)} frames -> {len(waypoints)} waypoints (ratio: {len(qpos)/len(waypoints):.2f})"
            )
            num_waypoints.append(len(waypoints))
            num_frames.append(len(qpos))

这段代码用于选择轨迹中的关键点(waypoints)。具体的选择方法是通过调用名为 dp_waypoint_selection 的函数来实现的。下面解释一下这段代码的主要步骤:

  1. waypoints = dp_waypoint_selection(...):这一行代码调用了 dp_waypoint_selection 函数来选择轨迹中的关键点。函数的参数如下:

    • env=None:可能是用于指定环境信息的参数,但在这段代码中被设置为 None,表示未使用环境信息。
    • actions=qpos:这是关节位置数据,作为选择waypoints的输入。
    • gt_states=qpos:这也是关节位置数据,可能表示地面实际状态(ground truth states)。
    • err_threshold=args.err_threshold用于设置选择waypoints的误差阈值,只有当关节位置的变化大于该阈值时,才会选择该点作为waypoint。
    • pos_only=True可能表示只考虑位置信息而不考虑朝向信息。
  2. print(...)在选择完waypoints后,代码会打印一条消息,显示当前轨迹的帧数、waypoints数量以及两者之间的比率。

  3. num_waypoints.append(len(waypoints))num_frames.append(len(qpos)):这两行代码将当前轨迹的waypoints数量和帧数添加到 num_waypointsnum_frames 列表中,以便后续计算平均值。

这段代码的目的是选择轨迹中的关键点(waypoints),并记录每个轨迹的waypoints数量和帧数。选择的关键点通常用于后续的轨迹分析和控制任务中。选择方式可能会根据应用需求和数据集不同而有所不同。


 

# save waypoints
            if args.save_waypoints:
                name = f"/waypoints"
                try:
                    root[name] = waypoints
                except:
                    # if the waypoints dataset already exists, ask the user if they want to overwrite
                    print("waypoints dataset already exists. Overwrite? (y/n)") #路点数据集已存在。是否覆盖?
                    ans = input()
                    if ans == "y":
                        del root[name]
                        root[name] = waypoints

这段代码用于将选择的waypoints保存到 HDF5 数据集文件中。下面解释一下它的主要步骤:

  1. if args.save_waypoints::这是一个条件语句,检查是否传递了 --save_waypoints 参数。如果传递了该参数(即设置为 True),则执行以下操作来保存waypoints。

  2. name = f"/waypoints":这一行代码创建了一个名为 name 的变量,用于指定waypoints数据集的名称。在这里,数据集的名称被设置为 "/waypoints"。

  3. try:except::这是一个异常处理结构,用于处理可能发生的异常情况。

  4. root[name] = waypoints这一行代码尝试将waypoints数组保存到 HDF5 数据集文件中的指定路径。具体来说,它将waypoints数据存储在数据集名称为 "/waypoints" 的位置。

  5. except: 块:如果在尝试保存数据时出现异常,代码会进入 except: 块。

  6. print("waypoints dataset already exists. Overwrite? (y/n)"):如果waypoints数据集已经存在,程序会打印一条消息询问用户是否要覆盖现有的数据集。

  7. ans = input():程序等待用户输入一个答案("y" 表示是,"n" 表示否)。

  8. if ans == "y"::如果用户输入 "y",表示同意覆盖现有数据集,那么代码会执行以下操作:

    • del root[name]:这一行代码删除现有的waypoints数据集,以便后续重新创建。

    • root[name] = waypoints:然后,代码将新的waypoints数据保存到数据集中。

这样,通过上述代码,如果waypoints数据集已经存在,用户会被询问是否覆盖,以确保不会意外覆盖现有的数据。这是一种良好的数据保护和用户交互的方式。

 # visualize ground truth qpos and waypoints
            if args.plot_3d:
                if not args.use_ee:
                    qpos = np.array(qpos)  # ts, dim
                    from act.convert_ee import get_xyz

                    left_arm_xyz = get_xyz(qpos[:, :6])
                    right_arm_xyz = get_xyz(qpos[:, 7:13])
                else:
                    left_arm_xyz = left_arm_ee[:, :3]
                    right_arm_xyz = right_arm_ee[:, :3]

                # Find global min and max for each axis
                all_data = np.concatenate([left_arm_xyz, right_arm_xyz], axis=0)
                min_x, min_y, min_z = np.min(all_data, axis=0)
                max_x, max_y, max_z = np.max(all_data, axis=0)

                fig = plt.figure(figsize=(20, 10))
                ax1 = fig.add_subplot(121, projection="3d") 
                ax1.set_xlabel("x")
                ax1.set_ylabel("y")
                ax1.set_zlabel("z")
                ax1.set_title("Left", fontsize=20)
                ax1.set_xlim([min_x, max_x])
                ax1.set_ylim([min_y, max_y])
                ax1.set_zlim([min_z, max_z])

                plot_3d_trajectory(ax1, left_arm_xyz, label="ground truth", legend=False)

                ax2 = fig.add_subplot(122, projection="3d")
                ax2.set_xlabel("x")
                ax2.set_ylabel("y")
                ax2.set_zlabel("z")
                ax2.set_title("Right", fontsize=20)
                ax2.set_xlim([min_x, max_x])
                ax2.set_ylim([min_y, max_y])
                ax2.set_zlim([min_z, max_z])

                plot_3d_trajectory(ax2, right_arm_xyz, label="ground truth", legend=False)

                # prepend 0 to waypoints to include the initial state
                waypoints = [0] + waypoints

                plot_3d_trajectory(
                    ax1,
                    [left_arm_xyz[i] for i in waypoints],
                    label="waypoints",
                    legend=False,
                )  # Plot waypoints for left_arm_xyz
                plot_3d_trajectory(
                    ax2,
                    [right_arm_xyz[i] for i in waypoints],
                    label="waypoints",
                    legend=False,
                )  # Plot waypoints for right_arm_xyz

                fig.suptitle(f"Task: {args.dataset.split('/')[-1]}", fontsize=30) 

                handles, labels = ax1.get_legend_handles_labels()
                fig.legend(handles, labels, loc="lower center", ncol=2, fontsize=20)

                fig.savefig(
                    f"plot/act/{args.dataset.split('/')[-1]}_{i}_t_{args.err_threshold}_waypoints.png"
                )
                plt.close(fig)

            root.close()

    print(
        f"Average number of waypoints: {np.mean(num_waypoints)} \tAverage number of frames: {np.mean(num_frames)} \tratio: {np.mean(num_frames)/np.mean(num_waypoints)}"
    )

这段代码用于可视化地面实际关节位置(qpos)和选择的轨迹关键点(waypoints)。下面解释一下这段代码的主要步骤:

  1. if args.plot_3d::这是一个条件语句,检查是否传递了 --plot_3d 参数。如果传递了该参数(即设置为 True),则执行以下操作来进行3D可视化。

  2. if not args.use_ee::这是一个条件语句,检查是否传递了 --use_ee 参数。如果没有传递(即设置为 False),则执行以下操作,这表示要可视化关节位置信息。

  3. qpos = np.array(qpos):将关节位置数据 qpos 转换为 NumPy 数组以进行后续处理。这是为了确保数据格式正确。

  4. from act.convert_ee import get_xyz:导入 get_xyz 函数,该函数用于从关节位置数据中提取末端执行器的位置信息。

  5. left_arm_xyz = get_xyz(qpos[:, :6])right_arm_xyz = get_xyz(qpos[:, 7:13]):使用 get_xyz 函数从左臂和右臂的关节位置数据中提取末端执行器的位置信息,并将结果分别存储在 left_arm_xyzright_arm_xyz 中。

  6. all_data = np.concatenate([left_arm_xyz, right_arm_xyz], axis=0):将左臂和右臂的末端执行器位置信息合并成一个包含所有数据的数组 all_data,以便找到每个坐标轴上的全局最小值和最大值。

  7. min_x, min_y, min_z = np.min(all_data, axis=0)max_x, max_y, max_z = np.max(all_data, axis=0):找到 all_data 中每个坐标轴上的最小值和最大值,以确定绘图时坐标轴的范围。

  8. 创建一个包含两个子图的3D图形 fig,每个子图对应于左臂和右臂的可视化。

  9. 对于每个子图,设置坐标轴标签、标题和坐标轴范围,并使用 plot_3d_trajectory 函数绘制地面实际关节位置("ground truth")的轨迹。

  10. 在每个子图中,通过在waypoints前添加初始状态(0)来可视化选择的waypoints。这些waypoints以独特的标签进行绘制,但不包括在图例中。

  11. 添加图形标题、图例和标签以美化图形。

  12. 使用 fig.savefig(...) 将图形保存为图像文件,文件名包含有关数据集、轨迹索引、误差阈值和waypoints的信息。

  13. 使用 plt.close(fig) 关闭图形以释放资源。

  14. root.close():关闭 HDF5 数据集文件。

  15. 最后,通过打印消息,显示了平均waypoints数量、平均帧数以及两者之间的比率。

总之,这段代码的主要目的是通过3D可视化展示地面实际关节位置和选择的waypoints。这有助于对轨迹数据进行直观的分析和可视化。

下面给出act.convert_ee模块中的get_xyz函数,python文件对应于act/convert_ee.py文件。

def get_xyz(joints):
    xyz = []
    for joint in joints:
        T_sb = mr.FKinSpace(vx300s.M, vx300s.Slist, joint)
        xyz.append(T_sb[:3, 3])
    return np.array(xyz)

get_xyz 函数用于从关节数据中计算末端执行器(通常是机械臂末端的位置)。下面解释这个函数的主要步骤:

  1. get_xyz(joints):这是 get_xyz 函数的定义,它接受一个名为 joints 的参数,该参数是关节位置数据的数组。

  2. xyz = []:创建一个空列表 xyz,用于存储计算得到的末端执行器位置信息。

  3. for joint in joints::这是一个循环,迭代处理传递给函数的每个关节位置数据。

  4. T_sb = mr.FKinSpace(vx300s.M, vx300s.Slist, joint):在每次迭代中,通过调用 mr.FKinSpace 函数来计算末端执行器的位置。具体来说,该函数接受以下参数:

    • vx300s.M:可能是描述机械臂的位姿(姿态)的矩阵。
    • vx300s.Slist:可能是描述机械臂运动学参数的矩阵。
    • joint:当前时间步的关节位置数据。

    mr.FKinSpace 函数返回了一个变换矩阵 T_sb,该矩阵描述了机械臂末端执行器相对于基坐标系的位姿。

  5. xyz.append(T_sb[:3, 3]):从计算得到的变换矩阵 T_sb 中提取前三行、第四列的元素,这表示末端执行器的位置信息(x、y、z坐标),并将该位置信息添加到 xyz 列表中。

  6. return np.array(xyz):最后,将 xyz 列表转换为 NumPy 数组并返回,该数组包含了每个时间步末端执行器的位置信息。

综合起来,get_xyz 函数用于从关节位置数据计算出每个时间步末端执行器的位置信息,并将这些位置信息以 NumPy 数组的形式返回。这个功能在机器人运动学和轨迹分析中很常见,用于可视化机械臂末端执行器的运动轨迹。

最后,简单说一下act/convert_ee.py文件中的load_hdf5函数,对应代码如下:

def load_hdf5(dataset_dir, dataset_name):
    dataset_path = os.path.join(dataset_dir, dataset_name + ".hdf5")
    if not os.path.isfile(dataset_path):
        print(f"Dataset does not exist at \n{dataset_path}\n")
        exit()

    with h5py.File(dataset_path, "r") as root:
        is_sim = root.attrs["sim"]
        qpos = root["/observations/qpos"][()]
        qvel = root["/observations/qvel"][()]
        action = root["/action"][()]
        image_dict = dict()
        for cam_name in root[f"/observations/images/"].keys():
            image_dict[cam_name] = root[f"/observations/images/{cam_name}"][()]

    return qpos, qvel, action, image_dict

这个 load_hdf5 函数用于从 HDF5 数据集文件中加载数据。下面解释这个函数的主要步骤:

  1. dataset_path = os.path.join(dataset_dir, dataset_name + ".hdf5"):这一行代码构建了要加载的 HDF5 数据集文件的完整路径,将 dataset_name 添加到指定的 dataset_dir 中,并附加文件扩展名 ".hdf5"。

  2. if not os.path.isfile(dataset_path)::这是一个条件语句,检查指定路径下的 HDF5 文件是否存在。如果文件不存在,会打印一条消息并退出程序。

  3. with h5py.File(dataset_path, "r") as root::使用 h5py 库打开指定的 HDF5 文件,并将其存储在 root 变量中,使用 "r" 模式表示只读模式。

  4. is_sim = root.attrs["sim"]:从 HDF5 数据集文件的属性中获取名为 "sim" 的属性值,这个属性可能表示数据集是否用于模拟。

  5. qpos = root["/observations/qpos"][()]:从数据集文件中提取关节位置数据(qpos),并将其存储在变量 qpos 中。注意 [()] 用于从 HDF5 数据集中读取数据并将其转换为 NumPy 数组。

  6. qvel = root["/observations/qvel"][()]:类似地,从数据集文件中提取关节速度数据(qvel)。

  7. action = root["/action"][()]:从数据集文件中提取操作(action)数据。

  8. image_dict = dict():创建一个空的字典 image_dict,用于存储图像数据。

  9. 循环遍历图像数据,对于数据集中的每个摄像头(cam_name),执行以下操作:

    • root[f"/observations/images/{cam_name}"]:访问 HDF5 数据集中的图像数据集,其中 cam_name 是摄像头的名称。

    • root[f"/observations/images/{cam_name}"][()]:从图像数据集中读取图像数据并将其转换为 NumPy 数组。

    • image_dict[cam_name] = ...:将图像数据存储在 image_dict 字典中,以摄像头名称作为键。

  10. 最后,返回加载的数据,包括关节位置数据 qpos、关节速度数据 qvel、操作数据 action 以及图像数据 image_dict

这个函数的主要目的是加载来自 HDF5 数据集文件的数据,以便进一步处理和分析。它是一个通用的数据加载函数,适用于各种包含关节数据、图像数据和其他数据的机器学习或模拟数据集。

你可能感兴趣的:(Aloha,学习,机器人,python)