主代码
import argparse
import os
import torch
import torch.nn.parallel
import torch.utils.data
from utils import to_categorical
from collections import defaultdict
from torch.autograd import Variable
from data_utils.ShapeNetDataLoader import PartNormalDataset
import torch.nn.functional as F
import datetime
import logging
from pathlib import Path
from utils import test_partseg
from tqdm import tqdm
from model.pointnet2 import PointNet2PartSeg_msg_one_hot
from model.pointnet import PointNetDenseCls,PointNetLoss
seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43], 'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46], 'Mug': [36, 37], 'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27], 'Table': [47, 48, 49], 'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40], 'Chair': [12, 13, 14, 15], 'Knife': [22, 23]}
seg_label_to_cat = {} # {0:Airplane, 1:Airplane, ...49:Table}
for cat in seg_classes.keys():
for label in seg_classes[cat]:
seg_label_to_cat[label] = cat
def parse_args():
parser = argparse.ArgumentParser('PointNet2')
parser.add_argument('--batchsize', type=int, default=8, help='input batch size')
parser.add_argument('--workers', type=int, default=0, help='number of data loading workers')
parser.add_argument('--epoch', type=int, default=4, help='number of epochs for training')
parser.add_argument('--pretrain', type=str, default=None,help='whether use pretrain model')
parser.add_argument('--gpu', type=str, default='0', help='specify gpu device')
parser.add_argument('--model_name', type=str, default='pointnet', help='Name of model')
parser.add_argument('--learning_rate', type=float, default=0.001, help='learning rate for training')
parser.add_argument('--decay_rate', type=float, default=1e-4, help='weight decay')
parser.add_argument('--optimizer', type=str, default='Adam', help='type of optimizer')
parser.add_argument('--multi_gpu', type=str, default=None, help='whether use multi gpu training')
parser.add_argument('--jitter', default=False, help="randomly jitter point cloud")
parser.add_argument('--step_size', type=int, default=20, help="randomly rotate point cloud")
return parser.parse_args()
def main(args):
#创建文件夹
# os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu if args.multi_gpu is None else '0,1,2,3'
'''CREATE DIR'''
experiment_dir = Path('./experiment/')
experiment_dir.mkdir(exist_ok=True)
file_dir = Path(str(experiment_dir) +'/%sPartSeg-'%args.model_name + str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')))
file_dir.mkdir(exist_ok=True)
checkpoints_dir = file_dir.joinpath('checkpoints/')
checkpoints_dir.mkdir(exist_ok=True)
log_dir = file_dir.joinpath('logs/')
log_dir.mkdir(exist_ok=True)
'''LOG'''
#使用logging
args = parse_args()
logger = logging.getLogger(args.model_name)#设置logger 记录器
logger.setLevel(logging.INFO)#设置等级
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')#设置输出的布局
file_handler = logging.FileHandler(str(log_dir) + '/train_%s_partseg.txt'%args.model_name)#设置handler处理器
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.info('---------------------------------------------------TRANING---------------------------------------------------')
logger.info('PARAMETER ...')
logger.info(args)
norm = True if args.model_name == 'pointnet' else False
#数据集加载
TRAIN_DATASET = PartNormalDataset(npoints=2048, split='trainval',normalize=norm, jitter=args.jitter)
dataloader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=args.batchsize,shuffle=True, num_workers=int(args.workers))
TEST_DATASET = PartNormalDataset(npoints=2048, split='test',normalize=norm,jitter=False)
testdataloader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=8,shuffle=True, num_workers=int(args.workers))
print("The number of training data is:",len(TRAIN_DATASET))
logger.info("The number of training data is:%d",len(TRAIN_DATASET))
print("The number of test data is:", len(TEST_DATASET))
logger.info("The number of test data is:%d", len(TEST_DATASET))
num_classes = 16
num_part = 50
blue = lambda x: '\033[94m' + x + '\033[0m'
model = PointNet2PartSeg_msg_one_hot(num_part) if args.model_name == 'pointnet2'else PointNetDenseCls(cat_num=num_classes,part_num=num_part)
if args.pretrain is not None:
model.load_state_dict(torch.load(args.pretrain))
print('load model %s'%args.pretrain)
logger.info('load model %s'%args.pretrain)
else:
print('Training from scratch')
logger.info('Training from scratch')
pretrain = args.pretrain
init_epoch = int(pretrain[-14:-11]) if args.pretrain is not None else 0
if args.optimizer == 'SGD':
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
elif args.optimizer == 'Adam':
optimizer = torch.optim.Adam(
model.parameters(),
lr=args.learning_rate,
betas=(0.9, 0.999),
eps=1e-08,
weight_decay=args.decay_rate
)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.5)#调整学习率的方法,根据epoches的数量调整学习率,step_size (int)学习率衰减期,gamma (float) – 学习率衰减的乘积因子。默认值:-0.1。。
'''GPU selection and multi-GPU'''
# if args.multi_gpu is not None:
# device_ids = [int(x) for x in args.multi_gpu.split(',')]
# torch.backends.cudnn.benchmark = True
# model.cuda(device_ids[0])
# model = torch.nn.DataParallel(model, device_ids=device_ids)
# else:
# # model.cuda()
# model()
criterion = PointNetLoss()
LEARNING_RATE_CLIP = 1e-5
history = defaultdict(lambda: list())#记录每次loss
best_acc = 0
best_class_avg_iou = 0
best_inctance_avg_iou = 0
for epoch in range(init_epoch,args.epoch):
scheduler.step()#修改ir
lr = max(optimizer.param_groups[0]['lr'],LEARNING_RATE_CLIP)
print('Learning rate:%f' % lr)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
for i, data in tqdm(enumerate(dataloader, 0),total=len(dataloader),smoothing=0.9):
points, label, target, norm_plt = data
'''
points(8,2048,3)
label(8,1)
target(8,2048)
norm_plt(8,2048,3)
'''
# print(data.shape())
# print(points)
# print(label)
# print(target)
# print(norm_plt)
points, label, target = Variable(points.float()),Variable(label.long()), Variable(target.long())#转torch
points = points.transpose(2, 1)#norm_plt(8,3,2048)
norm_plt = norm_plt.transpose(2, 1)#norm_plt(8,3,2048)
# points, label, target,norm_plt = points.cuda(),label.squeeze().cuda(), target.cuda(), norm_plt.cuda()
points, label, target, norm_plt = points, label.squeeze(), target, norm_plt
optimizer.zero_grad()
model = model.train()
if args.model_name == 'pointnet':
labels_pred, seg_pred, trans_feat = model(points, to_categorical(label, 16))
'''
labels_pred (8,16)
seg_pred (8,2048,50)
trans_feat
'''
seg_pred = seg_pred.contiguous().view(-1, num_part)#(8*2048,50)
target = target.view(-1, 1)[:, 0]#(8*2048)
loss, seg_loss, label_loss = criterion(labels_pred, label, seg_pred, target, trans_feat)
else:
seg_pred = model(points, norm_plt, to_categorical(label, 16))
seg_pred = seg_pred.contiguous().view(-1, num_part)
target = target.view(-1, 1)[:, 0]
loss = F.nll_loss(seg_pred, target)
history['loss'].append(loss.cpu().data.numpy())
loss.backward()
optimizer.step()
forpointnet2 = args.model_name == 'pointnet2'
test_metrics, test_hist_acc, cat_mean_iou = test_partseg(model.eval(), testdataloader, seg_label_to_cat,50,forpointnet2)#每完成测试集测试
#test_metrics字典记录了验证的标准
#test_hist_acc历史的语义分割的准确度
#cat_mean_iou验证出它的交叉比
print('Epoch %d %s accuracy: %f Class avg mIOU: %f Inctance avg mIOU: %f' % (
epoch, blue('test'), test_metrics['accuracy'],test_metrics['class_avg_iou'],test_metrics['inctance_avg_iou']))
#记录日志
logger.info('Epoch %d %s Accuracy: %f Class avg mIOU: %f Inctance avg mIOU: %f' % (
epoch, blue('test'), test_metrics['accuracy'],test_metrics['class_avg_iou'],test_metrics['inctance_avg_iou']))
#保存最好的模型
if test_metrics['accuracy'] > best_acc:
best_acc = test_metrics['accuracy']
torch.save(model.state_dict(), '%s/%s_%.3d_%.4f.pth' % (checkpoints_dir,args.model_name, epoch, best_acc))
logger.info(cat_mean_iou)
logger.info('Save model..')
print('Save model..')
print(cat_mean_iou)
if test_metrics['class_avg_iou'] > best_class_avg_iou:
best_class_avg_iou = test_metrics['class_avg_iou']
if test_metrics['inctance_avg_iou'] > best_inctance_avg_iou:
best_inctance_avg_iou = test_metrics['inctance_avg_iou']
print('Best accuracy is: %.5f'%best_acc)
logger.info('Best accuracy is: %.5f'%best_acc)
print('Best class avg mIOU is: %.5f'%best_class_avg_iou)
logger.info('Best class avg mIOU is: %.5f'%best_class_avg_iou)
print('Best inctance avg mIOU is: %.5f'%best_inctance_avg_iou)
logger.info('Best inctance avg mIOU is: %.5f'%best_inctance_avg_iou)
if __name__ == '__main__':
args = parse_args()
main(args)
数据加载部分(注解还在跟进)
# *_*coding:utf-8 *_*
import os
import json
import warnings
import numpy as np
from torch.utils.data import Dataset
warnings.filterwarnings('ignore')
def pc_normalize(pc):
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
pc = pc / m
return pc
def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05):
""" Randomly jitter points. jittering is per point.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, jittered batch of point clouds
"""
N, C = batch_data.shape
assert(clip > 0)
jittered_data = np.clip(sigma * np.random.randn(N, C), -1*clip, clip)
jittered_data += batch_data
return jittered_data
class PartNormalDataset(Dataset):
def __init__(self, npoints=2500, split='train', normalize=True, jitter=False):
self.npoints = npoints
self.root = r'G:\Pointnet_Pointnet2_pytorch-master\data\ShapeNet'
self.catfile = os.path.join(self.root, 'synsetoffset2category.txt')#加入目录
self.cat = {}
self.normalize = normalize
self.jitter = jitter
with open(self.catfile, 'r') as f:
for line in f:
ls = line.strip().split()#以空格为分界化成list
self.cat[ls[0]] = ls[1]#创建一个字典
self.cat = {k: v for k, v in self.cat.items()}#形成一个大字典
# print(self.cat)
self.meta = {}
with open(os.path.join(self.root, 'train_test_split', 'shuffled_train_file_list.json'), 'r') as f:
train_ids = set([str(d.split('/')[2]) for d in json.load(f)])
with open(os.path.join(self.root, 'train_test_split', 'shuffled_val_file_list.json'), 'r') as f:
val_ids = set([str(d.split('/')[2]) for d in json.load(f)])
with open(os.path.join(self.root, 'train_test_split', 'shuffled_test_file_list.json'), 'r') as f:
test_ids = set([str(d.split('/')[2]) for d in json.load(f)])
for item in self.cat:
# print('category', item)
self.meta[item] = []
dir_point = os.path.join(self.root, self.cat[item])
fns = sorted(os.listdir(dir_point))
# print(fns[0][0:-4])
if split == 'trainval':
fns = [fn for fn in fns if ((fn[0:-4] in train_ids) or (fn[0:-4] in val_ids))]
elif split == 'train':
fns = [fn for fn in fns if fn[0:-4] in train_ids]
elif split == 'val':
fns = [fn for fn in fns if fn[0:-4] in val_ids]
elif split == 'test':
fns = [fn for fn in fns if fn[0:-4] in test_ids]
else:
print('Unknown split: %s. Exiting..' % (split))
exit(-1)
# print(os.path.basename(fns))
for fn in fns:
token = (os.path.splitext(os.path.basename(fn))[0])
self.meta[item].append(os.path.join(dir_point, token + '.txt'))#为每个类找到源文件地址
self.datapath = []
for item in self.cat:
for fn in self.meta[item]:
self.datapath.append((item, fn))
if split == 'trainval':
self.datapath = self.datapath[:2000]
# print(len(self.datapath))
# print(len(self.datapath))
self.classes = dict(zip(self.cat, range(len(self.cat))))
# Mapping from category ('Chair') to a list of int [10,11,12,13] as segmentation labels
self.seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43],
'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46],
'Mug': [36, 37], 'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27],
'Table': [47, 48, 49], 'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40],
'Chair': [12, 13, 14, 15], 'Knife': [22, 23]}
for cat in sorted(self.seg_classes.keys()):
print(cat, self.seg_classes[cat])
self.cache = {} # from index to (point_set, cls, seg) tuple
self.cache_size = 20000
#上一步的得到所有数据集的地址信息
def __getitem__(self, index):
#每个照片的数据
if index in self.cache:
point_set, normal, seg, cls = self.cache[index]
else:
fn = self.datapath[index]
cat = self.datapath[index][0]
cls = self.classes[cat]
cls = np.array([cls]).astype(np.int32)
data = np.loadtxt(fn[1]).astype(np.float32)
point_set = data[:, 0:3]
normal = data[:, 3:6]
seg = data[:, -1].astype(np.int32)
if len(self.cache) < self.cache_size:
self.cache[index] = (point_set, normal, seg, cls)
if self.normalize:
point_set = pc_normalize(point_set)
if self.jitter:
jitter_point_cloud(point_set)
choice = np.random.choice(len(seg), self.npoints, replace=True)
# resample
point_set = point_set[choice, :]
seg = seg[choice]
normal = normal[choice, :]
return point_set,cls, seg, normal
def __len__(self):
return len(self.datapath)
if __name__ == '__main__':
TRAIN_DATASET = PartNormalDataset(npoints=2048, split='trainval',normalize=True, jitter=False)
网络结构部分
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
class STN3d(nn.Module):
def __init__(self):
super(STN3d, self).__init__()
# 这里需要注意的是上文提到的MLP均由卷积结构完成
# 比如说将3维映射到64维,其利用64个1x3的卷积核
self.conv1 = torch.nn.Conv1d(3, 64, 1)
self.conv2 = torch.nn.Conv1d(64, 128, 1)
self.conv3 = torch.nn.Conv1d(128, 1024, 1)
self.fc1 = nn.Linear(1024, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, 9)
self.relu = nn.ReLU()
self.bn1 = nn.BatchNorm1d(64)
self.bn2 = nn.BatchNorm1d(128)
self.bn3 = nn.BatchNorm1d(1024)
self.bn4 = nn.BatchNorm1d(512)
self.bn5 = nn.BatchNorm1d(256)
def forward(self, x):
batchsize = x.size()[0]
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
x = torch.max(x, 2, keepdim=True)[0]
x = x.view(-1, 1024)
x = F.relu(self.bn4(self.fc1(x)))
x = F.relu(self.bn5(self.fc2(x)))
x = self.fc3(x)
iden = Variable(torch.from_numpy(np.array([1, 0, 0, 0, 1, 0, 0, 0, 1]).astype(np.float32))).view(1, 9).repeat(
batchsize, 1)
if x.is_cuda:
iden = iden.cuda()
x = x + iden
x = x.view(-1, 3, 3)
return x
class STNkd(nn.Module):
def __init__(self, k=64):
super(STNkd, self).__init__()
self.conv1 = torch.nn.Conv1d(k, 64, 1)
self.conv2 = torch.nn.Conv1d(64, 128, 1)
self.conv3 = torch.nn.Conv1d(128, 1024, 1)
self.fc1 = nn.Linear(1024, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, k * k)
self.relu = nn.ReLU()
self.bn1 = nn.BatchNorm1d(64)
self.bn2 = nn.BatchNorm1d(128)
self.bn3 = nn.BatchNorm1d(1024)
self.bn4 = nn.BatchNorm1d(512)
self.bn5 = nn.BatchNorm1d(256)
self.k = k
def forward(self, x):
batchsize = x.size()[0]
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
x = torch.max(x, 2, keepdim=True)[0]
x = x.view(-1, 1024)
x = F.relu(self.bn4(self.fc1(x)))
x = F.relu(self.bn5(self.fc2(x)))
x = self.fc3(x)
iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1, self.k * self.k).repeat(
batchsize, 1)
if x.is_cuda:
iden = iden.cuda()
x = x + iden
x = x.view(-1, self.k, self.k)
return x
class PointNetEncoder(nn.Module):
def __init__(self, global_feat=True, feature_transform=False, semseg = False):
super(PointNetEncoder, self).__init__()
self.stn = STN3d() if not semseg else STNkd(k=9)
self.conv1 = torch.nn.Conv1d(3, 64, 1) if not semseg else torch.nn.Conv1d(9, 64, 1)
self.conv2 = torch.nn.Conv1d(64, 128, 1)
self.conv3 = torch.nn.Conv1d(128, 1024, 1)
self.bn1 = nn.BatchNorm1d(64)
self.bn2 = nn.BatchNorm1d(128)
self.bn3 = nn.BatchNorm1d(1024)
self.global_feat = global_feat
self.feature_transform = feature_transform
if self.feature_transform:
self.fstn = STNkd(k=64)
def forward(self, x):
n_pts = x.size()[2]
trans = self.stn(x)
x = x.transpose(2, 1)
x = torch.bmm(x, trans)
x = x.transpose(2, 1)
x = F.relu(self.bn1(self.conv1(x)))
if self.feature_transform:
trans_feat = self.fstn(x)
x = x.transpose(2, 1)
x = torch.bmm(x, trans_feat)
x = x.transpose(2, 1)
else:
trans_feat = None
pointfeat = x
x = F.relu(self.bn2(self.conv2(x)))
x = self.bn3(self.conv3(x))
x = torch.max(x, 2, keepdim=True)[0]
x = x.view(-1, 1024)
if self.global_feat:
return x, trans, trans_feat
else:
x = x.view(-1, 1024, 1).repeat(1, 1, n_pts)
return torch.cat([x, pointfeat], 1), trans, trans_feat
class PointNetDenseCls(nn.Module):
def __init__(self, cat_num=16,part_num=50):
super(PointNetDenseCls, self).__init__()
self.cat_num = cat_num
self.part_num = part_num
self.stn = STN3d()
self.conv1 = torch.nn.Conv1d(3, 64, 1)
self.conv2 = torch.nn.Conv1d(64, 128, 1)
self.conv3 = torch.nn.Conv1d(128, 128, 1)
self.conv4 = torch.nn.Conv1d(128, 512, 1)
self.conv5 = torch.nn.Conv1d(512, 2048, 1)
self.bn1 = nn.BatchNorm1d(64)
self.bn2 = nn.BatchNorm1d(128)
self.bn3 = nn.BatchNorm1d(128)
self.bn4 = nn.BatchNorm1d(512)
self.bn5 = nn.BatchNorm1d(2048)
self.fstn = STNkd(k=128)
# classification network
self.fc1 = nn.Linear(2048, 256)
self.fc2 = nn.Linear(256, 256)
self.fc3 = nn.Linear(256, cat_num)
self.dropout = nn.Dropout(p=0.3)
self.bnc1 = nn.BatchNorm1d(256)
self.bnc2 = nn.BatchNorm1d(256)
# segmentation network
self.convs1 = torch.nn.Conv1d(4944, 256, 1)
self.convs2 = torch.nn.Conv1d(256, 256, 1)
self.convs3 = torch.nn.Conv1d(256, 128, 1)
self.convs4 = torch.nn.Conv1d(128, part_num, 1)
self.bns1 = nn.BatchNorm1d(256)
self.bns2 = nn.BatchNorm1d(256)
self.bns3 = nn.BatchNorm1d(128)
def forward(self, point_cloud,label):
batchsize,_ , n_pts = point_cloud.size()
# point_cloud_transformed
trans = self.stn(point_cloud)
point_cloud = point_cloud.transpose(2, 1)
point_cloud_transformed = torch.bmm(point_cloud, trans)
point_cloud_transformed = point_cloud_transformed.transpose(2, 1)
# MLP
out1 = F.relu(self.bn1(self.conv1(point_cloud_transformed)))
out2 = F.relu(self.bn2(self.conv2(out1)))
out3 = F.relu(self.bn3(self.conv3(out2)))
# net_transformed
trans_feat = self.fstn(out3)
x = out3.transpose(2, 1)
net_transformed = torch.bmm(x, trans_feat)
net_transformed = net_transformed.transpose(2, 1)
# MLP
out4 = F.relu(self.bn4(self.conv4(net_transformed)))
out5 = self.bn5(self.conv5(out4))
out_max = torch.max(out5, 2, keepdim=True)[0]
out_max = out_max.view(-1, 2048)
# classification network
net = F.relu(self.bnc1(self.fc1(out_max)))
net = F.relu(self.bnc2(self.dropout(self.fc2(net))))
net = self.fc3(net) # [B,16]
# segmentation network
out_max = torch.cat([out_max,label],1)
expand = out_max.view(-1, 2048+16, 1).repeat(1, 1, n_pts)
concat = torch.cat([expand, out1, out2, out3, out4, out5], 1)
net2 = F.relu(self.bns1(self.convs1(concat)))
net2 = F.relu(self.bns2(self.convs2(net2)))
net2 = F.relu(self.bns3(self.convs3(net2)))
net2 = self.convs4(net2)
net2 = net2.transpose(2, 1).contiguous()
net2 = F.log_softmax(net2.view(-1, self.part_num), dim=-1)
net2 = net2.view(batchsize, n_pts, self.part_num) # [B, N 50]
return net, net2, trans_feat
def feature_transform_reguliarzer(trans):
d = trans.size()[1]
I = torch.eye(d)[None, :, :]
if trans.is_cuda:
I = I.cuda()
loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2, 1) - I), dim=(1, 2)))
return loss
class PointNetSeg(nn.Module):
def __init__(self,num_class,feature_transform=False, semseg = False):
super(PointNetSeg, self).__init__()
self.k = num_class
self.feat = PointNetEncoder(global_feat=False,feature_transform=feature_transform, semseg = semseg)
self.conv1 = torch.nn.Conv1d(1088, 512, 1)
self.conv2 = torch.nn.Conv1d(512, 256, 1)
self.conv3 = torch.nn.Conv1d(256, 128, 1)
self.conv4 = torch.nn.Conv1d(128, self.k, 1)
self.bn1 = nn.BatchNorm1d(512)
self.bn1_1 = nn.BatchNorm1d(1024)
self.bn2 = nn.BatchNorm1d(256)
self.bn3 = nn.BatchNorm1d(128)
def forward(self, x):
batchsize = x.size()[0]
n_pts = x.size()[2]
x, trans, trans_feat = self.feat(x)
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
x = self.conv4(x)
x = x.transpose(2,1).contiguous()
x = F.log_softmax(x.view(-1,self.k), dim=-1)
x = x.view(batchsize, n_pts, self.k)
return x, trans_feat
if __name__ == '__main__':
#判断网络是否正确
point = torch.randn(8,3,1024)
label = torch.randn(8,16)
model = PointNetDenseCls()
net, net2, trans_feat = model(point,label)
print('net',net.shape)
print('net2',net2.shape)
print('trans_feat',trans_feat.shape)
(理清里面的数据走向,你就懂了)
(B,)
损失函数
class PointNetLoss(torch.nn.Module):
def __init__(self, weight=1,mat_diff_loss_scale=0.001):
super(PointNetLoss, self).__init__()
self.mat_diff_loss_scale = mat_diff_loss_scale
self.weight = weight
def forward(self, labels_pred, label, seg_pred,seg, trans_feat):
seg_loss = F.nll_loss(seg_pred, seg))#语义分割loss
mat_diff_loss = feature_transform_reguliarzer(trans_feat)#正交损失
label_loss = F.nll_loss(labels_pred, label)#分类loss
loss = self.weight * seg_loss + (1-self.weight) * label_loss + mat_diff_loss * self.mat_diff_loss_scale
return loss, seg_loss, label_loss
测试集代码
def compute_cat_iou(pred,target,iou_tabel):
iou_list = []
target = target.cpu().data.numpy()
for j in range(pred.size(0)):
batch_pred = pred[j]
batch_target = target[j]
batch_choice = batch_pred.data.max(1)[1].cpu().data.numpy()
for cat in np.unique(batch_target):
# intersection = np.sum((batch_target == cat) & (batch_choice == cat))
# union = float(np.sum((batch_target == cat) | (batch_choice == cat)))
# iou = intersection/union if not union ==0 else 1
I = np.sum(np.logical_and(batch_choice == cat, batch_target == cat))
U = np.sum(np.logical_or(batch_choice == cat, batch_target == cat))
if U == 0:
iou = 1 # If the union of groundtruth and prediction points is empty, then count part IoU as 1
else:
iou = I / float(U)
iou_tabel[cat,0] += iou
iou_tabel[cat,1] += 1
iou_list.append(iou)
return iou_tabel,iou_list
def test_partseg(model, loader, catdict, num_classes = 50,forpointnet2=False):
''' catdict = {0:Airplane, 1:Airplane, ...49:Table} '''
iou_tabel = np.zeros((len(catdict),3))
iou_list = []
metrics = defaultdict(lambda:list())
hist_acc = []
# mean_correct = []
for batch_id, (points, label, target, norm_plt) in tqdm(enumerate(loader), total=len(loader), smoothing=0.9):
batchsize, num_point,_= points.size()
points, label, target, norm_plt = Variable(points.float()),Variable(label.long()), Variable(target.long()),Variable(norm_plt.float())
points = points.transpose(2, 1)
norm_plt = norm_plt.transpose(2, 1)
# points, label, target, norm_plt = points.cuda(), label.squeeze().cuda(), target.cuda(), norm_plt.cuda()
points, label, target, norm_plt = points, label.squeeze(), target, norm_plt
if forpointnet2:
seg_pred = model(points, norm_plt, to_categorical(label, 16))
else:
labels_pred, seg_pred, _ = model(points,to_categorical(label,16))#(B,16)/(B,2048,50)
# labels_pred_choice = labels_pred.data.max(1)[1]
# labels_correct = labels_pred_choice.eq(label.long().data).cpu().sum()
# mean_correct.append(labels_correct.item() / float(points.size()[0]))
# print(pred.size())
iou_tabel, iou = compute_cat_iou(seg_pred,target,iou_tabel)#计算交叉比
iou_list+=iou
# shape_ious += compute_overall_iou(pred, target, num_classes)
seg_pred = seg_pred.contiguous().view(-1, num_classes)#(B*2048,50)
target = target.view(-1, 1)[:, 0]#(B*2048)
pred_choice = seg_pred.data.max(1)[1]
correct = pred_choice.eq(target.data).cpu().sum()#得到分割的准确数
metrics['accuracy'].append(correct.item()/ (batchsize * num_point))#每次bitch得到准确率
iou_tabel[:,2] = iou_tabel[:,0] /iou_tabel[:,1]
hist_acc += metrics['accuracy']
metrics['accuracy'] = np.mean(hist_acc)
metrics['inctance_avg_iou'] = np.mean(iou_list)
iou_tabel = pd.DataFrame(iou_tabel,columns=['iou','count','mean_iou'])
iou_tabel['Category_IOU'] = [catdict[i] for i in range(len(catdict)) ]
cat_iou = iou_tabel.groupby('Category_IOU')['mean_iou'].mean()
metrics['class_avg_iou'] = np.mean(cat_iou)#该类的IOU
return metrics, hist_acc, cat_iou
补:原作者代码https://github.com/yanx27/Pointnet_Pointnet2_pytorch