输入训练图像之间的相互信息和两个logit之间的差异

时间:2020-04-21 16:58:32

标签: python deep-learning pytorch resnet

我正在尝试复制Rafael Muller等人所做的实验。在“ When does label smoothing help?” NIPS中,2019年

该方法主要在第4节和图6中介绍,用于评估带有分类任务的网络蒸馏中标签平滑的效果: “我们测量X和Y之间的互信息,其中X是代表训练样本索引的离散变量,而Y是代表两个对数(K个类)之间的差的连续变量。用于互信息近似的精确公式为写在第7页。

我正在尝试使用ResNet18在Cifar-10数据集上复制此实验。

在每个时期之后,我运行这段代码来计算相互信息,如本文所述。但是,计算出的互信息最终值完全远离其从0到log(N)的可行范围: batchsize = 300,N = 600(用于互信息计算的训练实例数),L = 100(MonteCarlo样本),并且classes_mi是一维数组,其中包含我们在互信息期间使用的两个类别的索引。计算

with torch.no_grad():
    #Mean calculation
    for i in range(self.L):
        for batch_idx, (inputs, targets) in enumerate(trainloader_sub_transforms[i]):
            inputs, targets = inputs.to(self.device), targets.to(self.device)
            outputs = (self.net(inputs)).cpu().detach().numpy()
            outputs = np.absolute(outputs[:, classes_mi[0]] - outputs[:, classes_mi[1]])
            mu_x[batch_idx*300:batch_idx*300 + len(targets)] += outputs
    mu_x /= self.L
    print('--> Finish Mean Calculation ', mu_x[:10])
    #STD Calculation
    for batch_idx, (inputs, targets) in enumerate(trainloader_sub):
        inputs, targets = inputs.to(self.device), targets.to(self.device)
        outputs = (self.net(inputs)).cpu().detach().numpy()
        outputs = np.absolute(outputs[:, classes_mi[0]] - outputs[:, classes_mi[1]])
        var = np.sum((outputs - mu_x[batch_idx*300:batch_idx*300 + len(targets)]) ** 2)
    var /= self.N
    print('--> Finish VAR Calculation ', var)
    #Mutual Information Calculation
    mutual_info_value = np.zeros(self.N)
    term2 = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader_sub):
        print('batch_idx2 = ', batch_idx)
        inputs, targets = inputs.to(self.device), targets.to(self.device)
        outputs = (self.net(inputs)).cpu().detach().numpy()
        outputs = np.absolute(outputs[:, classes_mi[0]] - outputs[:, classes_mi[1]])
        main_term = -(outputs - mu_x[batch_idx*300:batch_idx*300+len(targets)])**2 / (2 * var)
        term2  +=  np.sum(np.exp( main_term ))
        mutual_info_value[batch_idx*300:batch_idx*300 + len(targets)] = main_term
    term2 = np.log(term2)
    mutual_info_value -= term2
    print(mutual_info_value.shape, mutual_info_value[:10])
    mutual_info_value = np.sum(mutual_info_value)
sum += mutual_info_value
print('trial: ', t, ' --> MI = ', mutual_info_value)

0 个答案:

没有答案
相关问题