【MindSpore】【训练】使用TrainOneStepCell报错

【功能模块】

myTrainOneStepCell

【操作步骤&问题现象】

1、代码参见下文。

2、我的数据集的每个样本是由:(图片数据、音频数据、标签)组成的,参考多标签数据集使用trainOneStepCell,发现报错。

【截图信息】

【MindSpore】【训练】使用TrainOneStepCell报错_第1张图片

【日志信息】(可选,上传日志内容或者附件)

【代码】

import argparse
import cv2
import os
import random
from glob import glob
from os.path import dirname, join, basename, isfile

import numpy as np
import mindspore.nn as nn
from tqdm import tqdm

import audio
from hparams import hparams, get_image_list
import mindspore.ops as ops
import mindspore as ms
from models_ms import SyncNetColor
from mindspore import Parameter, Tensor, load_param_into_net, save_checkpoint, load_checkpoint, context, Model
from sklearn.metrics.pairwise import cosine_similarity
from mindspore.nn import LossBase
import mindspore.dataset as ds

context.set_context(device_target='CPU', mode=context.GRAPH_MODE)

parser = argparse.ArgumentParser(description='Code to train the expert lip-sync discriminator')

parser.add_argument("--data_root", help="Root folder of the preprocessed LRS2 dataset", required=True)

parser.add_argument('--save_ckpt_dir', help='Save checkpoints to this directory', required=True, type=str)
parser.add_argument('--load_ckpt_path', help='Resumed from this checkpoint', default=None, type=str)

args = parser.parse_args()

global_step = 0
global_epoch = 0

syncnet_T = 5
syncnet_mel_step_size = 16


def get_image_list(data_root, split):
    filelist = []

    with open('filelists/{}.txt'.format(split)) as f:
        for line in f:
            line = line.strip()
            if ' ' in line:
                line = line.split()[0]
            filelist.append(os.path.join(data_root, line))

    return filelist


class Dataset:
    def __init__(self, split):
        self.all_videos = get_image_list(args.data_root, split)

    @staticmethod
    def get_frame_id(frame):
        return int(basename(frame).split('.')[0])

    def get_window(self, start_frame):
        start_id = self.get_frame_id(start_frame)
        vidname = dirname(start_frame)

        window_fnames = []
        for frame_id in range(start_id, start_id + syncnet_T):
            frame = join(vidname, '{}.jpg'.format(frame_id))
            if not isfile(frame):
                return None
            window_fnames.append(frame)
        return window_fnames

    def crop_audio_window(self, spec, start_frame):
        start_frame_num = self.get_frame_id(start_frame)
        start_idx = int(80. * (start_frame_num / float(hparams.fps)))

        end_idx = start_idx + syncnet_mel_step_size

        return spec[start_idx: end_idx, :]

    def __len__(self):
        return len(self.all_videos)

    def __getitem__(self, idx):
        while 1:
            idx = random.randint(0, len(self.all_videos) - 1)
            vidname = self.all_videos[idx]

            img_names = list(glob(join(vidname, '*.jpg')))
            if len(img_names) <= 3 * syncnet_T:
                continue
            img_name = random.choice(img_names)
            wrong_img_name = random.choice(img_names)
            while wrong_img_name == img_name:
                wrong_img_name = random.choice(img_names)

            if random.choice([True, False]):
                y = ops.Ones()(1, ms.float32)
                chosen = img_name
            else:
                y = ops.Zeros()(1, ms.float32)
                chosen = wrong_img_name

            window_fnames = self.get_window(chosen)
            if window_fnames is None:
                continue

            window = []
            all_read = True
            for fname in window_fnames:
                img = cv2.imread(fname)
                if img is None:
                    all_read = False
                    break
                try:
                    img = cv2.resize(img, (hparams.img_size, hparams.img_size))
                except Exception as e:
                    all_read = False
                    break

                window.append(img)

            if not all_read:
                continue

            try:
                wav_path = join(vidname, "audio.wav")
                wav = audio.load_wav(wav_path, hparams.sample_rate)

                orig_mel = audio.melspectrogram(wav).T
            except Exception as e:
                continue

            mel = self.crop_audio_window(orig_mel.copy(), img_name)

            if (mel.shape[0] != syncnet_mel_step_size):
                continue

            # H x W x 3 * T
            x = np.concatenate(window, axis=2) / 255.
            x = x.transpose(2, 0, 1)
            x = x[:, x.shape[1] // 2:]

            x = Tensor(x, ms.float32)
            mel = ops.ExpandDims()(Tensor(mel.T, ms.float32), 0)

            return x, mel, y


class CosineLoss(nn.Cell):
    def __init__(self):
        super(CosineLoss, self).__init__()
        self.criterion = nn.BCELoss(reduction='mean')

    def construct(self, a, v, y):
        index, _ = a.shape
        cos_array = np.array([])
        for i in range(index):
            cos_sim = cosine_similarity(a.asnumpy().reshape(1, -1), v.asnumpy().reshape(1, -1))
            cos_array = np.append(cos_array, cos_sim[0][0])
        t = ops.ExpandDims()(Tensor(cos_array, ms.float32), 1)
        result = self.criterion(t, y)
        return result


def train(model, train_data_loader, test_data_loader, optimizer, checkpoint_dir=None, checkpoint_interval=None,
          nepochs=None):
    global global_epoch, global_step
    while global_epoch < 100000:
        prog_bar = tqdm(enumerate(train_data_loader))
        for step, (x, mel, y) in prog_bar:
            loss = model(mel, x)
            print(loss)


            # a, v = model(mel, x)
            # loss = cosine_loss(a, v, y)
            # global_step += 1
            prog_bar.set_description('Loss: {}'.format(loss))

        global_epoch += 1



class CustomWithLossCell(nn.Cell):
    def __init__(self, backbone, loss_fn):
        super(CustomWithLossCell, self).__init__(auto_prefix=False)
        self._backbone = backbone
        self._loss_fn = loss_fn

    def construct(self, x, mel, y):
        a, v = self._backbone(mel, x)
        return self._loss_fn(a, v, y)


if __name__ == "__main__":

    save_ckpt_dir = args.save_ckpt_dir
    load_ckpt_path = args.load_ckpt_path

    if not os.path.exists(save_ckpt_dir):
        os.mkdir(save_ckpt_dir)

    # Dataset and Dataloader setup
    train_dataset = Dataset('train')
    test_dataset = Dataset('val')

    train_data_loader = ds.GeneratorDataset(source=train_dataset,
                                            column_names=["img", "wav", "label"],
                                            shuffle=True)
    train_data_loader = train_data_loader.batch(batch_size=64, drop_remainder=True)
    test_data_loader = ds.GeneratorDataset(test_dataset, column_names=["img", "wav", "label"])
    test_data_loader = test_data_loader.batch(batch_size=64, drop_remainder=True)

    # Model
    net = SyncNetColor()

    loss = CosineLoss()
    loss_net = CustomWithLossCell(net, loss)

    optimizer = nn.Adam([p for _, p in net.parameters_dict().items() if p.requires_grad], learning_rate=hparams.syncnet_lr)

    myTrainOneStepCell = nn.TrainOneStepCell(loss_net, optimizer)

    if load_ckpt_path is not None:
        load_checkpoint(load_ckpt_path, net)



    train(myTrainOneStepCell, train_data_loader, test_data_loader, optimizer,
          checkpoint_dir=save_ckpt_dir,
          checkpoint_interval=100,
          nepochs=100)

 1.cosine_similarity 功能,可以使用ReduceSum、sqrt、div等小算子拼出来

2.我们的Reshape算子支持上述(-1,1)的运算。您可以查阅一下对应算子的使用是否满足您的需求

你可能感兴趣的:(python,开发语言)