运行python代码时Cuda错误出现在内存中

时间:2018-08-28 07:26:10

标签: python pytorch

我正在尝试运行以下代码。 我收到一个错误。该代码来自从Github下载的Deep Image Prior。谁能告诉我为什么我收到此错误以及我在哪里出错? 我读到一些关于它的内容,它说可以减小批量大小。如何减少呢? 是因为.tif文件约为100 MB? 错误是RuntimeError:CUDA错误:内存不足,我正在使用12 GB GPU和CUDA 9.2

from __future__ import print_function
import matplotlib.pyplot as plt
%matplotlib inline
import argparse
import os
#os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import numpy as np
from models import *
import torch
import torch.optim
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import warnings
warnings.filterwarnings("ignore")
from skimage.measure import compare_psnr
from models.downsampler import Downsampler
from utils.sr_utils import *
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark =True
dtype = torch.cuda.FloatTensor
device = torch.device("cuda:0" if torch.cuda.is_available() else"cpu")
imsize = (819,819)
factor = 2 # 8
enforse_div32 = 'CROP' # we usually need the dimensions to be divisible by     a power of two (32 in this case)
PLOT = True
# To produce images from the paper we took *_GT.png images from LapSRN     viewer for corresponding factor,
# e.g. x4/zebra_GT.png for factor=4, and x8/zebra_GT.png for factor=8 
with torch.no_grad():
path_to_image = '/home/smitha/deep-image-prior/resize.tif'
imgs = load_LR_HR_imgs_sr(path_to_image , imsize, factor, enforse_div32)
imgs['bicubic_np'], imgs['sharp_np'], imgs['nearest_np'] =  get_baselines(imgs['LR_pil'], imgs['HR_pil'])
 if PLOT:
 plot_image_grid([imgs['HR_np'], imgs['bicubic_np'], imgs['sharp_np'],   imgs['nearest_np']], 4,12);
print ('PSNR bicubic: %.4f   PSNR nearest: %.4f' %  (
                                    compare_psnr(imgs['HR_np'], imgs['bicubic_np']), 
                                    compare_psnr(imgs['HR_np'], imgs['nearest_np'])))
input_depth = 8
INPUT =     'noise'
pad   =     'reflection'
OPT_OVER =  'net'
KERNEL_TYPE='lanczos2'
LR = 1
tv_weight = 0.0
OPTIMIZER = 'adam'
if factor == 2: 
num_iter = 10
reg_noise_std = 0.01
elif factor == 8:
num_iter = 40
reg_noise_std = 0.05
else:
assert False, 'We did not experiment with other factors'
net_input = get_noise(input_depth, INPUT, (imgs['HR_pil'].size[1],    imgs['HR_pil'].size[0])).type(dtype).detach()
NET_TYPE = 'skip' # UNet, ResNet
net = get_net(input_depth, 'skip', pad,
          skip_n33d=512, 
          skip_n33u=512, 
          skip_n11=4, 
          num_scales=5,
          upsample_mode='bilinear').type(dtype)
      # Losses
      mse = torch.nn.MSELoss().type(dtype)
      img_LR_var = np_to_torch(imgs['LR_np']).type(dtype)
      downsampler = Downsampler(n_planes=3, factor=factor,   kernel_type=KERNEL_TYPE, phase=0.5, preserve_size=True).type(dtype)
     def closure():
     global i, net_input
     if reg_noise_std > 0:
     net_input = net_input_saved + (noise.normal_() * reg_noise_std)
     out_HR = net(net_input)
     out_LR = downsampler(out_HR)
     total_loss = mse(out_LR, img_LR_var) 
     if tv_weight > 0:
     total_loss += tv_weight * tv_loss(out_HR)
     total_loss.backward()
     # Log
     psnr_LR = compare_psnr(imgs['LR_np'], torch_to_np(out_LR))
     psnr_HR = compare_psnr(imgs['HR_np'], torch_to_np(out_HR))
     print ('Iteration %05d    PSNR_LR %.3f   PSNR_HR %.3f' % (i, psnr_LR,  psnr_HR), '\r', end='')
      # History
      psnr_history.append([psnr_LR, psnr_HR])
      if PLOT and i % 100 == 0:
      out_HR_np = torch_to_np(out_HR)
      plot_image_grid([imgs['HR_np'], imgs['bicubic_np'],  np.clip(out_HR_np, 0, 1)], factor=13, nrow=3)
      i += 1
      return total_loss
       psnr_history = [] 
       net_input_saved = net_input.detach().clone()
       noise = net_input.clone()
       i = 0
       p = get_params(OPT_OVER, net, net_input)
       optimize(OPTIMIZER, p, closure, LR, num_iter)
       out_HR_np = np.clip(torch_to_np(net(net_input)), 0, 1)
       result_deep_prior = put_in_center(out_HR_np, imgs['orig_np'].shape[1:])
       # For the paper we acually took `_bicubic.png` files from LapSRN viewer   and used `result_deep_prior` as our result
       plot_image_grid([imgs['HR_np'],
             imgs['bicubic_np'],
             out_HR_np], factor=4, nrow=1);

0 个答案:

没有答案