传递将dict返回给Keras的tf.data.Dataset

时间:2018-08-30 12:25:33

标签: python tensorflow keras

背景

据我所知,Keras将成为编写渴望/懒惰不可知模型(source)的首选方式。 Keras(source)似乎也支持tf.data,这很棒,因为这是我当前的tf.estimator项目使用的ETL。

问题

我想保留我的tf.data管道,但用tf.estimator替换我的tf.keras.Model,但是在移植图形之后,我被

绊倒了
ValueError: Please do not pass a dictionary as model inputs.

调用model.fit(dataset, steps_per_epoch=n)对我来说似乎很奇怪,因为tf.data / tf.estimator是为与(features: dict, labels: dict)一起使用而设计的,其中dict中的值为张量。

摘要

如何将tf.keras.Model与字典配合使用?

1 个答案:

答案 0 :(得分:1)

一种临时的解决方法是使用张量元组而不是字典,例如通过调用

def flatten(dataset):
    return dataset.map(lambda features, labels: (tuple(features.values()), tuple(labels.values())))

model.fit之前,同时确保Model.call通过索引而不是键(例如return tuple(outputs.values())等)。

不幸的是: 上面的技巧仅适用于培训数据。对于model.fit(validation_data=flatten(validation_dataset), ...),由于某些内部错误而无法使用:https://github.com/tensorflow/tensorflow/issues/21729