我跟随example预先加载了整个数据集
with tf.variable_scope("input"):
X_train_tf = tf.constant(X_train, dtype=tf.float32)
data_queue = tf.train.slice_input_producer([X_train_tf], capacity=4096, shuffle=True)
data_batch = tf.train.batch(data_queue, batch_size=batch_size, capacity=4096)
...
create_model(data_batch)
...
sess.run(training_step)
其中X_train是形状(3100,10,9,9)的numpy ndarray,然后我使用data_batch,因为我之前在图中使用了tf.placeholder。
这样可行,但它实际上比我刚使用feed_dict时要慢。我的GPU使用率大约是40%,所以我猜这部分是瓶颈。
输入生成器和批处理队列隐式地添加了一个摘要op,它们告诉我批处理队列永远不会满。来自tf.train.batch的摘要op“input / batch / fraction_of_4096_full”实际上经常报告0.0,但我不确定原因。
我正在使用tf.train.MonitoredTrainingSession,它应该处理启动队列运行器并初始化所有变量。
TLDR:为什么tf.constant + tf.train.slice_input_producer + tf.train.batch导致批次大部分为空?