【功能模块】
myTrainOneStepCell
【操作步骤&问题现象】
1、代码参见下文。
2、我的数据集的每个样本是由:(图片数据、音频数据、标签)组成的,参考多标签数据集使用trainOneStepCell,发现报错。
【截图信息】
【日志信息】(可选,上传日志内容或者附件)
【代码】
import argparse import cv2 import os import random from glob import glob from os.path import dirname, join, basename, isfile import numpy as np import mindspore.nn as nn from tqdm import tqdm import audio from hparams import hparams, get_image_list import mindspore.ops as ops import mindspore as ms from models_ms import SyncNetColor from mindspore import Parameter, Tensor, load_param_into_net, save_checkpoint, load_checkpoint, context, Model from sklearn.metrics.pairwise import cosine_similarity from mindspore.nn import LossBase import mindspore.dataset as ds context.set_context(device_target='CPU', mode=context.GRAPH_MODE) parser = argparse.ArgumentParser(description='Code to train the expert lip-sync discriminator') parser.add_argument("--data_root", help="Root folder of the preprocessed LRS2 dataset", required=True) parser.add_argument('--save_ckpt_dir', help='Save checkpoints to this directory', required=True, type=str) parser.add_argument('--load_ckpt_path', help='Resumed from this checkpoint', default=None, type=str) args = parser.parse_args() global_step = 0 global_epoch = 0 syncnet_T = 5 syncnet_mel_step_size = 16 def get_image_list(data_root, split): filelist = [] with open('filelists/{}.txt'.format(split)) as f: for line in f: line = line.strip() if ' ' in line: line = line.split()[0] filelist.append(os.path.join(data_root, line)) return filelist class Dataset: def __init__(self, split): self.all_videos = get_image_list(args.data_root, split) @staticmethod def get_frame_id(frame): return int(basename(frame).split('.')[0]) def get_window(self, start_frame): start_id = self.get_frame_id(start_frame) vidname = dirname(start_frame) window_fnames = [] for frame_id in range(start_id, start_id + syncnet_T): frame = join(vidname, '{}.jpg'.format(frame_id)) if not isfile(frame): return None window_fnames.append(frame) return window_fnames def crop_audio_window(self, spec, start_frame): start_frame_num = self.get_frame_id(start_frame) start_idx = int(80. * (start_frame_num / float(hparams.fps))) end_idx = start_idx + syncnet_mel_step_size return spec[start_idx: end_idx, :] def __len__(self): return len(self.all_videos) def __getitem__(self, idx): while 1: idx = random.randint(0, len(self.all_videos) - 1) vidname = self.all_videos[idx] img_names = list(glob(join(vidname, '*.jpg'))) if len(img_names) <= 3 * syncnet_T: continue img_name = random.choice(img_names) wrong_img_name = random.choice(img_names) while wrong_img_name == img_name: wrong_img_name = random.choice(img_names) if random.choice([True, False]): y = ops.Ones()(1, ms.float32) chosen = img_name else: y = ops.Zeros()(1, ms.float32) chosen = wrong_img_name window_fnames = self.get_window(chosen) if window_fnames is None: continue window = [] all_read = True for fname in window_fnames: img = cv2.imread(fname) if img is None: all_read = False break try: img = cv2.resize(img, (hparams.img_size, hparams.img_size)) except Exception as e: all_read = False break window.append(img) if not all_read: continue try: wav_path = join(vidname, "audio.wav") wav = audio.load_wav(wav_path, hparams.sample_rate) orig_mel = audio.melspectrogram(wav).T except Exception as e: continue mel = self.crop_audio_window(orig_mel.copy(), img_name) if (mel.shape[0] != syncnet_mel_step_size): continue # H x W x 3 * T x = np.concatenate(window, axis=2) / 255. x = x.transpose(2, 0, 1) x = x[:, x.shape[1] // 2:] x = Tensor(x, ms.float32) mel = ops.ExpandDims()(Tensor(mel.T, ms.float32), 0) return x, mel, y class CosineLoss(nn.Cell): def __init__(self): super(CosineLoss, self).__init__() self.criterion = nn.BCELoss(reduction='mean') def construct(self, a, v, y): index, _ = a.shape cos_array = np.array([]) for i in range(index): cos_sim = cosine_similarity(a.asnumpy().reshape(1, -1), v.asnumpy().reshape(1, -1)) cos_array = np.append(cos_array, cos_sim[0][0]) t = ops.ExpandDims()(Tensor(cos_array, ms.float32), 1) result = self.criterion(t, y) return result def train(model, train_data_loader, test_data_loader, optimizer, checkpoint_dir=None, checkpoint_interval=None, nepochs=None): global global_epoch, global_step while global_epoch < 100000: prog_bar = tqdm(enumerate(train_data_loader)) for step, (x, mel, y) in prog_bar: loss = model(mel, x) print(loss) # a, v = model(mel, x) # loss = cosine_loss(a, v, y) # global_step += 1 prog_bar.set_description('Loss: {}'.format(loss)) global_epoch += 1 class CustomWithLossCell(nn.Cell): def __init__(self, backbone, loss_fn): super(CustomWithLossCell, self).__init__(auto_prefix=False) self._backbone = backbone self._loss_fn = loss_fn def construct(self, x, mel, y): a, v = self._backbone(mel, x) return self._loss_fn(a, v, y) if __name__ == "__main__": save_ckpt_dir = args.save_ckpt_dir load_ckpt_path = args.load_ckpt_path if not os.path.exists(save_ckpt_dir): os.mkdir(save_ckpt_dir) # Dataset and Dataloader setup train_dataset = Dataset('train') test_dataset = Dataset('val') train_data_loader = ds.GeneratorDataset(source=train_dataset, column_names=["img", "wav", "label"], shuffle=True) train_data_loader = train_data_loader.batch(batch_size=64, drop_remainder=True) test_data_loader = ds.GeneratorDataset(test_dataset, column_names=["img", "wav", "label"]) test_data_loader = test_data_loader.batch(batch_size=64, drop_remainder=True) # Model net = SyncNetColor() loss = CosineLoss() loss_net = CustomWithLossCell(net, loss) optimizer = nn.Adam([p for _, p in net.parameters_dict().items() if p.requires_grad], learning_rate=hparams.syncnet_lr) myTrainOneStepCell = nn.TrainOneStepCell(loss_net, optimizer) if load_ckpt_path is not None: load_checkpoint(load_ckpt_path, net) train(myTrainOneStepCell, train_data_loader, test_data_loader, optimizer, checkpoint_dir=save_ckpt_dir, checkpoint_interval=100, nepochs=100)
1.cosine_similarity 功能,可以使用ReduceSum、sqrt、div等小算子拼出来
2.我们的Reshape算子支持上述(-1,1)的运算。您可以查阅一下对应算子的使用是否满足您的需求