在以下位置执行了无效的操作:LinearFunction(向前)

时间:2019-02-02 16:05:45

标签: python-3.x deep-learning chainer dcgan

这是主要代码。图片尺寸为420 * 420

import os

import chainer
from chainer import training
from chainer.training import extensions



batchsize = 64
epoch = 10
gpu = 0
dataset = "/content/gdrive/My Drive/images/"
out = "/content/gdrive/My Drive/DCGAN/"
resume = ""
n_hidden = 100
seed = 0
snapshot_interval = 200
display_interval = 100


# Set up a neural network to train
gen = Generator(n_hidden=n_hidden)
dis = Discriminator()

if gpu >= 0:
    # Make a specified GPU current
    chainer.backends.cuda.get_device_from_id(gpu).use()
    gen.to_gpu()  # Copy the model to the GPU
    dis.to_gpu()

# Setup an optimizer
def make_optimizer(model, alpha=0.0002, beta1=0.5):
    optimizer = chainer.optimizers.Adam(alpha=alpha, beta1=beta1)
    optimizer.setup(model)
    optimizer.add_hook(
        chainer.optimizer_hooks.WeightDecay(0.0001), 'hook_dec')
    return optimizer

opt_gen = make_optimizer(gen)
opt_dis = make_optimizer(dis)

if dataset == '':
    # Load the CIFAR10 dataset if args.dataset is not specified
    train, _ = chainer.datasets.get_cifar10(withlabel=False, scale=255.)
else:
    all_files = os.listdir(dataset)
    image_files = [f for f in all_files if ('png' in f or 'jpg' in f)]
    print('{} contains {} image files'
          .format(dataset, len(image_files)))
    train = chainer.datasets\
        .ImageDataset(paths=image_files, root=dataset)

# Setup an iterator
train_iter = chainer.iterators.SerialIterator(train, batchsize)

# Setup an updater
updater = DCGANUpdater(
    models=(gen, dis),
    iterator=train_iter,
    optimizer={
        'gen': opt_gen, 'dis': opt_dis},
    device=gpu)

# Setup a trainer
trainer = training.Trainer(updater, (epoch, 'epoch'), out=out)

snapshot_interval = (snapshot_interval, 'iteration')
display_interval = (display_interval, 'iteration')
trainer.extend(
    extensions.snapshot(filename='snapshot_iter_{.updater.iteration}.npz'),
    trigger=snapshot_interval)
trainer.extend(extensions.snapshot_object(
    gen, 'gen_iter_{.updater.iteration}.npz'), trigger=snapshot_interval)
trainer.extend(extensions.snapshot_object(
    dis, 'dis_iter_{.updater.iteration}.npz'), trigger=snapshot_interval)
trainer.extend(extensions.LogReport(trigger=display_interval))
trainer.extend(extensions.PrintReport([
    'epoch', 'iteration', 'gen/loss', 'dis/loss',
]), trigger=display_interval)
trainer.extend(extensions.ProgressBar(update_interval=10))
trainer.extend(
    out_generated_image(
        gen, dis,
        420, 420, seed, out),
    trigger=snapshot_interval)

if resume:
    # Resume from a snapshot
    chainer.serializers.load_npz(resume, trainer)

# Run the training
trainer.run()

完整的错误代码

Exception in main training loop: 
Invalid operation is performed in: LinearFunction (Forward)

Expect: x.shape[1] == W.shape[1]
Actual: 1384448 != 8192
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/chainer/training/trainer.py", line 315, in run
    update()
  File "/usr/local/lib/python3.6/dist-packages/chainer/training/updaters/standard_updater.py", line 165, in update
    self.update_core()
  File "<ipython-input-3-1c9eda353b43>", line 37, in update_core
    y_real = dis(x_real)
  File "/usr/local/lib/python3.6/dist-packages/chainer/link.py", line 242, in __call__
    out = forward(*args, **kwargs)
  File "<ipython-input-2-8321f7283f65>", line 81, in forward
    return self.l4(h)
  File "/usr/local/lib/python3.6/dist-packages/chainer/link.py", line 242, in __call__
    out = forward(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/chainer/links/connection/linear.py", line 138, in forward
    return linear.linear(x, self.W, self.b, n_batch_axes=n_batch_axes)
  File "/usr/local/lib/python3.6/dist-packages/chainer/functions/connection/linear.py", line 289, in linear
    y, = LinearFunction().apply(args)
  File "/usr/local/lib/python3.6/dist-packages/chainer/function_node.py", line 245, in apply
    self._check_data_type_forward(in_data)
  File "/usr/local/lib/python3.6/dist-packages/chainer/function_node.py", line 330, in _check_data_type_forward
    self.check_type_forward(in_type)
  File "/usr/local/lib/python3.6/dist-packages/chainer/functions/connection/linear.py", line 27, in check_type_forward
    x_type.shape[1] == w_type.shape[1],
  File "/usr/local/lib/python3.6/dist-packages/chainer/utils/type_check.py", line 546, in expect
    expr.expect()
  File "/usr/local/lib/python3.6/dist-packages/chainer/utils/type_check.py", line 483, in expect
    '{0} {1} {2}'.format(left, self.inv, right))
Will finalize trainer extensions and updater before reraising the exception.
---------------------------------------------------------------------------
InvalidType                               Traceback (most recent call last)
<ipython-input-8-a0fb675be455> in <module>()
     89 
     90 # Run the training
---> 91 trainer.run()
     92 
     93 

/usr/local/lib/python3.6/dist-packages/chainer/training/trainer.py in run(self, show_loop_exception_msg)
    327                 f.write('Will finalize trainer extensions and updater before '
    328                         'reraising the exception.\n')
--> 329             six.reraise(*sys.exc_info())
    330         finally:
    331             for _, entry in extensions:

/usr/local/lib/python3.6/dist-packages/six.py in reraise(tp, value, tb)
    691             if value.__traceback__ is not tb:
    692                 raise value.with_traceback(tb)
--> 693             raise value
    694         finally:
    695             value = None

/usr/local/lib/python3.6/dist-packages/chainer/training/trainer.py in run(self, show_loop_exception_msg)
    313                 self.observation = {}
    314                 with reporter.scope(self.observation):
--> 315                     update()
    316                     for name, entry in extensions:
    317                         if entry.trigger(self):

/usr/local/lib/python3.6/dist-packages/chainer/training/updaters/standard_updater.py in update(self)
    163 
    164         """
--> 165         self.update_core()
    166         self.iteration += 1
    167 

<ipython-input-3-1c9eda353b43> in update_core(self)
     35         batchsize = len(batch)
     36 
---> 37         y_real = dis(x_real)
     38 
     39         z = Variable(xp.asarray(gen.make_hidden(batchsize)))

/usr/local/lib/python3.6/dist-packages/chainer/link.py in __call__(self, *args, **kwargs)
    240         if forward is None:
    241             forward = self.forward
--> 242         out = forward(*args, **kwargs)
    243 
    244         # Call forward_postprocess hook

<ipython-input-2-8321f7283f65> in forward(self, x)
     79         h = F.leaky_relu(add_noise(self.bn2_1(self.c2_1(h))))
     80         h = F.leaky_relu(add_noise(self.bn3_0(self.c3_0(h))))
---> 81         return self.l4(h)

/usr/local/lib/python3.6/dist-packages/chainer/link.py in __call__(self, *args, **kwargs)
    240         if forward is None:
    241             forward = self.forward
--> 242         out = forward(*args, **kwargs)
    243 
    244         # Call forward_postprocess hook

/usr/local/lib/python3.6/dist-packages/chainer/links/connection/linear.py in forward(self, x, n_batch_axes)
    136             in_size = functools.reduce(operator.mul, x.shape[1:], 1)
    137             self._initialize_params(in_size)
--> 138         return linear.linear(x, self.W, self.b, n_batch_axes=n_batch_axes)

/usr/local/lib/python3.6/dist-packages/chainer/functions/connection/linear.py in linear(x, W, b, n_batch_axes)
    287         args = x, W, b
    288 
--> 289     y, = LinearFunction().apply(args)
    290     if n_batch_axes > 1:
    291         y = y.reshape(batch_shape + (-1,))

/usr/local/lib/python3.6/dist-packages/chainer/function_node.py in apply(self, inputs)
    243 
    244         if configuration.config.type_check:
--> 245             self._check_data_type_forward(in_data)
    246 
    247         hooks = chainer.get_function_hooks()

/usr/local/lib/python3.6/dist-packages/chainer/function_node.py in _check_data_type_forward(self, in_data)
    328         in_type = type_check.get_types(in_data, 'in_types', False)
    329         with type_check.get_function_check_context(self):
--> 330             self.check_type_forward(in_type)
    331 
    332     def check_type_forward(self, in_types):

/usr/local/lib/python3.6/dist-packages/chainer/functions/connection/linear.py in check_type_forward(self, in_types)
     25             x_type.ndim == 2,
     26             w_type.ndim == 2,
---> 27             x_type.shape[1] == w_type.shape[1],
     28         )
     29         if type_check.eval(n_in) == 3:

/usr/local/lib/python3.6/dist-packages/chainer/utils/type_check.py in expect(*bool_exprs)
    544         for expr in bool_exprs:
    545             assert isinstance(expr, Testable)
--> 546             expr.expect()
    547 
    548 

/usr/local/lib/python3.6/dist-packages/chainer/utils/type_check.py in expect(self)
    481             raise InvalidType(
    482                 '{0} {1} {2}'.format(self.lhs, self.exp, self.rhs),
--> 483                 '{0} {1} {2}'.format(left, self.inv, right))
    484 
    485 

InvalidType: 
Invalid operation is performed in: LinearFunction (Forward)

Expect: x.shape[1] == W.shape[1]
Actual: 1384448 != 8192

鉴别器

class Discriminator(chainer.Chain):

    def __init__(self, bottom_width=4, ch=512, wscale=0.02):
        w = chainer.initializers.Normal(wscale)
        super(Discriminator, self).__init__()
        with self.init_scope():
            self.c0_0 = L.Convolution2D(3, ch // 8, 3, 1, 1, initialW=w)
            self.c0_1 = L.Convolution2D(ch // 8, ch // 4, 4, 2, 1, initialW=w)
            self.c1_0 = L.Convolution2D(ch // 4, ch // 4, 3, 1, 1, initialW=w)
            self.c1_1 = L.Convolution2D(ch // 4, ch // 2, 4, 2, 1, initialW=w)
            self.c2_0 = L.Convolution2D(ch // 2, ch // 2, 3, 1, 1, initialW=w)
            self.c2_1 = L.Convolution2D(ch // 2, ch // 1, 4, 2, 1, initialW=w)
            self.c3_0 = L.Convolution2D(ch // 1, ch // 1, 3, 1, 1, initialW=w)
            self.l4 = L.Linear(bottom_width * bottom_width * ch, 1, initialW=w)
            self.bn0_1 = L.BatchNormalization(ch // 4, use_gamma=False)
            self.bn1_0 = L.BatchNormalization(ch // 4, use_gamma=False)
            self.bn1_1 = L.BatchNormalization(ch // 2, use_gamma=False)
            self.bn2_0 = L.BatchNormalization(ch // 2, use_gamma=False)
            self.bn2_1 = L.BatchNormalization(ch // 1, use_gamma=False)
            self.bn3_0 = L.BatchNormalization(ch // 1, use_gamma=False)

    def forward(self, x):
        h = add_noise(x)
        h = F.leaky_relu(add_noise(self.c0_0(h)))
        h = F.leaky_relu(add_noise(self.bn0_1(self.c0_1(h))))
        h = F.leaky_relu(add_noise(self.bn1_0(self.c1_0(h))))
        h = F.leaky_relu(add_noise(self.bn1_1(self.c1_1(h))))
        h = F.leaky_relu(add_noise(self.bn2_0(self.c2_0(h))))
        h = F.leaky_relu(add_noise(self.bn2_1(self.c2_1(h))))
        h = F.leaky_relu(add_noise(self.bn3_0(self.c3_0(h))))
        return self.l4(h)

1 个答案:

答案 0 :(得分:1)

所以问题出在鉴别器中bottom_width的值应该等于(image size)/(2^3)。在这种情况下,它将是420/2^3 = 52.5,但是如果结果是浮点型,则您将获得52的int值