论文代码复现
代码结构

Architectures
AttnClassifier.py
import torch.nn as nn
import torch
import torch.nn.functional as F
import numpy as np
class Classifier(nn.Module):
def __init__(self, args, feat_dim, param_seam, train_weight_base=False):
super(Classifier, self).__init__()
self.train_weight_base = train_weight_base
self.init_representation(param_seman)
if train_weight_base:
print('Enable training base class weights')
self.calibrator = SupportCalibrator(nway=args.n-ways, feat_dim=feat_dim, n_head=1, base_seman_calib=args.base_seman_calib, neg_gen)
self.open_generator = OpenSetGenerater(args.n_ways, feat_dim, n_head=1, neg_gen_type=args.neg_gen_type, agg=args.agg)
self.metric = Metric_Cosine()
def forward(self, feature, cls_ids, test=False):
(support_feat, query_feat, openset_feat) = features
(nb, nc, ns, ndim), nq = support_feat.size(), query_feat.size(1)
(supp_ids, base_ids) = cls_ids
base_weight, base_wgtmem, base_seman, support_seman = self.get_representation(supp_ids, base_ids)
support_feat = torch.mean(support_feat, dim=2)
supp_protos, support_attn = self.calibrator(support_feat, base_weights, support_seman, base_seman)
n = query_feat.size()[1]
sup_list = []
for i in range(0, n, 5):
supp_fk = query_feat[:, i:i+5, :].contiguous()
ss, _ = self.calibrator(supp_fk, base_weights, support_seman,base_seman)
sup_list.append(ss)
suppfake_protos = torch.cat(sup_list, dim=1)
suppfake_protos = torch.mean(suppfake_protos, dim=1).view(nb, -1, ndim)
new_supp_protos = torch.cat([supp_protos, suppfake_protos], dim=1)
fakeclass_protos, recip_unit = self.open_generator(new_supp_protos, base_weights, support_seman, base_seman)
cls_protos = torch.cat([supp_protos, fakeclass_protos], dim=1)
query_cls_scores = self.metric(cls_protos, query_feat)
openset_cls_scores = self.metric(cls_protos, openset_feat)
test_cosine_scores = (query_cls_scores, openset_cls_scores)
query_funit_distance = 1.0 - self.metric(recip_unit, query_feat)
query_funit_distance = 1.0 - self.metric(recip_unit, openset_feat)
funit_distance = torch.cat([query_funit_distance, qopen_funit_distance], dim=1)
return test_cosine_scores, supp_protos, fakeclass_protos, (base_weights, base_wgtmem), funit_distance
def init_representation(self, param_seman):
(params, seman_dict) = param_seman
self.weight_base = nn.Parameter(params['cls_classifier.weight'], requires_grad=self.train_weight_base)
self.bias_base = nn.Parameter(params['cls_classifier.bias'],requires_grad=self.train_weight_base)
self.weight_mem = nn.Parameter(params['cls_classifier.weight'].clone(), requires_grad=False)
self.seman = {
k:nn.Parameter(torch.from_numpy(v), requires_grad=False).float().cuda() for k,v in seman_dict_items()}
def get_representation(self, cls_ids, base_ids, randpick=False):
if base_ids is not None:
base_weights = self.weight_base[base_ids, :]
base_wgtmem = self.weight_mem[base_ids, :]
base_seman = self.seman['base'][base_ids, :]
supp_seman = self.seman['base'][cls_ids, :]
else:
bs = cls_ids.size(0)
base_weights = self.weight_base.repeat(bs, 1, 1)
base_wgtmem = self.weight_mem.repaet(bs, 1, 1)
base_seman = self.seman['base'].repeat(bs, 1, 1)
supp_seman = self.seman['novel_test'][cls_ids, :]
if randpick:
num_base = base_weights.shape[1]
base_size = self.base_size
idx = np.random.choice(list(range(num_base)), size=base_size, replace=False)
base_weights = base_weights[:, idx, :]
base_seman = base_seman[:, idx, :]
return base_weights, base_wgtmem, base_seman, supp_seman
class SupportCalibrator(nn.Module):
def __init__(self, nway, feat_dim, n_head=1, base_seman_calib=True, neg_gen_type='semang'):
super(SupportCalibrator, self).__init__()
self.nway = nway
self.feat_dim = feat_dim
self.base_seman_calib = base_seman_calib
self.map_sem = nn.Sequential(nn.Linear(300, 300), nn.LeakyReLU(0.1), nn.Dropout(0.1), nn.Linear(300, 300))
self.calibrator = MultiHeadAttention(feat_dim // n_head, feat_dim // n_head, (feat_dim, feat_dim))
self.neg_gen_type = neg_gen_type
if neg_gen_type == 'semang':
self.task_visfuse = nn.Linear(feat_dim * 300, feat_dim)
self.task_semfuse = nn.Linear(feat_dim*300, 300)
def _seman_calib(self, seman):
seman = self.map_sem(seman)
return seman
def forward(self, support_feat, base_weights, support_seman, base_seman):
n_bs, n_base_cls = base_weights.size()[:2]
base_weights = base_weights.unsqueeze(dim=1).repeat(1, self.nway, 1, 1).view(-1, n_base_cls, self.feat_dim)
support_feat = support_feat_view(-1, 1, self.feat_dim)
if self.neg_gen_type == 'semang':
support_seman = self._seman_calib(support_seman)
if self.base_seman_calib:
base_seman = self._seman_calib(base_seman)
base_seman = base_seman.unsqueeze(dim=1).repeat(1, self.nway, 1, 1).view(-1, n_base_cls, 300)
support_seman = support_seman.view(-1, 1, 300)
base_mem_vis = base_weights
task_mem_vis = base_weights
base_mem_seman = base_seman
task_mem_seman = base_seman
avg_task_mem = torch.mean(torch.cat([task_mem_vis, task_mem_seman], -1), 1, keepdim=True)
gate_vis = torch.sigmoid(self.task_visfuse(avg_task_mem)) + 1.0
gate_sem = torch.sigmoid(self.task_semfuse(avg_task_mem)) + 1.0
base_weights = base_mem_vis *gate_vis
base_seman = base_mem_seman * gate_sem
elif self.neg_gen_type == 'attg':
base_mem_vis = base_weights
base_seman = None
support_seman = None
elif self.neg_gen_type == 'att':
base_weights = support_feat
base_mem_vis = support_feat
support_seman = None
base_seman = None
else:
return support_feat.view(n_bs, self.nway, -1), None
support_center, _, support_attn, _ = self.calibrator(support_feat, base_weights, base_mem_vis, support_seman, base_seman)
support_center = support_center.view(n_bs, self.nway, -1)
support_attn = support_attn.view(n_bs, self.nway, -1)
return support_center, support_attn
class OpenSetGenerater(nn.Module):
def __init__(self, nway, featdim, n_head=1, neg_gen_type='semang', agg='avg'):
supper(OpenSetGenerater, self).__init__()
self.nway = nway
self.att = MultiHeadAttention(featdim // n_head, featdim // n_head, (featdim, featdim))
self.featdim = featdim
self.neg_gen_type = neg_gen_type
if neg_gen_type == 'semang':
self.task_visfuse = nn.Linear(featdim+300, featdim)
self.task_semfuse = nn.Linear(featdim+300, 300)
self.agg = agg
if agg == 'mlp':
self.agg_func = nn.Sequential(
nn.Linear(featdim, featdim),
nn.LeakyReLU(0.5),
nn.Dropout(0.5),
nn.Linear(featdim, featdim))
self.map_sem = nn.Sequential(nn.Linear(300, 300),
nn.LeakyReLU(