前言:(RVQ网络介绍)
RVQ(Residual Vector Quantization,残差向量量化)网络是一种用于高效数据压缩和表示的量化技术,主要用于处理高维数据如音频、图像和视频等。
RVQ网络通过多层量化器逐步逼近原始数据:
第一层量化器对原始向量进行粗略量化
后续每一层量化器都对前一层的残差(误差)进行进一步量化
最终表示由所有层量化结果的叠加组成
分层结构:由多个量化器级联构成
残差学习:每一层都学习前一层量化后的残差
渐进逼近:通过多层量化逐步减小重构误差
高效表示:可以用较少的比特数表示高维数据
音频编码:如语音和音乐压缩
图像/视频压缩:高效表示视觉数据
机器学习:作为特征提取或降维工具
神经音频合成:如WaveNet等模型中用于量化音频信号
相比单层VQ,RVQ能获得更好的重构质量
可以灵活调整层数来平衡质量和计算复杂度
适合处理高维数据分布
RVQ及其变体在现代深度学习模型中,特别是在生成模型和压缩领域有着广泛的应用
一 RVQ网络模型
今天小编将带领大家走进RVQ网络模型
import torch
from torch import nn
import torch.nn.functional as F
class RVQ(nn.Module):
"""
残差向量量化(RVQ)模块
通过一系列码本对输入向量进行量化,每一步量化前一步的残差
参数:
num_codebooks: 使用的码本数量
codebook_size: 每个码本中的向量数量
codebook_dim: 每个码本向量的维度
enlarge_residual_times: 量化步骤间残差的放大倍数
commitment_cost: 约束损失权重(当前被注释掉)
"""
def __init__(self, num_codebooks, codebook_size, codebook_dim, enlarge_residual_times, commitment_cost=0.25):
super().__init__()
# 创建码本列表(每个都是Embedding层)
self.codebooks = nn.ModuleList([nn.Embedding(codebook_size, codebook_dim) for _ in range(num_codebooks)])
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.num_codebooks = num_codebooks
self.enlarge_residual_times = enlarge_residual_times
#self.commitment_cost = commitment_cost # 当前未使用
self._init_codebooks()
def _init_codebooks(self):
"""初始化码本权重,均匀分布在[-1/codebook_size, 1/codebook_size]区间"""
for codebook in self.codebooks:
codebook.weight.data.uniform_(-1/self.codebook_size, 1/self.codebook_size)
def forward(self, X, training=False, all_used_index=None):
"""
通过残差向量量化对输入向量进行量化
参数:
X: 待量化的输入张量
training: 是否处于训练模式
all_used_index: 训练时跟踪每个码本使用情况的字典
返回:
all_indices: 每个码本的索引列表
vq_loss: 训练时的总量化损失
"""
residual = X # 从原始输入开始
all_indices = []
vq_loss = torch.tensor(0.0)
# 依次处理每个码本
for codebook_id, codebook in enumerate(self.codebooks):
# 计算残差与码本向量之间的距离
distance = (torch.sum(residual**2, dim=1, keepdim=True)
+ torch.sum(codebook.weight**2, dim=1)
- 2 * torch.matmul(residual, codebook.weight.T))
# 找到最接近的码本向量
indices = torch.argmin(distance, dim=1)
all_indices.append(indices)
# 获取量化后的向量
quantized = codebook(indices)
if training:
# 计算量化损失(量化结果与残差之间的MSE)
vq_loss = vq_loss + F.mse_loss(quantized, residual)
# 记录使用的索引(用于部分重新初始化)
all_used_index[codebook_id] = all_used_index[codebook_id].union(indices.tolist())
# 更新残差(当前残差与量化结果的差)
# 放大残差以突出剩余信息
residual = (residual - quantized).detach() * self.enlarge_residual_times
return all_indices, vq_loss
def part_embeddings_reinit_(self, all_used_index=None):
"""
重新初始化未使用的码本条目
帮助在训练期间探索更多码本空间
参数:
all_used_index: 包含每个码本已使用索引的字典
"""
for codebook_id, used_index in enumerate(all_used_index):
# 创建未使用条目的掩码
mask = torch.ones(self.codebook_size, dtype=torch.bool)
mask[list(used_index)] = False
# 重新初始化未使用的条目
self.codebooks[codebook_id].weight.data[mask] = self.codebooks[codebook_id].weight.data[mask].uniform_(-1/self.codebook_size, 1/self.codebook_size)
class RVQEncoder(nn.Module):
"""
使用RVQ将输入向量转换为量化索引的编码器模块
"""
def __init__(self, num_codebooks, codebook_size, codebook_dim, enlarge_residual_times, codebook_adress=None):
super().__init__()
self.rvq = RVQ(num_codebooks, codebook_size, codebook_dim, enlarge_residual_times)
# 如果提供了预训练码本则加载
if codebook_adress:
self.rvq.codebooks.load_state_dict(torch.load(codebook_adress, weights_only=True))
self.num_codebooks = num_codebooks
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.enlarge_residual_times = enlarge_residual_times
def forward(self, X):
"""将输入向量转换为量化索引"""
index, _ = self.rvq(X)
return index
def quantized_Embedding(self, X):
"""从量化索引重构向量"""
for codebook_id, index in enumerate(X):
if codebook_id == 0:
# 从第一个码本的向量开始
output = self.rvq.codebooks[codebook_id](index)
else:
# 添加后续码本向量(按enlarge_residual_times缩小)
output += self.rvq.codebooks[codebook_id](index) / self.enlarge_residual_times
return output
class RVQDecoder(nn.Module):
"""
从量化索引重构向量的解码器模块
"""
def __init__(self, num_codebooks, codebook_size, codebook_dim, enlarge_residual_times, codebook_adress=None):
super().__init__()
# 创建码本(与编码器的码本分开)
self.codebooks = nn.ModuleList([nn.Embedding(codebook_size, codebook_dim) for _ in range(num_codebooks)])
self.enlarge_residual_times = enlarge_residual_times
# 如果提供了预训练码本则加载
if codebook_adress:
self.codebooks.load_state_dict(torch.load(codebook_adress, weights_only=True))
def forward(self, X):
"""从量化索引重构向量"""
for codebook_id, index in enumerate(X):
if codebook_id == 0:
# 从第一个码本的向量开始
output = self.codebooks[codebook_id](index)
print(codebook_id) # 调试打印(生产代码中应移除)
else:
# 添加后续码本向量(按enlarge_residual_times缩小)
output += self.codebooks[codebook_id](index) / self.enlarge_residual_times
print(codebook_id) # 调试打印(生产代码中应移除)
return output
残差向量量化(RVQ):
通过多阶段量化输入向量
每个阶段量化前一步的残差(误差)
残差在阶段间会被放大以突出剩余信息
码本管理:
每个码本实现为一个Embedding层
训练期间定期重新初始化未使用的码本条目
这有助于探索更多码本空间
编码器-解码器结构:
编码器将向量转换为码本索引序列
解码器从这些索引重构向量
两者都可以加载预训练码本
重要参数:
enlarge_residual_times
:控制残差在阶段间的放大倍数
codebook_size
:每个码本中的向量数量
codebook_dim
:每个码本向量的维度
解码器中的调试打印在生产代码中应该移除。原始VQ-VAE论文中的约束损失项在当前实现中被注释掉了。
二 模型训练
import os
import torch
from torch import nn
import torch.optim as optim
# 训练超参数设置
lr = 0.001 # 学习率
epochs = 1 # 训练轮数
root = r"D:\xxx\xxx\xxx"
def train_codebooks(num_codebooks, codebook_size, enlarge_residual_times, lr, epochs,
n_dims, batch_size, dataiter, save_root, load_root=None):
"""
训练RVQ码本的函数
参数:
num_codebooks: 码本数量
codebook_size: 每个码本的大小(向量数量)
enlarge_residual_times: 残差放大倍数
lr: 学习率
epochs: 训练轮数
n_dims: 向量维度
batch_size: 批大小
dataiter: 数据迭代器
save_root: 模型保存路径
load_root: 预训练模型加载路径
"""
# 设置设备(GPU或CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 初始化RVQ编码器
if load_root:
# 如果提供了预训练模型路径,则加载现有码本
rvq_encoder = RVQEncoder(num_codebooks, codebook_size, n_dims, enlarge_residual_times,
codebook_adress=os.path.join(load_root, 'codebooks.pth')).to(device)
else:
# 否则创建新的RVQ编码器
rvq_encoder = RVQEncoder(num_codebooks, codebook_size, n_dims, enlarge_residual_times).to(device)
# 设置优化器(只优化码本参数)
optimizer = optim.Adam(rvq_encoder.rvq.codebooks.parameters(), lr=lr)
# 训练循环
for epoch in range(epochs):
l = torch.tensor(0.0).to(device) # 初始化损失值
# 初始化记录每个码本使用情况的集合
all_used_index = [set() for _ in range(rvq_encoder.rvq.num_codebooks)]
# 遍历数据批次
for X in dataiter:
# 梯度清零
optimizer.zero_grad(True)
# RVQ量化处理,返回索引和损失
_, loss = rvq_encoder.rvq(X, training=True, all_used_index=all_used_index)
# 反向传播和参数更新
loss.backward()
optimizer.step()
# 记录并打印损失
with torch.no_grad():
l += loss
print('loss', l)
break # 调试用,只处理一个批次
# 重新初始化未使用的码本条目
rvq_encoder.rvq.part_embeddings_reinit_(all_used_index=all_used_index)
# 保存训练好的码本
torch.save(rvq_encoder.rvq.codebooks.state_dict(), os.path.join(save_root, 'codebooks.pth'))
码本训练流程:
初始化RVQ编码器(可选择加载预训练码本)
设置优化器仅优化码本参数
每个epoch中:
遍历数据批次
前向传播计算量化结果和损失
反向传播更新码本参数
重新初始化未使用的码本条目
重要功能:
part_embeddings_reinit_
方法确保所有码本条目都能被充分利用
训练过程中会记录每个码本的使用情况
支持从检查点继续训练(通过load_root参数)
参数说明:
enlarge_residual_times
控制残差放大倍数,影响量化精度
codebook_size
和num_codebooks
决定模型的容量和表达能力
n_dims
是特征向量的维度
使用注意:
代码中的break
语句是调试用的,实际使用时应移除
encoder.train_forward
需要提前定义好
保存路径和加载路径可以相同,实现模型覆盖保存
这段代码实现了一个完整的RVQ码本训练流程,支持从零开始训练和继续训练两种模式,适用于语音等时序数据的向量量化任务。