Tensorflow数据集API缓慢以大批量获取数据

时间:2018-10-03 18:20:59

标签: performance tensorflow

我发现,即使批次中的所有数据都在内存中,当批次大小很大时,从tensorflow数据集API获取一批也可能非常慢。以下是一个示例。有人有见识吗?

FEATURE_NUM = 500
tf_X = tf.placeholder(dtype=tf.float32, shape=[None, FEATURE_NUM], name="X")
tf_Y = tf.placeholder(dtype=tf.float32, shape=[None, 1], name="Y")

batch_size = 1000000
dataset = tf.data.Dataset.from_tensor_slices((tf_X, tf_Y)).batch(batch_size)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

se = tf.Session()
se.run(tf.global_variables_initializer())
se.run(iterator.initializer, feed_dict={tf_X : numpy_array_X, tf_Y : numpy_array_Y})

while True:
    data = se.run(next_element) # This takes more than 5 seconds per call

0 个答案:

没有答案