我试图找出如何编写正确的input_fn。我使用Tensorflow Estimators和Datasets重建了keras UNet实现。它工作正常,但它比Keras实现慢得多。 我的输入管道直接从SSD读取文件,而Keras实现将整个DataSet加载到内存中。
我想在DataSet中做同样的事情,遗憾的是ds.cache()
的每次中断都会清除estimator.train
的缓存。我必须打断它以验证模型,因为我试图只使用一个GPU。
我查看了Dataset.cache
的实现,我猜测缓存绑定到包含数据集的图形,每次训练中断时都会释放该图形。
我有哪些选项让我的输入管道保持在张量流图中?我想避免在python和tensorflow之间来回传递整个训练数据。
我可以在另一张图中使用张量吗?所以我可以用input_fn创建一个外部(到Estimator的API)图形,然后以某种方式在estimator.train创建的图形中使用它。
这方面的事情:
train_graph = tf.Graph()
with training_graph.as_default():
train_get_next = input_fn(training_set, ...) () # tensor returned by it.get_next()
def train_input_fn_with_ext_cache():
batch = # some how obtain result of train_get_next
return batch
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn_with_ext_cache, max_steps=3000)
eval_spec = ...
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)