在张量流中预加载数据,性能不佳

时间:2017-04-18 13:13:30

标签: python performance tensorflow profiling

我跟随example预先加载了整个数据集

with tf.variable_scope("input"):
    X_train_tf = tf.constant(X_train, dtype=tf.float32)
    data_queue = tf.train.slice_input_producer([X_train_tf], capacity=4096, shuffle=True)
    data_batch = tf.train.batch(data_queue, batch_size=batch_size, capacity=4096)
...
create_model(data_batch)
...
sess.run(training_step)

其中X_train是形状(3100,10,9,9)的numpy ndarray,然后我使用data_batch,因为我之前在图中使用了tf.placeholder。

这样可行,但它实际上比我刚使用feed_dict时要慢。我的GPU使用率大约是40%,所以我猜这部分是瓶颈。

输入生成器和批处理队列隐式地添加了一个摘要op,它们告诉我批处理队列永远不会满。来自tf.train.batch的摘要op“input / batch / fraction_of_4096_full”实际上经常报告0.0,但我不确定原因。

我正在使用tf.train.MonitoredTrainingSession,它应该处理启动队列运行器并初始化所有变量。

TLDR:为什么tf.constant + tf.train.slice_input_producer + tf.train.batch导致批次大部分为空?

0 个答案:

没有答案