Tensorflow - Feed的值不能是tf.Tensor对象

时间:2016-09-14 13:20:46

标签: python tensorflow

我想阅读以下数据然后训练单层感知器:

1,1,0.05,-1.05
1,1,0.1,-1.1
....

这是我阅读数据的方式

total = 10000
def read_file_format(filename_queue):
   # read in data
   reader = tf.TextLineReader()
   key, value = reader.read(filename_queue)

   # Default values, in case of empty columns. Also specifies the type of the
   # decoded result.
   record_defaults = [tf.constant([], dtype=tf.int32),    # Column 1
                      tf.constant([], dtype=tf.int32),    # Column 2
                      tf.constant([], dtype=tf.float32),  # Column 3
                      tf.constant([], dtype=tf.float32)]  # Column 4

   col1, col2, col3, col4 = tf.decode_csv(value, record_defaults=record_defaults)
   features = tf.pack([tf.to_float(col1), tf.to_float(col2), col3])

   with tf.Session() as sess:
      # Start populating the filename queue.
      coord = tf.train.Coordinator()
      threads = tf.train.start_queue_runners(coord=coord)

      for i in range(total):
         # Retrieve a single instance:
         example, label = sess.run([features, col4])

      coord.request_stop()
      coord.join(threads)
   return example, label

filename_queue = tf.train.string_input_producer(["input.data"])
example, label = read_file_format(filename_queue)

我按

创建批次
min_after_dequeue = 100
batch_size = 50
capacity = min_after_dequeue + 3 * batch_size
example_batch, label_batch = tf.train.shuffle_batch([example, label], batch_size=batch_size,
                                                    capacity=capacity,min_after_dequeue=min_after_dequeu

并通过

训练
....
_, c = sess.run([optimizer, cost], feed_dict={x: example_batch, y: label_batch})
....

但它给了我以下错误:

Traceback (most recent call last):
  File "nn.py", line 90, in <module>
    _, c = sess.run([optimizer, cost], feed_dict={x: example_batch, y: label_batch})
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 710, in run
    run_metadata_ptr)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 868, in _run
    raise TypeError('The value of a feed cannot be a tf.Tensor object. '
TypeError: The value of a feed cannot be a tf.Tensor object. Acceptable feed values include Python scalars, strings, lists, or numpy ndarrays.

0 个答案:

没有答案