我正在使用TensorFlow v0.8,奇怪的是打印第二个print time.time()
需要大约5分钟。我认为tf.decode_csv()
只是简单地在图形中添加一个操作而不进行任何计算。
为什么拨打tf.decode_csv()
需要这么长时间?
def main(argv=None):
# deal with arguments
with tf.device("/cpu:0"):
filename_queue = tf.train.string_input_producer(tf.train.match_filenames_once(train_set_filename + "*"))
reader = tf.TextLineReader()
_, line = reader.read(filename_queue)
default = [[-1.0] for x in range(image_size * image_size * channels + 1)]
print time.time()
line = tf.decode_csv(line, record_defaults=default)
print time.time()
label = line[0]
feature = tf.pack(list(line[1:]))
...
答案 0 :(得分:1)
tf.decode_csv(line, record_defaults=default)
需要花费很多时间,因为你使用了很多列
我不知道你的image_size,但是如果它在200
左右,你试图将120,001
列设置为你的csv,这是巨大的。你是对的,TensorFlow没有做任何计算,但它必须正确地构建图形,并且需要很多时间才能使用那么多列!
我强烈建议您不要使用csv格式的图片。相反,您应该以JPEG格式存储图像,并使用tf.image.decode_jpeg()
。