我发现,即使批次中的所有数据都在内存中,当批次大小很大时,从tensorflow数据集API获取一批也可能非常慢。以下是一个示例。有人有见识吗?
FEATURE_NUM = 500
tf_X = tf.placeholder(dtype=tf.float32, shape=[None, FEATURE_NUM], name="X")
tf_Y = tf.placeholder(dtype=tf.float32, shape=[None, 1], name="Y")
batch_size = 1000000
dataset = tf.data.Dataset.from_tensor_slices((tf_X, tf_Y)).batch(batch_size)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
se = tf.Session()
se.run(tf.global_variables_initializer())
se.run(iterator.initializer, feed_dict={tf_X : numpy_array_X, tf_Y : numpy_array_Y})
while True:
data = se.run(next_element) # This takes more than 5 seconds per call