AI算法成长练习第一篇——Task-Adaptive Negative Envision for Few-Shot Open-Set Recognition代码复现

论文代码复现

代码结构
AI算法成长练习第一篇——Task-Adaptive Negative Envision for Few-Shot Open-Set Recognition代码复现_第1张图片

Architectures

AttnClassifier.py

import torch.nn as nn
import torch
import torch.nn.functional as F

import numpy as np

class Classifier(nn.Module):
	def __init__(self, args, feat_dim, param_seam, train_weight_base=False):
		super(Classifier, self).__init__()
		
		#weight & Bias for Base
		self.train_weight_base = train_weight_base
		self.init_representation(param_seman)
		if train_weight_base:
			print('Enable training base class weights')
		
		self.calibrator = SupportCalibrator(nway=args.n-ways, feat_dim=feat_dim, n_head=1, base_seman_calib=args.base_seman_calib, neg_gen)
		self.open_generator = OpenSetGenerater(args.n_ways, feat_dim, n_head=1, neg_gen_type=args.neg_gen_type, agg=args.agg)
		self.metric = Metric_Cosine()

	def forward(self, feature, cls_ids, test=False):
		## bs: features[0].size(0)
		## support_feat: bs*nway*nshot*D
		## query_feat: bs*(nway*nquery)*D
		## base_ids: bs*54
		(support_feat, query_feat, openset_feat) = features
		
		(nb, nc, ns, ndim), nq = support_feat.size(), query_feat.size(1)
		(supp_ids, base_ids) = cls_ids
		base_weight, base_wgtmem, base_seman, support_seman = self.get_representation(supp_ids, base_ids)
		support_feat = torch.mean(support_feat, dim=2)
		supp_protos, support_attn = self.calibrator(support_feat, base_weights, support_seman, base_seman)

		# 修改的代码
		n = query_feat.size()[1]
		sup_list = []
		for i in range(0, n, 5):
			supp_fk = query_feat[:, i:i+5, :].contiguous()
			ss, _ = self.calibrator(supp_fk, base_weights, support_seman,base_seman)
			sup_list.append(ss)
		suppfake_protos = torch.cat(sup_list, dim=1)

		suppfake_protos = torch.mean(suppfake_protos, dim=1).view(nb, -1, ndim)
		new_supp_protos = torch.cat([supp_protos, suppfake_protos], dim=1)

		fakeclass_protos, recip_unit = self.open_generator(new_supp_protos, base_weights, support_seman, base_seman)
		cls_protos = torch.cat([supp_protos, fakeclass_protos], dim=1)
		
		query_cls_scores = self.metric(cls_protos, query_feat)
		openset_cls_scores = self.metric(cls_protos, openset_feat)
		
		test_cosine_scores = (query_cls_scores, openset_cls_scores)
		
		query_funit_distance = 1.0 - self.metric(recip_unit, query_feat)
		query_funit_distance = 1.0 - self.metric(recip_unit, openset_feat)
		funit_distance = torch.cat([query_funit_distance, qopen_funit_distance], dim=1)

		return test_cosine_scores, supp_protos, fakeclass_protos, (base_weights, base_wgtmem), funit_distance 

	def init_representation(self, param_seman):
		(params, seman_dict) = param_seman
		self.weight_base = nn.Parameter(params['cls_classifier.weight'], requires_grad=self.train_weight_base)
		self.bias_base = nn.Parameter(params['cls_classifier.bias'],requires_grad=self.train_weight_base)
		self.weight_mem = nn.Parameter(params['cls_classifier.weight'].clone(), requires_grad=False)
		self.seman = {
   k:nn.Parameter(torch.from_numpy(v), requires_grad=False).float().cuda() for k,v in seman_dict_items()}
	
	def get_representation(self, cls_ids, base_ids, randpick=False):
		if base_ids is not None:
			base_weights = self.weight_base[base_ids, :]
			base_wgtmem = self.weight_mem[base_ids, :]
			base_seman = self.seman['base'][base_ids, :]
			supp_seman = self.seman['base'][cls_ids, :]
		else:
			bs = cls_ids.size(0)
			base_weights = self.weight_base.repeat(bs, 1, 1)
			base_wgtmem = self.weight_mem.repaet(bs, 1, 1)
			base_seman = self.seman['base'].repeat(bs, 1, 1)
			supp_seman = self.seman['novel_test'][cls_ids, :]
		if randpick:
			num_base = base_weights.shape[1]
			base_size = self.base_size
			idx = np.random.choice(list(range(num_base)), size=base_size, replace=False)
			base_weights = base_weights[:, idx, :]
			base_seman = base_seman[:, idx, :]
		return base_weights, base_wgtmem, base_seman, supp_seman

class SupportCalibrator(nn.Module):
	def __init__(self, nway, feat_dim, n_head=1, base_seman_calib=True, neg_gen_type='semang'):
		super(SupportCalibrator, self).__init__()
		self.nway = nway
		self.feat_dim = feat_dim
		self.base_seman_calib = base_seman_calib
		
		self.map_sem = nn.Sequential(nn.Linear(300, 300), nn.LeakyReLU(0.1), nn.Dropout(0.1), nn.Linear(300, 300))
		self.calibrator = MultiHeadAttention(feat_dim // n_head, feat_dim // n_head, (feat_dim, feat_dim))
		self.neg_gen_type = neg_gen_type
		if neg_gen_type == 'semang':
			self.task_visfuse = nn.Linear(feat_dim * 300, feat_dim)
			self.task_semfuse = nn.Linear(feat_dim*300, 300)

	def _seman_calib(self, seman):
		seman = self.map_sem(seman)
		return seman

	def forward(self, support_feat, base_weights, support_seman, base_seman):
		## support_feat: bs*nway*640, base_weights: bs*num_base*640, support_seman: bs*nway*300, base_seman: bs*num_base*300
		n_bs, n_base_cls = base_weights.size()[:2]
		base_weights = base_weights.unsqueeze(dim=1).repeat(1, self.nway, 1, 1).view(-1, n_base_cls, self.feat_dim)
		support_feat = support_feat_view(-1, 1, self.feat_dim)
		
		if self.neg_gen_type == 'semang':
			support_seman = self._seman_calib(support_seman)
			if self.base_seman_calib:
				base_seman = self._seman_calib(base_seman)
			base_seman = base_seman.unsqueeze(dim=1).repeat(1, self.nway, 1, 1).view(-1, n_base_cls, 300)
			support_seman = support_seman.view(-1, 1, 300)
			
			base_mem_vis = base_weights
			task_mem_vis = base_weights
			
			base_mem_seman = base_seman
			task_mem_seman = base_seman
			avg_task_mem = torch.mean(torch.cat([task_mem_vis, task_mem_seman], -1), 1, keepdim=True)

			gate_vis = torch.sigmoid(self.task_visfuse(avg_task_mem)) + 1.0
			gate_sem = torch.sigmoid(self.task_semfuse(avg_task_mem)) + 1.0
			
			base_weights = base_mem_vis *gate_vis
			base_seman = base_mem_seman * gate_sem
			
		elif self.neg_gen_type == 'attg':
			base_mem_vis = base_weights
			base_seman = None
			support_seman = None
		elif self.neg_gen_type == 'att':
			base_weights = support_feat
			base_mem_vis = support_feat
			support_seman = None
			base_seman = None

		else:
			return support_feat.view(n_bs, self.nway, -1), None
		
		support_center, _, support_attn, _ = self.calibrator(support_feat, base_weights, base_mem_vis, support_seman, base_seman)
		support_center = support_center.view(n_bs, self.nway, -1)
		support_attn = support_attn.view(n_bs, self.nway, -1)
		return support_center, support_attn
		 
class OpenSetGenerater(nn.Module):
 	def __init__(self, nway, featdim, n_head=1, neg_gen_type='semang', agg='avg'):
 		supper(OpenSetGenerater, self).__init__()
 		self.nway = nway
 		self.att = MultiHeadAttention(featdim // n_head, featdim // n_head, (featdim, featdim))
 		self.featdim = featdim
 		self.neg_gen_type = neg_gen_type
 		if neg_gen_type == 'semang':
 			self.task_visfuse = nn.Linear(featdim+300, featdim)
 			self.task_semfuse = nn.Linear(featdim+300, 300)
 			
 		self.agg = agg
 		if agg == 'mlp':
 			self.agg_func = nn.Sequential(
			nn.Linear(featdim, featdim),
			nn.LeakyReLU(0.5),
			nn.Dropout(0.5),
			nn.Linear(featdim, featdim))
		self.map_sem = nn.Sequential(nn.Linear(300, 300),
									nn.LeakyReLU(

你可能感兴趣的:(人工智能,算法,python,机器学习,深度学习)