以下是我要尝试运行的代码 -
import tensorflow as tf
import numpy as np
import input_data
filename_queue = tf.train.string_input_producer(["cs-training.csv"])
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
record_defaults = [[1], [1], [1], [1], [1], [1], [1], [1], [1], [1], [1]]
col1, col2, col3, col4, col5, col6, col7, col8, col9, col10, col11 = tf.decode_csv(
value, record_defaults=record_defaults)
features = tf.concat(0, [col2, col3, col4, col5, col6, col7, col8, col9, col10, col11])
with tf.Session() as sess:
# Start populating the filename queue.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(1200):
# Retrieve a single instance:
print i
example, label = sess.run([features, col1])
try:
print example, label
except:
pass
coord.request_stop()
coord.join(threads)
此代码返回以下错误。
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-23-e42fe2609a15> in <module>()
7 # Retrieve a single instance:
8 print i
----> 9 example, label = sess.run([features, col1])
10 try:
11 print example, label
/root/anaconda/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in run(self, fetches, feed_dict)
343
344 # Run request and get response.
--> 345 results = self._do_run(target_list, unique_fetch_targets, feed_dict_string)
346
347 # User may have fetched the same tensor multiple times, but we
/root/anaconda/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _do_run(self, target_list, fetch_list, feed_dict)
417 # pylint: disable=protected-access
418 raise errors._make_specific_exception(node_def, op, e.error_message,
--> 419 e.code)
420 # pylint: enable=protected-access
421 raise e_type, e_value, e_traceback
InvalidArgumentError: Field 1 in record 0 is not a valid int32: 0.766126609
它后面有很多信息,我认为这与问题无关。显然问题是我提供给程序的很多数据都不是dtype int32。它主要是浮点数。我已经尝试了一些改变dtype的方法,例如在dtype=float
和tf.decode_csv
中明确设置tf.concat
参数。都没有奏效。这是一个无效的论点。最重要的是,我不知道这段代码是否真的会对数据进行预测。我希望它能够预测col1是1还是0,并且我不会在代码中看到任何暗示它会实际进行预测的内容。也许我会将这个问题保存为不同的主题。非常感谢任何帮助!
答案 0 :(得分:20)
tf.decode_csv()
的界面有点棘手。每列的dtype
由record_defaults
参数的相应元素确定。代码中record_defaults
的值被解释为每个列都有tf.int32
作为其类型,这会在遇到浮点数据时导致错误。
假设您有以下CSV数据,其中包含三个整数列,后跟一个浮点列:
4, 8, 9, 4.5
2, 5, 1, 3.7
2, 2, 2, 0.1
假设所有列都是必需,您将按如下方式构建record_defaults
:
value = ...
record_defaults = [tf.constant([], dtype=tf.int32), # Column 0
tf.constant([], dtype=tf.int32), # Column 1
tf.constant([], dtype=tf.int32), # Column 2
tf.constant([], dtype=tf.float32)] # Column 3
col0, col1, col2, col3 = tf.decode_csv(value, record_defaults=record_defauts)
assert col0.dtype == tf.int32
assert col1.dtype == tf.int32
assert col2.dtype == tf.int32
assert col3.dtype == tf.float32
record_defaults
中的空值表示该值是必需的。或者,如果允许(例如)第2列具有缺失值,则可以按如下方式定义record_defaults
:
record_defaults = [tf.constant([], dtype=tf.int32), # Column 0
tf.constant([], dtype=tf.int32), # Column 1
tf.constant([0], dtype=tf.int32), # Column 2
tf.constant([], dtype=tf.float32)] # Column 3
您的问题的第二部分涉及如何构建和训练一个模型,该模型可以预测输入数据中某列的值。目前,该程序并不是:它只是将列连接成一个称为features
的张量。您需要定义并训练一个解释该数据的模型。最简单的方法之一是线性回归,您可能会发现linear regression in TensorFlow上的本教程适合您的问题。
答案 1 :(得分:1)
更改dtype的答案是只改变默认值 -
record_defaults = [[1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.]]
执行此操作后,如果您打印出col1,您将收到此消息。
Tensor("DecodeCSV_43:0", shape=TensorShape([]), dtype=float32)
但是您还会遇到另一个错误,which has been answered here.要回顾一下答案,解决方法是将tf.concat
更改为tf.pack
。
features = tf.pack([col2, col3, col4, col5, col6, col7, col8, col9, col10, col11])