U-net:如何在pytorch中改变背景的重量

时间:2018-01-11 04:19:34

标签: python pytorch

我使用的数据库是脑MRI图片。我已将这些图像转换为2D图像。 在训练期间,我发现由于大部分标记的图像都是黑色背景,因此只有一小部分是肿瘤。 这导致了不平衡的类别。我想增加肿瘤的重量并减轻背景的重量。 我的想法是可以计算背景像素的次数,它出现的次数越多,它的重量就越低。  但我不知道如何修改程序? 我的代码: unet.py:

if @register.save
   @payment = Payment.new({email: params["school_registration"]["email"],
                token: params[:payment]["token"], school_registration_id: @register.id
            })
   flash[:error] = "Please check registration errors" unless @payment.valid?
   begin
      @payment.process_payment
      @payment.save
   rescue Exception => e
      flash[:error] = e.message
      @register.destroy
      render :new and return #=> :new means your registration form
   end
else
    #=> Code 
end

train.py:

    import torch.nn as nn
    import torch.nn.functional as F
    import torch
    from numpy.linalg import svd
    from numpy.random import normal
    from math import sqrt


    class UNet(nn.Module):
        def __init__(self,colordim =4):
            super(UNet, self).__init__()
            self.conv1_1 = nn.Conv2d(colordim,32,3,padding=1,stride=1)  # input of (n,n,1), output of (n-2,n-2,64)
            self.conv1_2 = nn.Conv2d(32,32,3,padding=1,stride=1)
            self.bn1 = nn.BatchNorm2d(32)

            self.conv2_1 = nn.Conv2d(32, 64, 3,padding=1)
            self.conv2_2 = nn.Conv2d(64, 64, 3,padding=1)
            self.bn2 = nn.BatchNorm2d(64)

            self.conv3_1 = nn.Conv2d(64, 128, 3,padding=1)
            self.conv3_2 = nn.Conv2d(128, 128, 3,padding=1)
            self.bn3 = nn.BatchNorm2d(128)

            self.conv4_1 = nn.Conv2d(128, 256, 3,padding=1)
            self.conv4_2 = nn.Conv2d(256, 256, 3,padding=1)
            self.upconv4 = nn.Conv2d(256, 128, 1)
            self.bn4 = nn.BatchNorm2d(128)
            self.bn4_out = nn.BatchNorm2d(256)

            self.conv5_1 = nn.Conv2d(256, 128, 3,padding=1)
            self.conv5_2 = nn.Conv2d(128, 128, 3,padding=1)
            self.upconv5 = nn.Conv2d(128, 64, 1)
            self.bn5 = nn.BatchNorm2d(64)
            self.bn5_out = nn.BatchNorm2d(128)

            self.conv6_1 = nn.Conv2d(128, 64, 3,padding=1)
            self.conv6_2 = nn.Conv2d(64, 64, 3,padding=1)
            self.upconv6 = nn.Conv2d(64, 32, 1)
            self.bn6 = nn.BatchNorm2d(32)
            self.bn6_out = nn.BatchNorm2d(64)

            self.conv7_1 = nn.Conv2d(64, 32, 3,padding=1)
            self.conv7_2 = nn.Conv2d(32, 32, 3,padding=1)
            self.conv7_3 = nn.Conv2d(32, 1, 1)
            self.bn7 = nn.BatchNorm2d(1)

            self.maxpool = nn.MaxPool2d(2, stride=2, return_indices=False, ceil_mode=False)
            self.upsample = nn.UpsamplingBilinear2d(scale_factor=2)
            self._initialize_weights()

        def forward(self, x1):
            #x1out
            x1 = F.relu(self.bn1(self.conv1_2(F.relu(self.conv1_1(x1)))))
            #print('x1 size: %d'%(x1.size(2)))
            #x2out
            x2 = F.relu(self.bn2(self.conv2_2(F.relu(self.conv2_1(self.maxpool(x1))))))
            #print('x2 size: %d'%(x2.size(2)))
            #x3out
            x3 = F.relu(self.bn3(self.conv3_2(F.relu(self.conv3_1(self.maxpool(x2))))))
            #print('x3 size: %d'%(x3.size(2)))
            #x4out
            xup = F.relu((self.conv4_2(F.relu(self.conv4_1(self.maxpool(x3))))))
            #print('x4 size: %d'%(xup.size(2)))
            #x5in
            xup = self.bn4(self.upconv4(self.upsample(xup)))
            #print('x5in size: %d'%(xup.size(2)))
            #x5out
            xup = self.bn4_out(torch.cat((x3,xup),1))
            xup = F.relu(self.conv5_2(F.relu(self.conv5_1(xup))))
            #print('x5ou size: %d' % (xup.size(2)))
            #x6in
            xup = self.bn5(self.upconv5(self.upsample(xup)))

            #x6out
            xup = self.bn5_out(torch.cat((x2,xup),1))
            xup = F.relu(self.conv6_2(F.relu(self.conv6_1(xup))))

            #x7in
            xup = self.bn6(self.upconv6(self.upsample(xup)))

            #x7out
            xup = self.bn6_out(torch.cat((x1,xup),1))
            xup = F.relu(self.conv7_3(F.relu(self.conv7_2(F.relu(self.conv7_1(xup))))))
            return F.softsign(self.bn7(xup))


        def _initialize_weights(self):
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                    m.weight.data.normal_(0, sqrt(2. / n))
                    if m.bias is not None:
                        m.bias.data.zero_()
                elif isinstance(m, nn.BatchNorm2d):
                    m.weight.data.fill_(1)
                    m.bias.data.zero_()


    unet = UNet().cuda()

load_data.py:

#-*- coding:utf-8 -*-
from __future__ import print_function
from math import log10
import numpy as np
import random
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from unet import UNet
from load_data import get_train_img,get_test_img
import torchvision


# Training settings

class option:
    def __init__(self):
        self.cuda = True #use cuda?
        self.batchSize = 3 #training batch size
        self.testBatchSize = 3 #testing batch size
        self.nEpochs = 5 #number of epochs to train for
        self.lr = 0.001 #Learning Rate. Default=0.01
        self.threads = 0 #number of threads for data loader to use
        self.seed = 123 #random seed to use. Default=123
        self.size = 240 # image size
        self.colordim = 4 #
        self.pretrain_net = 'F:\\image\\100data\\model\\model_epoch_140.pth'

def map01(tensor,eps=1e-5):
    #input/output:tensor
    max = np.max(tensor.numpy(), axis=(1,2,3), keepdims=True)
    min = np.min(tensor.numpy(), axis=(1,2,3), keepdims=True)
    if (max-min).any():
        return torch.from_numpy( (tensor.numpy() - min) / (max-min + eps) )
    else:
        return torch.from_numpy( (tensor.numpy() - min) / (max-min) )

opt = option()

cuda = opt.cuda
if cuda and not torch.cuda.is_available():
    raise Exception("No GPU found, please run without --cuda")

torch.manual_seed(opt.seed)
if cuda:
    torch.cuda.manual_seed(opt.seed)
print('===>Use the default training, test data')
print('===> Loading datasets')
train_set = get_train_img(opt.size)
test_set = get_test_img(opt.size)
training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)
testing_data_loader = DataLoader(dataset=test_set, num_workers=opt.threads, batch_size=opt.testBatchSize, shuffle=False)
#torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
    #  num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)

print('===> Building unet')
unet = UNet()

criterion = nn.MSELoss()
if cuda:
    unet = unet.cuda()
    criterion = criterion.cuda()

pretrained = False
if pretrained:
    unet.load_state_dict(torch.load(opt.pretrain_net))

optimizer = optim.SGD(unet.parameters(), lr=opt.lr)
print('===> Training unet')
print(training_data_loader)
def train(epoch):
    epoch_loss = 0
    for iteration, batch in enumerate(training_data_loader, 1):
        #print('train\' iteration and batch'+str(type(iteration))+' '+str(type(batch)))
        #print()
        input = Variable(batch[0])
        target = Variable(batch[1])
        filename = batch[2]
        #target =target.squeeze(1)
        #print(target.data.size())
        if cuda:
            input = input.cuda()
            target = target.cuda()
        input = unet(input)

        loss = criterion( input, target)
        epoch_loss += loss.data[0]
        loss.backward()
        optimizer.step()
        if iteration%10 is 0:
            print("===> Epoch[{}]({}/{}): Loss: {:.4f}".format(epoch, iteration, len(training_data_loader), loss.data[0]))

    imgout = input.data
    torchvision.utils.save_image(imgout,"F:\\image\\100data\\output\\train_out\\"+str(epoch)+'.png',padding=0)
    print("===> Epoch {} Complete: Avg. Loss: {}".format(epoch, epoch_loss / len(training_data_loader)))
    print('train ok')
def test(epoch):
    totalloss = 0
    for iteration ,batch in enumerate(testing_data_loader):

        input = Variable(batch[0])
        filename = batch[1]
        if cuda:
            input = input.cuda()


            optimizer.zero_grad()
            prediction = unet(input)

        imgout = prediction.data
        torchvision.utils.save_image(imgout,"F:\\image\\100data\\output\\test_out\\"+str(epoch)+'.png',padding=0)


def checkpoint(epoch):
    model_out_path = "F:\\image\\100data\\model\\model_epoch_{}.pth".format(epoch)
    torch.save(unet.state_dict(), model_out_path)
    print("Checkpoint saved to {}".format(model_out_path))

for epoch in range(1, 1+opt.nEpochs + 1):
    train(epoch)
    if epoch%10 is 0:
        checkpoint(epoch)
    test(epoch)
checkpoint(epoch)

我的目录:

#-*- coding:utf-8 -*-
from os.path import exists,join
from os import listdir
import numpy as np
from PIL import Image
import torch.utils.data as data
from skimage import io
from torchvision.transforms import Compose, CenterCrop, ToTensor, Scale,ToPILImage


def brats2015(dest = r'F:\image'):
    if not exists(dest):
        print('Sorry,dataset not exits')
        print('please check the file path')
    return dest

def input_transform(crop_size):
    return Compose([
        #CenterCrop(crop_size),
        ToTensor()
    ])
def get_train_img(size,train_path = r'H:\deepfortest\2Dpng'):
    brats_train = brats2015(train_path)
    train_num = join(brats_train,'T1')
    return DatasetFromFolder(image_dir = train_num,
                             image_path = train_path,
                             input_transform = input_transform(size),
                             target_transform = input_transform(size),
                             OT_point=True,
                            )

def get_test_img(size,test_path = r'F:\image\100data\test'):
    brats_test = brats2015(test_path)
    test_num = join(brats_test,'T1')
    return DatasetFromFolder(image_dir = test_num,
                             image_path = test_path,
                             input_transform = input_transform(size),
                             target_transform = input_transform(size),
                             OT_point=False,
                            )

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in [".png", ".jpg"])

def load_img(image_path, name, label_bool=False):
    if label_bool is False:
        Flair_path1 = join(join(image_path,'Flair'),name)
        T1_path2 = join(join(image_path,'T1'),name)
        T1C_path3 = join(join(image_path,'T1C'),name)
        T2_path4 = join(join(image_path,'T2'),name)
        Flair = io.imread(Flair_path1)
        T1 = io.imread(T1_path2)
        T1C = io.imread(T1C_path3)
        T2 = io.imread(T2_path4)
        B = np.dstack((Flair,T1,T1C,T2))
    if label_bool is True:
        path = join(join(image_path,'OT'),name)
        B = Image.open(path)
        #B = np.array(A)
    return B
#image_dir = r'F:\image\train\T1',image_path = r'F:\image\train'
class DatasetFromFolder(data.Dataset):
    def __init__(self, image_dir, image_path, input_transform=None, target_transform=None,OT_point = True):
        super(DatasetFromFolder, self).__init__()
        self.image_filenames = [x for x in listdir(image_dir) if is_image_file(x)]
        self.input_transform = input_transform
        self.target_transform = target_transform
        self.image_dir = image_dir
        self.image_path = image_path
        self.OT_point = OT_point

    def __getitem__(self, index):
        input = load_img(self.image_path, self.image_filenames[index])
        if self.input_transform:
            input = self.input_transform(input)
        if self.OT_point:
            target = load_img(self.image_path, self.image_filenames[index], label_bool=True)
            if self.target_transform:
                target = self.target_transform(target)
            return input, target,self.image_filenames
        return input,self.image_filenames

    def __len__(self):
        return len(self.image_filenames)

0 个答案:

没有答案