实施MSE损失功能时出错

时间:2020-11-02 18:17:31

标签: python

我尝试用均方误差损失替换si-snr-loss,但是损失方程式出现错误 TypeError:无法转换numpy.object_类型的np.ndarray。唯一受支持的类型是float64,float32,float16,int64,int32,int16,int8,uint8和bool。

有人可以帮忙吗?

原始Si-SNR损耗如下: '''

def si_snr_loss(ests, egs):
    refs = egs["ref"]
    num_spks = len(refs)

    def sisnr_loss(permute):
        return sum([sisnr(ests[s], refs[t]) for s, t in enumerate(permute)]) / len(permute)
    N = egs["mix"].size(0)
    sisnr_mat = torch.stack([sisnr_loss(p) for p in permutations(range(num_spks))])
    max_perutt,_ = torch.max(sisnr_mat, dim=0)
    return -torch.sum(max_perutt) / N
for egs in val_dataloader:
                current_step += 1
                egs = to_device(egs, self.device)
                ests = data_parallel(self.net, egs['mix'], device_ids=self.gpuid)
                #loss = si_snr_loss(ests, egs)
                loss = (ests - torch.Tensor(np.array(egs.values())))**2
                losses.append(loss.item())

'''

0 个答案:

没有答案