nnUNet V2代码——图像增强(三)

本文阅读的nnU-Net V2图像增强有亮度调整对比度调整低分辨率调整

各个类内的各个函数的调用关系见前文nnUNet V2代码——图像增强(一)的BasicTransform

安装batchgeneratorsv2,nnU-Net V2关于图像增强的代码都在这个库中,点击链接,将其clone到本地后,在命令行进入文件夹内,pip install -e . 即可(注意-e后有个点)。

本文目录

  • 一 MultiplicativeBrightnessTransform
    • 1. __init__函数
    • 2. get_parameters函数
    • 3. _apply_to_image函数
  • 二 ContrastTransform
    • 1. __init__函数
    • 2. get_parameters函数
    • 3. _apply_to_image函数
  • 三 SimulateLowResolutionTransform类
    • 1. __init__函数
    • 2. get_parameters函数
    • 3. _apply_to_image函数

一 MultiplicativeBrightnessTransform

该类包含亮度调整,继承自ImageOnlyTransform类,只对image施加,seg不施加。

代码在batchgeneratorsv2 \ transforms \ intensity \ brightness.py文件中

MultiplicativeBrightnessTransform代码比SpatialTransform类、GaussianNoiseTransform类、GaussianBlurTransform类简洁,但代码逻辑一致

1. __init__函数

定义必要的类内变量,代码清晰,不做粘贴,变量在用到时再介绍作用

2. get_parameters函数

def get_parameters(self, **data_dict) -> dict:
	## 获取image大小
    shape = data_dict['image'].shape
    ## 确定哪些通道要施加亮度调整
    ### self.p_per_channel = 1,nnU-Net V2对每个通道都施加亮度调整
    apply_to_channel = torch.where(torch.rand(shape[0]) < self.p_per_channel)[0]
    ## 各通道同步施加相同的亮度调整
    if self.synchronize_channels:
        multipliers = torch.Tensor([sample_scalar(self.multiplier_range, image=data_dict['image'], channel=None)] * len(apply_to_channel))
    ## 各通道各自施加自己的亮度调整
    else:	### self.synchronize_channels = False,nnU-Net V2不同步施加
        multipliers = torch.Tensor([sample_scalar(self.multiplier_range, image=data_dict['image'], channel=c) for c in apply_to_channel])
    ## 收集参数后返回
    return {
        'apply_to_channel': apply_to_channel,
        'multipliers': multipliers
    }

sample_scalar函数见nnUNet V2代码——图像增强(一)的其余函数

3. _apply_to_image函数

def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
	## 没有要调整的通道,直接返回
    if len(params['apply_to_channel']) == 0:
        return img
    ## 遍历施加亮度调整
    for c, m in zip(params['apply_to_channel'], params['multipliers']):
        img[c] *= m
    return img

二 ContrastTransform

该类负责对比度调整,继承自ImageOnlyTransform类,只对image施加,seg不施加。

代码在batchgeneratorsv2 \ transforms \ intensity \ contrast.py文件中

1. __init__函数

定义必要的类内变量,代码清晰,不做粘贴,变量在用到时再介绍作用

2. get_parameters函数

MultiplicativeBrightnessTransform的get_parameters函数一致

3. _apply_to_image函数

def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
    if len(params['apply_to_channel']) == 0:
        return img
    ## 遍历通道
    for i in range(len(params['apply_to_channel'])):
        c = params['apply_to_channel'][i]
        ## 获取图像某一通道的平均值
        mean = img[c].mean()
        ## 是否保留数值范围,nnU-Net V2设置self.preserve_range = True
        if self.preserve_range:
            minm = img[c].min()
            maxm = img[c].max()

        ## 对比度调整
        img[c] -= mean
        img[c] *= params['multipliers'][i]
        img[c] += mean
		
		## 是否保留数值范围
        if self.preserve_range:
            img[c].clamp_(minm, maxm)

    return img

三 SimulateLowResolutionTransform类

该类负责施加低分辨率,继承自ImageOnlyTransform类,只对image施加,seg不施加。

代码在batchgeneratorsv2 \ transforms \ spatial\ low_resolution.py文件中

1. __init__函数

变量名称 含义
self.scale 图像放缩范围
self.synchronize_channels 通道是否施加相同的低分辨率处理
self.synchronize_axes 是否施加相同的低分辨率处理
self.ignore_axes 某轴不能施加低分辨率处理,与nnUNet V2代码——图像增强(一)的Convert3DTo2DTransformConvert2DTo3DTransform有关
self.allowed_channels 可能会施加低分辨率处理的通道
self.p_per_channel 某通道施加低分辨率处理的概率
self.upmodes 各维度采样方法
self.scale = scale
self.synchronize_channels = synchronize_channels
self.synchronize_axes = synchronize_axes
self.ignore_axes = ignore_axes
self.allowed_channels = allowed_channels
self.p_per_channel = p_per_channel

self.upmodes = {
    1: 'linear',
    2: 'bilinear',
    3: 'trilinear'
}

2. get_parameters函数

def get_parameters(self, **data_dict) -> dict:
    shape = data_dict['image'].shape
    ## nnU-Net V2设置为None,所有通道按概率施加低分辨率处理
    if self.allowed_channels is None:
        apply_to_channel = torch.where(torch.rand(shape[0]) < self.p_per_channel)[0]
    else:
        apply_to_channel = [i for i in self.allowed_channels if torch.rand(1) < self.p_per_channel]
    ## nnU-Net V2设置为False
    if self.synchronize_channels:
        ## nnU-Net V2设置为True
        if self.synchronize_axes:
            ## 各通道、各轴施加相同的
            scales = torch.Tensor([[sample_scalar(self.scale, image=data_dict['image'], channel=None, dim=None)] * (len(shape) - 1)] * len(apply_to_channel))
        else:
        	## 各通道施加相同的,各轴施加各自的
            scales = torch.Tensor([[sample_scalar(self.scale, image=data_dict['image'], channel=None, dim=d) for d in range(len(shape) - 1)]] * len(apply_to_channel))
    else:
        if self.synchronize_axes:
        	## 各轴施加相同的,各通道施加各自的
            scales = torch.Tensor([[sample_scalar(self.scale, image=data_dict['image'], channel=c, dim=None)]  * (len(shape) - 1) for c in apply_to_channel])
        else:
        	## 各通道、各轴施加各自的
            scales = torch.Tensor([[sample_scalar(self.scale, image=data_dict['image'], channel=c, dim=d) for d in range(len(shape) - 1)] for c in apply_to_channel])
    ## 对忽略的轴单独处理
    if len(scales) > 0:
        scales[:, self.ignore_axes] = 1
    return {
        'apply_to_channel': apply_to_channel,
        'scales': scales
    }

3. _apply_to_image函数

def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
    orig_shape = img.shape[1:]
    # 注释机翻:我们无法对这些内容进行批处理,因为每个通道的下采样 SHAP 值会有所不同。
    for c, s in zip(params['apply_to_channel'], params['scales']):
        ## 按照放缩尺度确定下采样后的图像大小
        new_shape = [round(i * j.item()) for i, j in zip(orig_shape, s)]
        ## 使用某一种最近邻插值进行采样
        downsampled = interpolate(img[c][None, None], new_shape, mode='nearest-exact')
        ## 还原图像大小
        img[c] = interpolate(downsampled, orig_shape, mode=self.upmodes[img.ndim - 1])[0, 0]
    return img

你可能感兴趣的:(阅读nnUNet,V2代码,图像增强,计算机视觉,nnunet,机器学习,深度学习,人工智能,图像增强)