这是主要代码。图片尺寸为420 * 420
import os
import chainer
from chainer import training
from chainer.training import extensions
batchsize = 64
epoch = 10
gpu = 0
dataset = "/content/gdrive/My Drive/images/"
out = "/content/gdrive/My Drive/DCGAN/"
resume = ""
n_hidden = 100
seed = 0
snapshot_interval = 200
display_interval = 100
# Set up a neural network to train
gen = Generator(n_hidden=n_hidden)
dis = Discriminator()
if gpu >= 0:
# Make a specified GPU current
chainer.backends.cuda.get_device_from_id(gpu).use()
gen.to_gpu() # Copy the model to the GPU
dis.to_gpu()
# Setup an optimizer
def make_optimizer(model, alpha=0.0002, beta1=0.5):
optimizer = chainer.optimizers.Adam(alpha=alpha, beta1=beta1)
optimizer.setup(model)
optimizer.add_hook(
chainer.optimizer_hooks.WeightDecay(0.0001), 'hook_dec')
return optimizer
opt_gen = make_optimizer(gen)
opt_dis = make_optimizer(dis)
if dataset == '':
# Load the CIFAR10 dataset if args.dataset is not specified
train, _ = chainer.datasets.get_cifar10(withlabel=False, scale=255.)
else:
all_files = os.listdir(dataset)
image_files = [f for f in all_files if ('png' in f or 'jpg' in f)]
print('{} contains {} image files'
.format(dataset, len(image_files)))
train = chainer.datasets\
.ImageDataset(paths=image_files, root=dataset)
# Setup an iterator
train_iter = chainer.iterators.SerialIterator(train, batchsize)
# Setup an updater
updater = DCGANUpdater(
models=(gen, dis),
iterator=train_iter,
optimizer={
'gen': opt_gen, 'dis': opt_dis},
device=gpu)
# Setup a trainer
trainer = training.Trainer(updater, (epoch, 'epoch'), out=out)
snapshot_interval = (snapshot_interval, 'iteration')
display_interval = (display_interval, 'iteration')
trainer.extend(
extensions.snapshot(filename='snapshot_iter_{.updater.iteration}.npz'),
trigger=snapshot_interval)
trainer.extend(extensions.snapshot_object(
gen, 'gen_iter_{.updater.iteration}.npz'), trigger=snapshot_interval)
trainer.extend(extensions.snapshot_object(
dis, 'dis_iter_{.updater.iteration}.npz'), trigger=snapshot_interval)
trainer.extend(extensions.LogReport(trigger=display_interval))
trainer.extend(extensions.PrintReport([
'epoch', 'iteration', 'gen/loss', 'dis/loss',
]), trigger=display_interval)
trainer.extend(extensions.ProgressBar(update_interval=10))
trainer.extend(
out_generated_image(
gen, dis,
420, 420, seed, out),
trigger=snapshot_interval)
if resume:
# Resume from a snapshot
chainer.serializers.load_npz(resume, trainer)
# Run the training
trainer.run()
完整的错误代码
Exception in main training loop:
Invalid operation is performed in: LinearFunction (Forward)
Expect: x.shape[1] == W.shape[1]
Actual: 1384448 != 8192
Traceback (most recent call last):
File "/usr/local/lib/python3.6/dist-packages/chainer/training/trainer.py", line 315, in run
update()
File "/usr/local/lib/python3.6/dist-packages/chainer/training/updaters/standard_updater.py", line 165, in update
self.update_core()
File "<ipython-input-3-1c9eda353b43>", line 37, in update_core
y_real = dis(x_real)
File "/usr/local/lib/python3.6/dist-packages/chainer/link.py", line 242, in __call__
out = forward(*args, **kwargs)
File "<ipython-input-2-8321f7283f65>", line 81, in forward
return self.l4(h)
File "/usr/local/lib/python3.6/dist-packages/chainer/link.py", line 242, in __call__
out = forward(*args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/chainer/links/connection/linear.py", line 138, in forward
return linear.linear(x, self.W, self.b, n_batch_axes=n_batch_axes)
File "/usr/local/lib/python3.6/dist-packages/chainer/functions/connection/linear.py", line 289, in linear
y, = LinearFunction().apply(args)
File "/usr/local/lib/python3.6/dist-packages/chainer/function_node.py", line 245, in apply
self._check_data_type_forward(in_data)
File "/usr/local/lib/python3.6/dist-packages/chainer/function_node.py", line 330, in _check_data_type_forward
self.check_type_forward(in_type)
File "/usr/local/lib/python3.6/dist-packages/chainer/functions/connection/linear.py", line 27, in check_type_forward
x_type.shape[1] == w_type.shape[1],
File "/usr/local/lib/python3.6/dist-packages/chainer/utils/type_check.py", line 546, in expect
expr.expect()
File "/usr/local/lib/python3.6/dist-packages/chainer/utils/type_check.py", line 483, in expect
'{0} {1} {2}'.format(left, self.inv, right))
Will finalize trainer extensions and updater before reraising the exception.
---------------------------------------------------------------------------
InvalidType Traceback (most recent call last)
<ipython-input-8-a0fb675be455> in <module>()
89
90 # Run the training
---> 91 trainer.run()
92
93
/usr/local/lib/python3.6/dist-packages/chainer/training/trainer.py in run(self, show_loop_exception_msg)
327 f.write('Will finalize trainer extensions and updater before '
328 'reraising the exception.\n')
--> 329 six.reraise(*sys.exc_info())
330 finally:
331 for _, entry in extensions:
/usr/local/lib/python3.6/dist-packages/six.py in reraise(tp, value, tb)
691 if value.__traceback__ is not tb:
692 raise value.with_traceback(tb)
--> 693 raise value
694 finally:
695 value = None
/usr/local/lib/python3.6/dist-packages/chainer/training/trainer.py in run(self, show_loop_exception_msg)
313 self.observation = {}
314 with reporter.scope(self.observation):
--> 315 update()
316 for name, entry in extensions:
317 if entry.trigger(self):
/usr/local/lib/python3.6/dist-packages/chainer/training/updaters/standard_updater.py in update(self)
163
164 """
--> 165 self.update_core()
166 self.iteration += 1
167
<ipython-input-3-1c9eda353b43> in update_core(self)
35 batchsize = len(batch)
36
---> 37 y_real = dis(x_real)
38
39 z = Variable(xp.asarray(gen.make_hidden(batchsize)))
/usr/local/lib/python3.6/dist-packages/chainer/link.py in __call__(self, *args, **kwargs)
240 if forward is None:
241 forward = self.forward
--> 242 out = forward(*args, **kwargs)
243
244 # Call forward_postprocess hook
<ipython-input-2-8321f7283f65> in forward(self, x)
79 h = F.leaky_relu(add_noise(self.bn2_1(self.c2_1(h))))
80 h = F.leaky_relu(add_noise(self.bn3_0(self.c3_0(h))))
---> 81 return self.l4(h)
/usr/local/lib/python3.6/dist-packages/chainer/link.py in __call__(self, *args, **kwargs)
240 if forward is None:
241 forward = self.forward
--> 242 out = forward(*args, **kwargs)
243
244 # Call forward_postprocess hook
/usr/local/lib/python3.6/dist-packages/chainer/links/connection/linear.py in forward(self, x, n_batch_axes)
136 in_size = functools.reduce(operator.mul, x.shape[1:], 1)
137 self._initialize_params(in_size)
--> 138 return linear.linear(x, self.W, self.b, n_batch_axes=n_batch_axes)
/usr/local/lib/python3.6/dist-packages/chainer/functions/connection/linear.py in linear(x, W, b, n_batch_axes)
287 args = x, W, b
288
--> 289 y, = LinearFunction().apply(args)
290 if n_batch_axes > 1:
291 y = y.reshape(batch_shape + (-1,))
/usr/local/lib/python3.6/dist-packages/chainer/function_node.py in apply(self, inputs)
243
244 if configuration.config.type_check:
--> 245 self._check_data_type_forward(in_data)
246
247 hooks = chainer.get_function_hooks()
/usr/local/lib/python3.6/dist-packages/chainer/function_node.py in _check_data_type_forward(self, in_data)
328 in_type = type_check.get_types(in_data, 'in_types', False)
329 with type_check.get_function_check_context(self):
--> 330 self.check_type_forward(in_type)
331
332 def check_type_forward(self, in_types):
/usr/local/lib/python3.6/dist-packages/chainer/functions/connection/linear.py in check_type_forward(self, in_types)
25 x_type.ndim == 2,
26 w_type.ndim == 2,
---> 27 x_type.shape[1] == w_type.shape[1],
28 )
29 if type_check.eval(n_in) == 3:
/usr/local/lib/python3.6/dist-packages/chainer/utils/type_check.py in expect(*bool_exprs)
544 for expr in bool_exprs:
545 assert isinstance(expr, Testable)
--> 546 expr.expect()
547
548
/usr/local/lib/python3.6/dist-packages/chainer/utils/type_check.py in expect(self)
481 raise InvalidType(
482 '{0} {1} {2}'.format(self.lhs, self.exp, self.rhs),
--> 483 '{0} {1} {2}'.format(left, self.inv, right))
484
485
InvalidType:
Invalid operation is performed in: LinearFunction (Forward)
Expect: x.shape[1] == W.shape[1]
Actual: 1384448 != 8192
鉴别器
class Discriminator(chainer.Chain):
def __init__(self, bottom_width=4, ch=512, wscale=0.02):
w = chainer.initializers.Normal(wscale)
super(Discriminator, self).__init__()
with self.init_scope():
self.c0_0 = L.Convolution2D(3, ch // 8, 3, 1, 1, initialW=w)
self.c0_1 = L.Convolution2D(ch // 8, ch // 4, 4, 2, 1, initialW=w)
self.c1_0 = L.Convolution2D(ch // 4, ch // 4, 3, 1, 1, initialW=w)
self.c1_1 = L.Convolution2D(ch // 4, ch // 2, 4, 2, 1, initialW=w)
self.c2_0 = L.Convolution2D(ch // 2, ch // 2, 3, 1, 1, initialW=w)
self.c2_1 = L.Convolution2D(ch // 2, ch // 1, 4, 2, 1, initialW=w)
self.c3_0 = L.Convolution2D(ch // 1, ch // 1, 3, 1, 1, initialW=w)
self.l4 = L.Linear(bottom_width * bottom_width * ch, 1, initialW=w)
self.bn0_1 = L.BatchNormalization(ch // 4, use_gamma=False)
self.bn1_0 = L.BatchNormalization(ch // 4, use_gamma=False)
self.bn1_1 = L.BatchNormalization(ch // 2, use_gamma=False)
self.bn2_0 = L.BatchNormalization(ch // 2, use_gamma=False)
self.bn2_1 = L.BatchNormalization(ch // 1, use_gamma=False)
self.bn3_0 = L.BatchNormalization(ch // 1, use_gamma=False)
def forward(self, x):
h = add_noise(x)
h = F.leaky_relu(add_noise(self.c0_0(h)))
h = F.leaky_relu(add_noise(self.bn0_1(self.c0_1(h))))
h = F.leaky_relu(add_noise(self.bn1_0(self.c1_0(h))))
h = F.leaky_relu(add_noise(self.bn1_1(self.c1_1(h))))
h = F.leaky_relu(add_noise(self.bn2_0(self.c2_0(h))))
h = F.leaky_relu(add_noise(self.bn2_1(self.c2_1(h))))
h = F.leaky_relu(add_noise(self.bn3_0(self.c3_0(h))))
return self.l4(h)
答案 0 :(得分:1)
所以问题出在鉴别器中bottom_width
的值应该等于(image size)/(2^3)
。在这种情况下,它将是420/2^3 = 52.5
,但是如果结果是浮点型,则您将获得52
的int值