ChatGLM2-INT4 + Lora 结构适配和改造

Lora 是目前公认的最好的微调方法,一方面,它并不像AdapterTuning 一样,改变原有模型的架构,不便于在不同框架之间迁移;另一方面,它不像 PTuning 一样改变所有任务下的单词生成概率,严重破坏已习得的知识。

ChatGLM2-INT4 这个量化版本使用自定义的QuantizedLinear作为线性模块。如果我们要使用 PEFT 库添加 Lora 参数时,它就会不认识,因为它是为torch.nn.Linear准备的,于是我们只能自己写个模块来实现这件事情。

一、编写LoraQuantizedLinear

LoraQuantizedLinear是我们自己的带Lora的线性层,包含QuantizedLinear所有参数/配置和Lora的主要配置项。

通过阅读quantization.py,我们确定QuantizedLinear的参数或配置有四个:

  • weight:量化后权重,形状为[OutDim, InDim]。注意在 INT4 模式下,一个 INT8 元素当两个 INT4 使用,InDim是 FP16 版本的一半。
  • weight_scale:量化的缩放系数,形状为[OutDim]。也就是说这个量化针对每一个隐藏状态确定一个范围,而不是整个参数。
  • bias:不量化的偏置,形状为[OutDim]
  • weight_bit_width:量化位数,4 或者 8。

新的线性层应该原样保存这些参数,并且应该包含Lora的三个主要配置:

  • r:较低的维度大小
  • alpha:和r一起组成缩放系数
  • dropout_rate:前置 Dropout 的比例

新的线性层创建如下:

class LoraQuantizedLinear(torch.nn.Module):

    def __init__(self, q_linear, lora_r=32, lora_alpha=32, lora_dropout_rate=0.0):
        super().__init__()
        
        # 保存原始参数和Lora配置
        self.lora_r = lora_r
        self.lora_alpha = lora_alpha
        self.lora_dropout_rate = lora_dropout_rate
        self.weight_bit_width = q_linear.weight_bit_width
        self.weight = q_linear.weight
        self.weight_scale = q_linear.weight_scale
        self.bias = q_linear.bias

        # 冻结原始参数
        self.weight.requires_grad = False
        self.weight_scale.requires_grad = False
        if self.bias is not None: self.bias.requires_grad = False

        # 创建 Lora 参数,FP16
        out_dim, in_dim = self.weight.shape
        # INT4 模型下,InDim 是原始大小的一半
        if self.weight_bit_width == 4: in_dim *= 2
        # LoraA 正态初始化
        self.lora_a = torch.nn.Parameter(torch.empty(
            [self.lora_r, in_dim],
            device=self.weight.device,
            dtype=torch.float16,
        ))
        torch.nn.init.kaiming_normal_(self.lora_a)
        # LoraB 全零初始化
        self.lora_b = torch.nn.Parameter(torch.zeros(
            [out_dim, self.lora_r],
            device=self.weight.device,
            dtype=torch.float16,
        ))
        self.lora_dropout = torch.nn.Dropout(self.lora_dropout_rate)
        self.lora_scale = self.lora_alpha / self.lora_r

正向传播过程中,先用之前的forward()方法完成原始输出的计算,之后再手动编写 Lora 输出的计算,然后相加并返回:

# class LoraQuantizedLinear

    def forward(self, input):
        ori_output = QuantizedLinear.forward(self, input)
        lora_output = (
            self.lora_dropout(input.half()) @ 
            self.lora_a.transpose(0, 1) @ 
            self.lora_b.transpose(0, 1) * 
            self.lora_scale
        )
        return ori_output + lora_output.to(ori_output.dtype)

合并方法,我们将原始参数解量化,将Lora两个参数相乘后与之相加,然后再次量化。

# class LoraQuantizedLinear

    def merge(self):
        # H = XW + b + XAB * s => H = X(W + AB * s) + b
        # 将 int 原始参数转成 fp16
        weight = extract_weight_to_half(self.weight, self.weight_scale, self.weight_bit_width)
        # 合并 lora 参数
        weight += self.lora_b @ self.lora_a * self.lora_scale
        # 再转回 int
        weight, weight_scale = half_weight_to_int(weight, self.weight_bit_width)
        self.weight = torch.nn.Parameter(weight, requires_grad=False)
        self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
        # 重新初始化 lora 两个矩阵
        torch.nn.init.kaiming_normal_(self.lora_a)
        torch.nn.init.zeros_(self.lora_b)

half_weight_to_int取自QuantizedLinear类的构造器:

def half_weight_to_int(weight: torch.Tensor, weight_bit_width: int):
    assert weight_bit_width in [4, 8]
    assert weight.ndim == 2
    weight_scale = weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)
    weight = torch.round(weight / weight_scale[:, None]).to(torch.int8)
    if weight_bit_width == 4:
        weight = compress_int4_weight(weight)
    return weight, weight_scale

二、辅助方法

之后实现几个辅助方法,完成参数挂载卸载和合并:

首先是attach_lora:将所有的QuantizedLinear改成LoraQuantizedLinear

我们搜索模型的所有模块,再搜索它的直接子模块,如果任何东西是QuantizedLinear,就把它替换掉。之后把非 Lora 参数冻结。

def attach_lora(model, lora_r=32, lora_alpha=32, lora_dropout_rate=0.0):
    if model.lora_attached: return model
    lora_conf = dict(lora_r=lora_r, lora_alpha=lora_alpha, lora_dropout_rate=lora_dropout_rate)
    for mod in model.modules():
        for name in dir(mod):
            submod = getattr(mod, name, None)
            if not isinstance(submod, QuantizedLinear):
                continue
            new_submod = LoraQuantizedLinear(submod, **lora_conf)
            setattr(mod, name, new_submod)
   
    for name, param in model.named_parameters():
        if 'lora_' not in name:
            param.requires_grad = False
    model.lora_attached = True
    return model

lora_state_dict:导出所有 Lora 参数。

def lora_state_dict(model):
    return {   
        k:v
        for k, v in model.state_dict().items()
        if 'lora_' in k
    }

merge_lora:将所有LoraQuantizedLinear的参数合并。

def merge_lora(model):
    for mod in model.modules():
        if isinstance(mod, LoraQuantizedLinear):
            mod.merge()
    return model

detach_lora:搜索模型的所有模块,再搜索它的直接子模块,如果任何东西是LoraQuantizedLinear,就把它替换回QuantizedLinear

def detach_lora(model):
    if not model.lora_attached: return model
    
    for mod in model.modules():
        for name in dir(mod):
            submod = getattr(mod, name, None)
            if not isinstance(submod, LoraQuantizedLinear):
                continue
            new_submod = QuantizedLinear.from_params(
                submod.weight_bit_width,
                submod.weight,
                submod.weight_scale,
                submod.bias,
            )
            setattr(mod, name, new_submod)
    
    model.lora_attached = False
    return model

这就需要给QuantizedLinear加一个工厂方法,接受它的四个参数并直接保存。原有的构造器与之不兼容,于是实现为类方法:

# class QuantizedLinear

    @classmethod
    def from_params(cls, weight_bit_width, weight, weight_scale, bias):
        obj = cls.__new__(cls)
        super(cls, obj).__init__()
        obj.weight_bit_width = weight_bit_width
        obj.weight = weight
        obj.weight_scale = weight_scale
        obj.bias = bias
        return obj

由于这些方法是实现在单独的文件中的(例如lora.py),我们在modeling_chatglm.py中导入这个文件,然后添加到ChatGLMForConditionalGeneration类里面,便于使用:

from .lora import attach_lora, detach_lora, merge_lora, lora_state_dict

# class ChatGLMForConditionalGeneration

    # def __init__
        self.lora_attached = False

    attach_lora = attach_lora
    detach_lora = detach_lora
    merge_lora = merge_lora
    lora_state_dict = lora_state_dict

然后一切就完成了。

你可能感兴趣的:(人工智能,pytorch,深度学习,python)