python tf.sess.run()中的feed_dict尺寸问题

时间:2018-12-02 22:21:46

标签: python tensorflow

我正在celebA数据集上运行这个model.py,方法是使用一些FLAGS调用它的train方法

if FLAGS.train: dcgan.train(FLAGS)

Model.py

但是我遇到以下错误:

Reloaded modules: ops, utils
Flags are already set
{'batch_size': <absl.flags._flag.Flag object at 0x1c37f2ccc0>,
 'beta1': <absl.flags._flag.Flag object at 0xb34003cc0>,
 'checkpoint_dir': <absl.flags._flag.Flag object at 0x1c380370f0>,
 'crop': <absl.flags._flag.BooleanFlag object at 0x1c380372e8>,
 'data_dir': <absl.flags._flag.Flag object at 0x1c38037160>,
 'dataset': <absl.flags._flag.Flag object at 0x1c37f2cf98>,
 'epoch': <absl.flags._flag.Flag object at 0x11b2cd748>,
 'generate_test_images': <absl.flags._flag.Flag object at 0x1c38037400>,
 'input_fname_pattern': <absl.flags._flag.Flag object at 0x1c38037048>,
 'input_height': <absl.flags._flag.Flag object at 0x1c37f2cd68>,
 'input_width': <absl.flags._flag.Flag object at 0x1c37f2ce10>,
 'learning_rate': <absl.flags._flag.Flag object at 0xb34003c50>,
 'output_height': <absl.flags._flag.Flag object at 0x1c37f2ce80>,
 'output_width': <absl.flags._flag.Flag object at 0x1c37f2cf28>,
 'sample_dir': <absl.flags._flag.Flag object at 0x1c38037208>,
 'train': <absl.flags._flag.BooleanFlag object at 0x1c38037240>,
 'train_set_size': <absl.flags._flag.Flag object at 0x1c380374a8>,
 'train_size': <absl.flags._flag.Flag object at 0x1c37e6af28>,
 'visualize': <absl.flags._flag.BooleanFlag object at 0x1c38037358>,
 'y_dim': <absl.flags._flag.Flag object at 0x1c38037550>}
shape of data = (10000,)
There are a total of 202599 items
Attribute name is Male
We have selected 10000 items ...
First 20 attribute values: [1 0 0 1 0 1 0 0 1 0 0 0 0 1 0 1 0 0 0 1]
---------
Variables: name (type shape) [size]
---------
generator/g_h0_lin/Matrix:0 (float32_ref 140x8192) [1146880, bytes:     4587520]
generator/g_h0_lin/bias:0 (float32_ref 8192) [8192, bytes: 32768]
generator/g_bn0/beta:0 (float32_ref 512) [512, bytes: 2048]
generator/g_bn0/gamma:0 (float32_ref 512) [512, bytes: 2048]
generator/g_h1/w:0 (float32_ref 5x5x256x552) [3532800, bytes: 14131200]
generator/g_h1/biases:0 (float32_ref 256) [256, bytes: 1024]
generator/g_bn1/beta:0 (float32_ref 256) [256, bytes: 1024]
generator/g_bn1/gamma:0 (float32_ref 256) [256, bytes: 1024]
generator/g_h2/w:0 (float32_ref 5x5x128x296) [947200, bytes: 3788800]
generator/g_h2/biases:0 (float32_ref 128) [128, bytes: 512]
generator/g_bn2/beta:0 (float32_ref 128) [128, bytes: 512]
generator/g_bn2/gamma:0 (float32_ref 128) [128, bytes: 512]
generator/g_h3/w:0 (float32_ref 5x5x64x168) [268800, bytes: 1075200]
generator/g_h3/biases:0 (float32_ref 64) [64, bytes: 256]
generator/g_bn3/beta:0 (float32_ref 64) [64, bytes: 256]
generator/g_bn3/gamma:0 (float32_ref 64) [64, bytes: 256]
generator/g_h4/w:0 (float32_ref 5x5x3x104) [7800, bytes: 31200]
generator/g_h4/biases:0 (float32_ref 3) [3, bytes: 12]
discriminator/d_h0_conv/w:0 (float32_ref 5x5x43x43) [46225, bytes: 184900]
discriminator/d_h0_conv/biases:0 (float32_ref 43) [43, bytes: 172]
discriminator/d_h1_conv/w:0 (float32_ref 5x5x83x128) [265600, bytes: 1062400]
discriminator/d_h1_conv/biases:0 (float32_ref 128) [128, bytes: 512]
discriminator/d_bn1/beta:0 (float32_ref 128) [128, bytes: 512]
discriminator/d_bn1/gamma:0 (float32_ref 128) [128, bytes: 512]
discriminator/d_h2_conv/w:0 (float32_ref 5x5x168x256) [1075200, bytes: 4300800]
discriminator/d_h2_conv/biases:0 (float32_ref 256) [256, bytes: 1024]
discriminator/d_bn2/beta:0 (float32_ref 256) [256, bytes: 1024]
discriminator/d_bn2/gamma:0 (float32_ref 256) [256, bytes: 1024]
discriminator/d_h3_conv/w:0 (float32_ref 5x5x296x512) [3788800, bytes: 15155200]
discriminator/d_h3_conv/biases:0 (float32_ref 512) [512, bytes: 2048]
discriminator/d_bn3/beta:0 (float32_ref 512) [512, bytes: 2048]
discriminator/d_bn3/gamma:0 (float32_ref 512) [512, bytes: 2048]
discriminator/d_h4_lin/Matrix:0 (float32_ref 8832x1) [8832, bytes: 35328]
discriminator/d_h4_lin/bias:0 (float32_ref 1) [1, bytes: 4]
Total size of variables: 11101432
Total bytes of variables: 44405728
 [*] Reading checkpoints...
INFO:tensorflow:Restoring parameters from                     checkpoint/celebA_64_64_64/DCGAN.model-2
 [*] Success to read DCGAN.model-2
[*] Load SUCCESS
batch_images.shape=(64, 64, 64, 3), batch_z.shape=(64, 100),         batch_labels.shape=(64,)
Traceback (most recent call last):

  File "<ipython-input-19-a592902117ec>", line 1, in <module>
    runfile('/Users/user/dcgan-argon/dcgan.py', wdir='/Users/user/dcgan-argon')

  File "/Users/user/anaconda3/lib/python3.6/site-    packages/spyder_kernels/customize/spydercustomize.py", line 668, in runfile
    execfile(filename, namespace)

  File "/Users/user/anaconda3/lib/python3.6/site-    packages/spyder_kernels/customize/spydercustomize.py", line 108, in     execfile
    exec(compile(f.read(), filename, 'exec'), namespace)

  File "/Users/user/dcgan-argon/dcgan.py", line 89, in <module>
    main()

  File "/Users/user/dcgan-argon/dcgan.py", line 74, in main
    if FLAGS.train: dcgan.train(FLAGS)

  File "/Users/user/dcgan-argon/model.py", line 305, in train
    feed_dict={self.inputs: batch_images, self.z: batch_z, self.y: batch_labels})

  File "/Users/user/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 929, in run
run_metadata_ptr)

  File "/Users/user/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1128, in _run
str(subfeed_t.get_shape())))

ValueError: Cannot feed value of shape (64,) for Tensor 'y:0', which has shape '(64, 40)'

如果我替换行:

attrs = read_attribute(self.data_dir + "/celebA_attrs.txt",
                               idxes, 20)  # 20 = Male, 15 = Eyeglass
self.attrs = attrs

使用

attrs = read_attribute(self.data_dir + "/celebA_attrs.txt",
                               idxes, 20)  # 20 = Male, 15 = Eyeglass

y = np.repeat(20, batch_size)
y_ = np.zeros((batch_size, 40))
for i in range(batch_size):
   y_[np.arange(batch_size),y] = attrs[i]
self.attrs = y_

我可以跑步1次,但再次失败:

batch_images.shape=(64, 64, 64, 3), batch_z.shape=(64, 100), 
batch_labels.shape=(64, 40)
Epoch: [ 0/25] [   0/ 156] time: 14.3307, d_loss: 9.40660000, g_loss: 4.58669901
batch_images.shape=(64, 64, 64, 3), batch_z.shape=(64, 100), 
batch_labels.shape=(0, 40)
Traceback (most recent call last):
 File "/Users/user/dcgan-argon/dcgan.py", line 89, in <module>
  main()
 File "/Users/user/dcgan-argon/dcgan.py", line 74, in main
  if FLAGS.train: dcgan.train(FLAGS)
 File "/Users/user/dcgan-argon/model.py", line 304, in train
  feed_dict={self.inputs: batch_images, self.z: batch_z, self.y: batch_labels})
 File "/Users/user/venv/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 887, in run
  run_metadata_ptr)
 File "/Users/user/venv/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1086, in _run
  str(subfeed_t.get_shape())))
ValueError: Cannot feed value of shape (0, 40) for Tensor 'y:0', which has shape '(64, 40)'

我在类似问题中尝试了各种答案,例如

batch_labels = np.expand_dims(batch_labels, axis=0)

batch_labels = np.expand_dims(batch_labels, axis=1)

但没有帮助。

如错误输出所示

batch_labels.shape=(64, )

谢谢!

0 个答案:

没有答案