TensorFlow FixedLengthRecordReader占位符数据类型错误

时间:2017-02-27 01:03:40

标签: python tensorflow

我正在尝试使用FixedLengthRecordReader创建一个使用TensorFlow的CNN。

import tensorflow as tf
import numpy as np

BATCH_SIZE = 100
IMAGE_SIZE = 30
IMAGE_DEPTH = 3
image_data_len = IMAGE_SIZE * IMAGE_SIZE * 3

class Record(object):
    pass

def init_weights(shape):
    return tf.Variable(tf.random_normal(shape, stddev=0.01))

def generate_batch(image, label, min_queue_examples, batch_size):
  num_preprocess_threads = 16
  images, label_batch = tf.train.shuffle_batch(
      [image, label],
      batch_size=batch_size,
      num_threads=num_preprocess_threads,
      capacity=min_queue_examples + 3 * batch_size,
      min_after_dequeue=min_queue_examples)
  return images, tf.reshape(label_batch, [batch_size])

def model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden):
    l1a = tf.nn.relu(tf.nn.conv2d(X, w, strides=[1, 1, 1, 1], padding='SAME'))
    l1 = tf.nn.max_pool(l1a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
    l1 = tf.nn.dropout(l1, p_keep_conv)

    l2a = tf.nn.relu(tf.nn.conv2d(l1, w2, strides=[1, 1, 1, 1], padding='SAME'))
    l2 = tf.nn.max_pool(l2a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
    l2 = tf.nn.dropout(l2, p_keep_conv)

    l3a = tf.nn.relu(tf.nn.conv2d(l2, w3, strides=[1, 1, 1, 1], padding='SAME'))
    l3 = tf.nn.max_pool(l3a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
    l3 = tf.reshape(l3, [-1, w4.get_shape().as_list()[0]])
    l3 = tf.nn.dropout(l3, p_keep_conv)

    l4 = tf.nn.relu(tf.matmul(l3, w4))
    l4 = tf.nn.dropout(l4, p_keep_hidden)

    pyx = tf.matmul(l4, w_o)
    return pyx

reader = tf.FixedLengthRecordReader(record_bytes=(1 + image_data_len))
filenames = ['train.bin']
filename_queue = tf.train.string_input_producer(filenames)
result = Record()
result.key, value = reader.read(filename_queue)
record_bytes = tf.decode_raw(value, tf.uint8)
result.label = tf.cast(tf.slice(record_bytes, [0], [1]), tf.int32)
depth_major = tf.reshape(tf.slice(record_bytes, [1], [image_data_len]), [IMAGE_DEPTH, IMAGE_SIZE, IMAGE_SIZE])
result.uint8image = tf.transpose(depth_major, [1, 2, 0])
float_image = tf.image.per_image_standardization(tf.cast(result.uint8image, tf.float32))

min_fraction_of_examples_in_queue = 0.4
min_queue_examples = int(10000 * min_fraction_of_examples_in_queue)
batch_images, batch_labels = generate_batch(float_image, result.label, min_queue_examples, BATCH_SIZE)

X = tf.placeholder("float", [None, IMAGE_SIZE, IMAGE_SIZE, IMAGE_DEPTH])
Y = tf.placeholder("float", [None, 1])

w = init_weights([3, 3, IMAGE_DEPTH, 32])   
w2 = init_weights([3, 3, 32, 64])    
w3 = init_weights([3, 3, 64, 128])   
w4 = init_weights([128 * 4 * 4, 625])
w_o = init_weights([625, 1])         

p_keep_conv = tf.placeholder("float")
p_keep_hidden = tf.placeholder("float")
py_x = model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden)

cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=py_x, labels=Y))
train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)
predict_op = tf.argmax(py_x, 1)

init = tf.global_variables_initializer()

sess = tf.Session()
sess.run(init)
tf.train.start_queue_runners(sess=sess)
for step in range(1000):
  _, loss_value = sess.run([train_op, cost])

在最后一行,我得到以下内容:

InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'Placeholder' with dtype float
[[Node: Placeholder = Placeholder[dtype=DT_FLOAT, shape=[], _device="/job:localhost/replica:0/task:0/cpu:0"]()]]

有什么想法吗?

2 个答案:

答案 0 :(得分:1)

您需要将尺寸[?, IMAGE_SIZE, IMAGE_SIZE, IMAGE_DEPTH]的图片(一批图片)和尺寸[?, 1]的标签(一批标签)传递给sess.run()来电feed_dict }。

由于您有一批名为batch_images的图片和一批名为batch_labels的标签,因此您的最后一行应为:

images, labels = sess.run([batch_images, batch_labels])
_, loss_value = sess.run([train_op, cost], feed_dict={X: images, Y: labels})

答案 1 :(得分:0)

您已定义tf.placeholder() XY,但在使用sess.run()时,您并未向这些占位符提供任何价值。执行以下操作:

for step in range(1000):
  img, lbl = sess.run([batch_images, batch_labels])
  _, loss_value = sess.run([train_op, cost],feed_dict={X: batch_images,\
                           Y: batch_labels})