如何为tf.data.Dataset

时间:2018-03-03 09:17:05

标签: python tensorflow tensorflow-datasets tensorflow-estimator

我试图找出如何编写正确的input_fn。我使用Tensorflow Estimators和Datasets重建了keras UNet实现。它工作正常,但它比Keras实现慢得多。 我的输入管道直接从SSD读取文件,而Keras实现将整个DataSet加载到内存中。

我想在DataSet中做同样的事情,遗憾的是ds.cache()的每次中断都会清除estimator.train的缓存。我必须打断它以验证模型,因为我试图只使用一个GPU。

我查看了Dataset.cache的实现,我猜测缓存绑定到包含数据集的图形,每次训练中断时都会释放该图形。

我有哪些选项让我的输入管道保持在张量流图中?我想避免在python和tensorflow之间来回传递整个训练数据。

我可以在另一张图中使用张量吗?所以我可以用input_fn创建一个外部(到Estimator的API)图形,然后以某种方式在estimator.train创建的图形中使用它。

这方面的事情:

train_graph = tf.Graph()
with training_graph.as_default():
    train_get_next = input_fn(training_set, ...) () # tensor returned by it.get_next()

def train_input_fn_with_ext_cache():
    batch = # some how obtain result of train_get_next 
    return batch

train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn_with_ext_cache,  max_steps=3000)
eval_spec = ...

tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

0 个答案:

没有答案