无法使用pytorch插值框架

时间:2018-07-27 09:21:53

标签: python pytorch

因此,我试图测试在github页上插入帧的算法。我正在所提供的Middlebury数据集上测试该算法,链接为:http://vision.middlebury.edu/flow/data/。(我只使用彩色图像,只有两帧。在优胜美地图像文件中也发生了错误)该算法在除最后一个。它会引发错误:

Traceback (most recent call last): 
File "run.py", line 257, in  
tensorInputFirst = torch.FloatTensor(numpy.rollaxis(numpy.asarray(PIL.Image.open(arguments_strFirst))[:,:,::-1], 2, 0).astype(numpy.float32) / 255.0) 
IndexError: too many indices for array

我知道错误的含义,该索引不在列表中,但是我不知道它与帧插值的关系以及为什么会发生。有人可以帮忙吗。

我在github上找到的代码:

#!/usr/bin/env python2.7

import sys
import getopt
import math
import numpy
import torch
import torch.utils.serialization
import PIL
import PIL.Image

from moviepy.video.io.ffmpeg_reader import FFMPEG_VideoReader
from moviepy.video.io.ffmpeg_writer import FFMPEG_VideoWriter

from SeparableConvolution import SeparableConvolution # the custom SeparableConvolution layer

torch.cuda.device(1) 

torch.backends.cudnn.enabled = True 

##########################################################

arguments_strModel = 'lf'
arguments_strFirst = './images/first.png'
arguments_strSecond = './images/second.png'
arguments_strOut = './result.png'
arguments_strVideo = ''
arguments_strVideoOut = ''

for strOption, strArgument in getopt.getopt(sys.argv[1:], '', [ strParameter[2:] + '=' for strParameter in sys.argv[1::2] ])[0]:
    if strOption == '--model':
        arguments_strModel = strArgument # which model to use, l1 or lf, please see our paper for more details

    elif strOption == '--first':
        arguments_strFirst = strArgument # path to the first frame

    elif strOption == '--second':
        arguments_strSecond = strArgument # path to the second frame

    elif strOption == '--out':
        arguments_strOut = strArgument # path to where the output should be stored

    elif strOption == '--video':
        arguments_strVideo = strArgument # path to the video

    elif strOption == '--video-out':
        arguments_strVideoOut = strArgument # path to the video


##########################################################

class Network(torch.nn.Module):
    def __init__(self):
        super(Network, self).__init__()

        def Basic(intInput, intOutput):
            return torch.nn.Sequential(
                torch.nn.Conv2d(in_channels=intInput, out_channels=intOutput, kernel_size=3, stride=1, padding=1),
                torch.nn.ReLU(inplace=False),
                torch.nn.Conv2d(in_channels=intOutput, out_channels=intOutput, kernel_size=3, stride=1, padding=1),
                torch.nn.ReLU(inplace=False),
                torch.nn.Conv2d(in_channels=intOutput, out_channels=intOutput, kernel_size=3, stride=1, padding=1),
                torch.nn.ReLU(inplace=False)
            )


        def Subnet():
            return torch.nn.Sequential(
                torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
                torch.nn.ReLU(inplace=False),
                torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
                torch.nn.ReLU(inplace=False),
                torch.nn.Conv2d(in_channels=64, out_channels=51, kernel_size=3, stride=1, padding=1),
                torch.nn.ReLU(inplace=False),
                torch.nn.Upsample(scale_factor=2, mode='bilinear'),
                torch.nn.Conv2d(in_channels=51, out_channels=51, kernel_size=3, stride=1, padding=1)
            )

        self.moduleConv1 = Basic(6, 32)
        self.modulePool1 = torch.nn.AvgPool2d(kernel_size=2, stride=2)

        self.moduleConv2 = Basic(32, 64)
        self.modulePool2 = torch.nn.AvgPool2d(kernel_size=2, stride=2)

        self.moduleConv3 = Basic(64, 128)
        self.modulePool3 = torch.nn.AvgPool2d(kernel_size=2, stride=2)

        self.moduleConv4 = Basic(128, 256)
        self.modulePool4 = torch.nn.AvgPool2d(kernel_size=2, stride=2)

        self.moduleConv5 = Basic(256, 512)
        self.modulePool5 = torch.nn.AvgPool2d(kernel_size=2, stride=2)

        self.moduleDeconv5 = Basic(512, 512)
        self.moduleUpsample5 = torch.nn.Sequential(
            torch.nn.Upsample(scale_factor=2, mode='bilinear'),
            torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(inplace=False)
        )

        self.moduleDeconv4 = Basic(512, 256)
        self.moduleUpsample4 = torch.nn.Sequential(
            torch.nn.Upsample(scale_factor=2, mode='bilinear'),
            torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(inplace=False)
        )

        self.moduleDeconv3 = Basic(256, 128)
        self.moduleUpsample3 = torch.nn.Sequential(
            torch.nn.Upsample(scale_factor=2, mode='bilinear'),
            torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(inplace=False)
        )

        self.moduleDeconv2 = Basic(128, 64)
        self.moduleUpsample2 = torch.nn.Sequential(
            torch.nn.Upsample(scale_factor=2, mode='bilinear'),
            torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(inplace=False)
        )

        self.moduleVertical1 = Subnet()
        self.moduleVertical2 = Subnet()
        self.moduleHorizontal1 = Subnet()
        self.moduleHorizontal2 = Subnet()

        self.modulePad = torch.nn.ReplicationPad2d([ int(math.floor(51 / 2.0)), int(math.floor(51 / 2.0)), int(math.floor(51 / 2.0)), int(math.floor(51 / 2.0)) ])

        self.load_state_dict(torch.load('./network-' + arguments_strModel + '.pytorch'))

    def forward(self, variableInput1, variableInput2):
        variableJoin = torch.cat([variableInput1, variableInput2], 1)

        variableConv1 = self.moduleConv1(variableJoin)
        variablePool1 = self.modulePool1(variableConv1)

        variableConv2 = self.moduleConv2(variablePool1)
        variablePool2 = self.modulePool2(variableConv2)

        variableConv3 = self.moduleConv3(variablePool2)
        variablePool3 = self.modulePool3(variableConv3)

        variableConv4 = self.moduleConv4(variablePool3)
        variablePool4 = self.modulePool4(variableConv4)

        variableConv5 = self.moduleConv5(variablePool4)
        variablePool5 = self.modulePool5(variableConv5)

        variableDeconv5 = self.moduleDeconv5(variablePool5)
        variableUpsample5 = self.moduleUpsample5(variableDeconv5)

        variableDeconv4 = self.moduleDeconv4(variableUpsample5 + variableConv5)
        variableUpsample4 = self.moduleUpsample4(variableDeconv4)

        variableDeconv3 = self.moduleDeconv3(variableUpsample4 + variableConv4)
        variableUpsample3 = self.moduleUpsample3(variableDeconv3)

        variableDeconv2 = self.moduleDeconv2(variableUpsample3 + variableConv3)
        variableUpsample2 = self.moduleUpsample2(variableDeconv2)

        variableCombine = variableUpsample2 + variableConv2

        variableDot1 = SeparableConvolution()(self.modulePad(variableInput1), self.moduleVertical1(variableCombine), self.moduleHorizontal1(variableCombine))
        variableDot2 = SeparableConvolution()(self.modulePad(variableInput2), self.moduleVertical2(variableCombine), self.moduleHorizontal2(variableCombine))

        return variableDot1 + variableDot2


moduleNetwork = Network().cuda()

##########################################################

def process(tensorInputFirst, tensorInputSecond, tensorOutput):
    assert(tensorInputFirst.size(1) == tensorInputSecond.size(1))
    assert(tensorInputFirst.size(2) == tensorInputSecond.size(2))

    intWidth = tensorInputFirst.size(2)
    intHeight = tensorInputFirst.size(1)

    assert(intWidth <= 1280) # while our approach works with larger images, we do not recommend it unless you are aware of the implications
    assert(intHeight <= 720) # while our approach works with larger images, we do not recommend it unless you are aware of the implications

    intPaddingLeft = int(math.floor(51 / 2.0))
    intPaddingTop = int(math.floor(51 / 2.0))
    intPaddingRight = int(math.floor(51 / 2.0))
    intPaddingBottom = int(math.floor(51 / 2.0))
    modulePaddingInput = torch.nn.Module()
    modulePaddingOutput = torch.nn.Module()

    if True:
        intPaddingWidth = intPaddingLeft + intWidth + intPaddingRight
        intPaddingHeight = intPaddingTop + intHeight + intPaddingBottom

        if intPaddingWidth != ((intPaddingWidth >> 7) << 7):
            intPaddingWidth = (((intPaddingWidth >> 7) + 1) << 7) # more than necessary

        if intPaddingHeight != ((intPaddingHeight >> 7) << 7):
            intPaddingHeight = (((intPaddingHeight >> 7) + 1) << 7) # more than necessary

        intPaddingWidth = intPaddingWidth - (intPaddingLeft + intWidth + intPaddingRight)
        intPaddingHeight = intPaddingHeight - (intPaddingTop + intHeight + intPaddingBottom)

        modulePaddingInput = torch.nn.ReplicationPad2d([intPaddingLeft, intPaddingRight + intPaddingWidth, intPaddingTop, intPaddingBottom + intPaddingHeight])
        modulePaddingOutput = torch.nn.ReplicationPad2d([0 - intPaddingLeft, 0 - intPaddingRight - intPaddingWidth, 0 - intPaddingTop, 0 - intPaddingBottom - intPaddingHeight])

    if True:
        tensorInputFirst = tensorInputFirst.cuda()
        tensorInputSecond = tensorInputSecond.cuda()

        modulePaddingInput = modulePaddingInput.cuda()
        modulePaddingOutput = modulePaddingOutput.cuda()

    if True:
        variablePaddingFirst = modulePaddingInput(torch.autograd.Variable(data=tensorInputFirst.view(1, 3, intHeight, intWidth), volatile=True))
        variablePaddingSecond = modulePaddingInput(torch.autograd.Variable(data=tensorInputSecond.view(1, 3, intHeight, intWidth), volatile=True))
        variablePaddingOutput = modulePaddingOutput(moduleNetwork(variablePaddingFirst, variablePaddingSecond))

        tensorOutput.resize_(3, intHeight, intWidth).copy_(variablePaddingOutput.data[0])

    if True:
        tensorInputFirst.cpu()
        tensorInputSecond.cpu()
        tensorOutput.cpu()


tensorOutput = torch.FloatTensor()

if arguments_strVideo and arguments_strVideoOut:
    reader = FFMPEG_VideoReader(arguments_strVideo, False)
    writer = FFMPEG_VideoWriter(arguments_strVideoOut, reader.size, reader.fps*2)
    reader.initialize()
    nextFrame = reader.read_frame()
    for x in range(0, reader.nframes):
        firstFrame = nextFrame
        nextFrame = reader.read_frame()
        tensorInputFirst = torch.FloatTensor(numpy.rollaxis(firstFrame[:,:,::-1], 2, 0) / 255.0)
        tensorInputSecond = torch.FloatTensor(numpy.rollaxis(nextFrame[:,:,::-1], 2, 0) / 255.0)
        process(tensorInputFirst, tensorInputSecond, tensorOutput)
        writer.write_frame(firstFrame)
        writer.write_frame((numpy.rollaxis(tensorOutput.clamp(0.0, 1.0).numpy(), 0, 3)[:,:,::-1] * 255.0).astype(numpy.uint8))
    writer.write_frame(nextFrame)
    writer.close()
else:
    tensorInputFirst = torch.FloatTensor(numpy.rollaxis(numpy.asarray(PIL.Image.open(arguments_strFirst))[:,:,::-1], 2, 0).astype(numpy.float32) / 255.0)
    tensorInputSecond = torch.FloatTensor(numpy.rollaxis(numpy.asarray(PIL.Image.open(arguments_strSecond))[:,:,::-1], 2, 0).astype(numpy.float32) / 255.0) 
    process(tensorInputFirst, tensorInputSecond, tensorOutput)
    PIL.Image.fromarray((numpy.rollaxis(tensorOutput.clamp(0.0, 1.0).numpy(), 0, 3)[:,:,::-1] * 255.0).astype(numpy.uint8)).save(arguments_strOut)

更新

我找到了错误的原因。这是因为图像是灰度的,而我测试的其他图像是彩色的,因为彩色的是24位深度,而灰度是8位深度。如何修改代码以使其在灰度图像上工作。

0 个答案:

没有答案