我正在尝试批量阅读heart.csv文件数据。根据{{3}}网站的文档,我有以下代码逐行阅读
import tensorflow as tf
filename_queue = tf.train.string_input_producer(["heart.csv"])
reader = tf.TextLineReader(skip_header_lines=1)
_, csv_row = reader.read(filename_queue)
record_defaults = [[0], [0.0], [0.0], [0.0], [""], [0], [0.0], [0.0], [0], [0]]
sbp, tobacco, ldl, adiposity, famhist, typea, obesity, alcohol, age, chd = tf.decode_csv(csv_row, record_defaults=record_defaults)
features = [sbp, tobacco, ldl, adiposity, famhist, typea, obesity, alcohol, age]
nof_examples = 10
with tf.Session() as sess:
tf.global_variables_initializer().run()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
while nof_examples > 0:
nof_examples -= 1
try:
data_features, data_chd = sess.run([features, chd])
# data_features[4] = 1 if data_features[4] == 'Present' else 0
print(data_features, data_chd)
except tf.errors.OutOfRangeError:
coord.request_stop()
coord.join(threads)
break
coord.request_stop()
coord.join(threads)
输出:
([160, 12.0, 5.73, 23.110001, 'Present', 49, 25.299999, 97.199997, 52], 1)
([144, 0.0099999998, 4.4099998, 28.610001, 'Absent', 55, 28.870001, 2.0599999, 63], 1)
([118, 0.079999998, 3.48, 32.279999, 'Present', 52, 29.139999, 3.8099999, 46], 0)
([170, 7.5, 6.4099998, 38.029999, 'Present', 51, 31.99, 24.26, 58], 1)
([134, 13.6, 3.5, 27.780001, 'Present', 60, 25.99, 57.34, 49], 1)
([132, 6.1999998, 6.4699998, 36.209999, 'Present', 62, 30.77, 14.14, 45], 0)
([142, 4.0500002, 3.3800001, 16.200001, 'Absent', 59, 20.809999, 2.6199999, 38], 0)
([114, 4.0799999, 4.5900002, 14.6, 'Present', 62, 23.110001, 6.7199998, 58], 1)
([114, 0.0, 3.8299999, 19.4, 'Present', 49, 24.860001, 2.49, 29], 0)
([132, 0.0, 5.8000002, 30.959999, 'Present', 69, 30.110001, 0.0, 53], 1)
但是当我尝试按张量流文档中显示的批量阅读时,我得到了
TypeError: Cannot convert a list containing a tensor of dtype <dtype:
float32'> to <dtype: 'int32'> (Tensor is: <tf.Tensor 'DecodeCSV_6:1'
shape=() dtype=float32>)
批处理代码
import tensorflow as tf
batch_size = 1
def read_my_file_format(filename_queue):
reader = tf.TextLineReader(skip_header_lines=1)
_, csv_row = reader.read(filename_queue)
record_defaults = [[0], [0.0], [0.0], [0.0], [""], [0], [0.0], [0.0], [0], [0]]
sbp, tobacco, ldl, adiposity, famhist, typea, obesity, alcohol, age, chd = tf.decode_csv(csv_row, record_defaults=record_defaults)
feature = [sbp, tobacco, ldl, adiposity, famhist, typea, obesity, alcohol, age]
label = [chd]
return feature, label
def input_pipeline(filenames, batch_size, num_epochs=None):
filename_queue = tf.train.string_input_producer(filenames,
num_epochs=num_epochs,
shuffle=True)
feature, label = read_my_file_format(filename_queue)
min_after_dequeue = 10000
capacity = min_after_dequeue + 3 * batch_size
feature_batch, label_batch = tf.train.shuffle_batch([feature, label],
batch_size=batch_size,
capacity=capacity,
min_after_dequeue=min_after_dequeue)
return feature_batch, label_batch
features, labels = input_pipeline(['heart.csv'], batch_size)
with tf.Session() as sess:
tf.global_variables_initializer().run()
# start populating filename queue
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
try:
while not coord.should_stop():
feature_batch, label_batch = sess.run([features, labels])
print(feature_batch)
except tf.errors.OutOfRangeError:
print('Done training, epoch reached')
finally:
coord.request_stop()
coord.join(threads)
使用tensorflow读取CSV文件似乎有点麻烦,但我确信它在作为分布式系统的库中具有重要性。我发现它令人困惑,花了60多分钟阅读并掌握了读取源管道如何为csv文件工作。可能文档应该更好,需要更多的视觉效果。
答案 0 :(得分:1)
我查看了代码,看来tf.train.shuffle_batch
中的一个内部函数要求行中的所有张量都具有相同的dtype
(从第一个元素推断出来,在你身上案例a tf.int32
)。你可以用字符串解码它们,然后在正确的类型中转换它们。不太方便。
但是我建议您在使用TensorFlow 1.2.0时建议使用新的DataSet API,这是处理数据时的新方法(参见例如this answer)。
根据引用的答案,以下是使用新API的示例:
def read_row(csv_row):
record_defaults = [[0], [0.0], [0.0], [0.0], [""], [0], [0.0], [0.0], [0], [0]]
row = tf.decode_csv(csv_row, record_defaults=record_defaults)
return row[:-1], row[-1]
def input_pipeline(filenames, batch_size):
# Define a `tf.contrib.data.Dataset` for iterating over one epoch of the data.
dataset = (tf.contrib.data.TextLineDataset(filenames)
.skip(1)
.map(lambda line: read_row(line))
.shuffle(buffer_size=10) # Equivalent to min_after_dequeue=10.
.batch(batch_size))
# Return an *initializable* iterator over the dataset, which will allow us to
# re-initialize it at the beginning of each epoch.
return dataset.make_initializable_iterator()
iterator = input_pipeline(['heart.csv'], batch_size)
features, labels = iterator.get_next()
nof_examples = 10
with tf.Session() as sess:
tf.global_variables_initializer().run()
sess.run(iterator.initializer)
while nof_examples > 0:
nof_examples -= 1
try:
data_features, data_labels = sess.run([features, labels])
print(data_features)
except tf.errors.OutOfRangeError:
pass