骰子损失数字:
加权交叉熵损失的数字:
优化器是Adam,lr = 0.0002,beta1 = 0.5,beta2 = 0.999。有人和我有同样的问题吗?您能告诉我解决方法和原因吗?
3d-unet模型如下所示。
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class UNet3D(nn.Module):
def name(self):
return self.name
def __init__(self, name, in_channels, out_channels, min_filters=2, norm_layer=nn.BatchNorm3d):
super(UNet3D, self).__init__()
self.name = name
activation = nn.LeakyReLU(0.2, True)
self.input = Conv2Block(in_channels, min_filters, activation=activation, norm_layer=norm_layer)
self.down1 = DownBlock(min_filters, min_filters*2, activation=activation, norm_layer=norm_layer)
self.down2 = DownBlock(min_filters*2, min_filters*4, activation=activation, norm_layer=norm_layer)
self.down3 = DownBlock(min_filters*4, min_filters*8, activation=activation, norm_layer=norm_layer)
self.up1 = UpBlock(min_filters*8, min_filters*4, activation=activation, norm_layer=norm_layer)
self.up2 = UpBlock(min_filters*4, min_filters*2, activation=activation, norm_layer=norm_layer)
self.up3 = UpBlock(min_filters*2, min_filters, activation=activation, norm_layer=norm_layer)
if self.name == 'netS':
self.out = OutBlock(min_filters, out_channels, is_image=False)
elif self.name == 'netG':
self.out = OutBlock(min_filters, out_channels, is_image=True)
def forward(self, x):
# print('x', np.unique(x))
x1 = self.input(x)
# print('input', np.unique(x1.detach().numpy()))
x2 = self.down1(x1)
# print('down1', np.unique(x2.detach().numpy()))
x3 = self.down2(x2)
# print('down2', np.unique(x3.detach().numpy()))
x = self.down3(x3)
# print('down3', np.unique(x.detach().numpy()))
x = self.up1(x, x3)
# print('up1', np.unique(x.detach().numpy()))
x = self.up2(x, x2)
# print('up2', np.unique(x.detach().numpy()))
x = self.up3(x, x1)
# print('up3', np.unique(x.detach().numpy()))
x = self.out(x)
# print('out', np.unique(x.detach().numpy()))
return x
class Conv2Block(nn.Module):
'''
Two successive Conv3d, each Conv3d followed by norm_layer and activation.
first one| in: in_channels, out: out_channels
second one| in: out_channels, out: out_channels
'''
def __init__(self, in_channels, out_channels, activation, norm_layer=nn.BatchNorm3d, kernel_size=3):
super(Conv2Block, self).__init__()
self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size, padding=int(np.ceil((kernel_size-1)/2)))
self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size, padding=int(np.ceil((kernel_size-1)/2)))
self.norm_layer = norm_layer
if(norm_layer != None):
self.bn1 = norm_layer(out_channels)
self.bn2 = norm_layer(out_channels)
self.activation = activation
def forward(self, x):
if(self.norm_layer != None):
x = self.activation(self.bn1(self.conv1(x)))
x = self.activation(self.bn2(self.conv2(x)))
else:
x = self.activation(self.conv1(x))
x = self.activation(self.conv2(x))
return x
class DownBlock(nn.Module):
'''
MaxPool3d + Conv2Block
'''
def __init__(self, in_channels, out_channels, activation, norm_layer=nn.BatchNorm3d, kernel_size=3):
super(DownBlock, self).__init__()
self.conv = Conv2Block(in_channels, out_channels, activation, norm_layer=norm_layer, kernel_size=kernel_size)
def forward(self, x):
max_pool = nn.MaxPool3d(2)
x = self.conv(max_pool(x))
return x
class UpConv(nn.Module):
'''
interpolate + conv3d + activation
'''
def __init__(self, in_channels, out_channels, activation, kernel_size=3):
super(UpConv, self).__init__()
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, padding=int(np.ceil((kernel_size-1)/2)))
self.activation = activation
def forward(self, x, sz):
x = F.interpolate(x, sz, mode='trilinear', align_corners=True)
# x = F.interpolate(x, sz, mode='nearest')
x = self.activation(self.conv(x)) # TODO 未作normalization
return x
class UpBlock(nn.Module):
'''
UpConv + cat + Conv2Block
'''
def __init__(self, in_channels, out_channels, activation, norm_layer=nn.BatchNorm3d, kernel_size=3):
super(UpBlock, self).__init__()
self.upconv = UpConv(in_channels, out_channels, activation, kernel_size=kernel_size)
self.conv = Conv2Block(in_channels, out_channels, activation, norm_layer=norm_layer, kernel_size=kernel_size)
def forward(self, x, x2):
x = self.upconv(x, (x2.shape[-3], x2.shape[-2], x2.shape[-1]))
x = torch.cat([x, x2], dim=1)
x = self.conv(x)
return x
class OutBlock(nn.Module):
def __init__(self, in_channels, out_channels, is_image):
super(OutBlock, self).__init__()
self.conv = nn.Conv3d(in_channels, out_channels, 1)
self.out = nn.Tanh()
if is_image:
self.out = nn.Tanh()
else:
self.out = nn.Sigmoid()
def forward(self, x):
x = self.conv(x)
x = self.out(x)
return x
加权交叉熵损失和骰子分数摘要。
class SoftDiceLoss(nn.Module):
def __init__(self):
super(SoftDiceLoss, self).__init__()
def forward(self, preds, labels):
preds = F.softmax(preds, dim=1)
num = labels.size(0)
m1 = preds.view(num, -1)
m2 = labels.view(num, -1)
intersection = (m1 * m2)
score = 2. * (intersection.sum(1) + 1) / (m1.sum(1) + m2.sum(1) + 1) # m1.sum = cube size
score = 1 - score.sum() / num
return score
class WeightedCrossEntropyLoss(nn.Module):
'''
Negative log likelihood loss
'''
def __init__(self):
super(WeightedCrossEntropyLoss, self).__init__()
# self.loss = nn.NLLLoss()
def forward(self, pred, gt, alpha=3.0):
gt = torch.squeeze(gt.argmax(dim=1), dim=1)
assert gt.dtype == torch.long
import torch.nn.functional as F
mtx = F.cross_entropy(pred, gt, reduction='none')
bg = (gt == 0) + (gt == 5) # background
neg = mtx[bg]
pos = mtx[1-bg]
Np, Nn = pos.numel(), neg.numel()
pos = pos.sum()
k = min(Np*alpha, Nn)
if k > 0:
neg, _ = torch.topk(neg, int(k))
neg = neg.sum()
else:
neg = 0.0
loss = (pos + neg)/(Np + k)
return loss