我正在尝试使用输入管道加载csv文件。由于以下错误,我在线跟踪了一些文档,但未能复制它们。
InvalidArgumentError (see above for traceback): Expect 6 fields but have 751 in record 0
[[Node: DecodeCSV_1 = DecodeCSV[OUT_TYPE=[DT_STRING, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_STRING], field_delim=",", _device="/job:localhost/replica:0/task:0/cpu:0"](ReaderReadV2_1:1, DecodeCSV_1/record_defaults_0, DecodeCSV_1/record_defaults_1, DecodeCSV_1/record_defaults_2, DecodeCSV_1/record_defaults_3, DecodeCSV_1/record_defaults_4, DecodeCSV_1/record_defaults_5)]]
好像我遇到了换行符分隔符问题。我将不胜感激任何反馈。请参阅以下步骤以复制问题。
我使用链接https://vincentarelbundock.github.io/Rdatasets/csv/datasets/iris.csv将iris数据集下载到我的本地,并删除了标题
以下CSV格式:
"1",5.1,3.5,1.4,0.2,"setosa"
"2",4.9,3,1.4,0.2,"setosa"
"3",4.7,3.2,1.3,0.2,"setosa"
我的代码如下:
import tensorflow as tf
def read_my_file_format(filename_queue):
reader = tf.TextLineReader(skip_header_lines=0)
key, value = reader.read(filename_queue)
record_defaults = [[""], [0.0], [0.0], [0.0], [0.0], [""]]
index, slength, swidth, plength, pwidth, species = tf.decode_csv(value, record_defaults=record_defaults, field_delim=',')
features = tf.stack([slength, swidth, plength, pwidth])
return features, [species]
def input_pipeline(filepaths, batch_size, num_epochs=None):
filename_queue = tf.train.string_input_producer(filepaths, num_epochs=num_epochs, shuffle=True)
features, label = read_my_file_format(filename_queue)
min_after_dequeue = 10000
capacity = min_after_dequeue + 3 * batch_size
example_batch, label_batch = tf.train.shuffle_batch([features, label], batch_size=batch_size, capacity=capacity, min_after_dequeue=min_after_dequeue)
return example_batch, label_batch
example_batch, label_batch = input_pipeline(filepaths=["/Users/iiskin/Downloads/iris.csv"],batch_size=10,num_epochs=10)
with tf.Session() as sess:
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
while not coord.should_stop():
features, label = sess.run([example_batch, label_batch])
print features
except tf.errors.OutOfRangeError:
print('Done -- epoch limit reached')
finally:
coord.request_stop()
coord.join(threads)
我使用的是tensorflow版本:1.0.0-rc2