GAN网络

""" Module implementing GAN which will be trained using the Progressive growing
    technique -> https://arxiv.org/abs/1710.10196
"""
import datetime
import os
import time
import timeit
import numpy as np
import torch as th


class Generator(th.nn.Module):
    """ Generator of the GAN network  创建生成器"""

    def __init__(self, depth=7, latent_size=512, dilation=1, use_spectral_norm=True):
        """
        constructor for the Generator class
        :param depth: required depth of the Network   网络深度
        :param latent_size: size of the latent manifold  潜在流行
        :param dilation: amount of dilation to be used by the 3x3 convs  使用的膨胀量
                         in the Generator module.
        :param use_spectral_norm: whether to use spectral normalization   是否使用光谱归一化
        """
        from torch.nn import ModuleList, Conv2d
        from MSG_GAN.CustomLayers import GenGeneralConvBlock, GenInitialBlock

        super().__init__()

        assert latent_size != 0 and ((latent_size & (latent_size - 1)) == 0), \
            "latent size not a power of 2"
        if depth >= 4:
            assert latent_size >= np.power(2, depth - 4), "latent size will diminish to zero"

        # state of the generator:  生成器的状态
        self.depth = depth
        self.latent_size = latent_size
        self.spectral_norm_mode = None
        self.dilation = dilation

        # register the modules required for the GAN Below  在下面注册 GAN 所需的模块...
        # create the ToRGB layers for various outputs:     为各种输出创建 ToRGB 层:...
        def to_rgb(in_channels):
            return Conv2d(in_channels, 3, (1, 1), bias=True)

        # create a module list of the other required general convolution blocks   创建其他所需通用卷积块的模块列表
        self.layers = ModuleList([GenInitialBlock(self.latent_size)])
        self.rgb_converters = ModuleList([to_rgb(self.latent_size)])

        # create the remaining layers  创建剩余的图层
        for i in range(self.depth - 1):
            if i <= 2:
                layer = GenGeneralConvBlock(self.latent_size, self.latent_size,
                                            dilation=dilation)
                rgb = to_rgb(self.latent_size)
            else:
                layer = GenGeneralConvBlock(
                    int(self.latent_size // np.power(2, i - 3)),
                    int(self.latent_size // np.power(2, i - 2)),
                    dilation=dilation
                )
                rgb = to_rgb(int(self.latent_size // np.power(2, i - 2)))
            self.layers.append(layer)
            self.rgb_converters.append(rgb)

        # if spectral normalization is on: 如果光谱归一化打开:
        if use_spectral_norm:
            self.turn_on_spectral_norm()

    def turn_on_spectral_norm(self):
        """
        private helper for turning on the spectral normalization
        :return: None (has side effect)
        """
        from torch.nn.utils import spectral_norm

        if self.spectral_norm_mode is not None:
            assert self.spectral_norm_mode is False, \
                "can't apply spectral_norm. It is already applied"

        # apply the same to the remaining relevant blocks
        for module in self.layers:
            module.conv_1 = spectral_norm(module.conv_1)
            module.conv_2 = spectral_norm(module.conv_2)

        # toggle the state variable:
        self.spectral_norm_mode = True

    def turn_off_spectral_norm(self):
        """
        private helper for turning off the spectral normalization
        :return: None (has side effect)
        """
        from torch.nn.utils import remove_spectral_norm

        if self.spectral_norm_mode is not None:
            assert self.spectral_norm_mode is True, \
                "can't remove spectral_norm. It is not applied"

        # remove the applied spectral norm
        for module in self.layers:
            remove_spectral_norm(module.conv_1)
            remove_spectral_norm(module.conv_2)

        # toggle the state variable:
        self.spectral_norm_mode = False

    def forward(self, x):
     

你可能感兴趣的:(pycharm,python)