使用数据集。from_tensor_slices似乎与张量大小成线性比例关系,即使这些张量被裁剪,因此网络也总是获得相同大小的数据补丁。
以这个简化的数据集为例:
ones = tf.ones((1, 1000, 1000))
zeros = tf.zeros((1, 1000, 1000))
return tf.data.Dataset.from_tensor_slices(
(
ones,
tf.ones((1, 1))
)
).concatenate(
tf.data.Dataset.from_tensor_slices(
(
zeros,
tf.zeros((1, 1))
)
)
).map(lambda x, y: (tf.image.random_crop(x, (20, 20)), y)).repeat().batch(10)
如果将ones
和zeros
张量的形状增加到(1,10000,1000)-训练速度将降低10倍,而(1,10000,10000 )的工作速度会慢100倍...
知道为什么吗?