我尝试用均方误差损失替换si-snr-loss,但是损失方程式出现错误 TypeError:无法转换numpy.object_类型的np.ndarray。唯一受支持的类型是float64,float32,float16,int64,int32,int16,int8,uint8和bool。
有人可以帮忙吗?
原始Si-SNR损耗如下: '''
def si_snr_loss(ests, egs):
refs = egs["ref"]
num_spks = len(refs)
def sisnr_loss(permute):
return sum([sisnr(ests[s], refs[t]) for s, t in enumerate(permute)]) / len(permute)
N = egs["mix"].size(0)
sisnr_mat = torch.stack([sisnr_loss(p) for p in permutations(range(num_spks))])
max_perutt,_ = torch.max(sisnr_mat, dim=0)
return -torch.sum(max_perutt) / N
for egs in val_dataloader:
current_step += 1
egs = to_device(egs, self.device)
ests = data_parallel(self.net, egs['mix'], device_ids=self.gpuid)
#loss = si_snr_loss(ests, egs)
loss = (ests - torch.Tensor(np.array(egs.values())))**2
losses.append(loss.item())
'''