3DGS源码解读 - duplicateWithKeys 和 RadixSort

duplicateWithKeys 和 RadixSort

我们先来看一下 duplicateWithKeys 和 RadixSort 的大致流程:
3DGS源码解读 - duplicateWithKeys 和 RadixSort_第1张图片
图片来源于文献 FlashGS: Efficient 3D Gaussian Splatting for Large-scale and High-resolution Rendering
duplicateWithKeys 部分的关键代码如下:

// 如果 radii[idx] <= 0,说明该 Gaussian 在屏幕上不产生任何影响,可直接跳过,节省后续计算和内存写入
if (radii[idx] > 0) {
    // 每个高斯点在输出数组中的起始偏移量
	uint32_t off = (idx == 0) ? 0 : offsets[idx - 1];

	// 根据 Gaussian 的中心屏幕坐标与半径,计算它在 tile 网格下的最小/最大 tile 坐标 rect_min、rect_max
	uint2 rect_min, rect_max;
	getRect(points_xy[idx], radii[idx], rect_min, rect_max, grid);

	for (int y = rect_min.y; y < rect_max.y; y++) {
		for (int x = rect_min.x; x < rect_max.x; x++) {
			uint64_t key = y * grid.x + x;        // 创建 64 位的 Key,通过 y * grid.x + x 计算 tile ID
			key <<= 32;                           // 高 32 位用于存储 tile id
			key |= *((uint32_t*)&depths[idx]);    // 低 32 位用于存储深度值
			gaussian_keys_unsorted[off] = key;    // 键 = tile id + 深度值
			gaussian_values_unsorted[off] = idx;  // 高斯点的索引 idx
			off++;
		}
	}
}

在后续的基数排序(RadixSort)中,会先按 Tile ID 排序(将同一瓦片的高斯点放在一起),再按深度排序(确保远到近或近到远的渲染顺序)。
论文中关于这部分的描述如下:

We then instantiate each Gaussian according to the number of tiles they overlap and assign each instance a key that combines view space depth and tile ID. We then sort Gaussians based on these keys using a single fast GPU Radix sort.

问:为什么需要执行 duplicateWithKeys 和 RadixSort 呢?
答:3DGS 的渲染需要解决以下两个关键问题:
1)确定每个高斯点覆盖哪些 tile,以便按 tile 并行渲染;
2)同一 Tile 内的高斯点需按深度排序,以支持 alpha blending。
执行 duplicateWithKeys 和 RadixSort 的作用主要有两个:
1)按 tile 分组后,每个线程块可以独立处理一个 tile 的高斯点,避免跨 tile 的线程竞争;
2)按深度排序后,alpha blending 可直接按顺序叠加颜色,无需额外的深度测试。

基数排序(RadixSort)

基数排序的本质是一种非比较型、基于分桶的稳定排序算法,其核心思想是通过逐位处理数字的每一位(如个位、十位、百位等),将数据分配到不同的桶中,并在每一轮排序中利用稳定排序的特性,逐步构建最终的有序序列。基数排序的形象化演示可以参考视频:基数排序-2分钟极速掌握。
基数排序的时间复杂度是 O ( k × n ) O(k×n) O(k×n),其中 n n n 是排序元素的个数, k k k 是最大数的位数。
基数排序的稳定性是其成立的关键条件:
1)稳定排序:在每一轮排序中,相同位值的元素会按照它们在原序列中的相对顺序进入同一个桶;
2)保留历史信息:前一轮排序的结果在后续轮次中不会被破坏。例如,十位排序后,十位相同的元素会根据个位排序的结果保持顺序。
现在使用 Python 实现基数排序,代码如下:

def radix_sort(arr):
    if not arr:
        return arr

    max_num = max(arr)  # 找出最大值,以确定最大数的位数
    exp = 1  # 当前处理的位数(从个位开始)
    # 当最大数在当前位上仍不为 0 时,继续处理更高位
    while max_num // exp > 0:
        buckets = [[] for _ in range(10)]  # 初始化 10 个桶(0-9)

        # 将每个数字放入对应的桶中
        for num in arr:
            digit = (num // exp) % 10  # 提取当前位的数字
            buckets[digit].append(num)

        # 按桶的顺序收集元素,形成新的数组
        arr = [num for bucket in buckets for num in bucket]
        exp *= 10  # 处理下一位(十位、百位...)
    return arr


if __name__ == "__main__":
    nums = [170, 45, 75, 90, 802, 24, 2, 66]
    sorted_nums = radix_sort(nums)
    print("排序后:", sorted_nums)  # 排序后: [2, 24, 45, 66, 75, 90, 170, 802]

你可能感兴趣的:(3DGS,人工智能)