如何在tf.data.Dataset.map()中使用Keras的predict_on_batch?

时间:2019-04-03 02:01:11

标签: python tensorflow keras tensorflow-datasets tensorflow2.0

我想找到一种在predict_on_batch的{​​{1}}内使用Keras的tf.data.Dataset.map()的方法

假设我有一个numpy数据集

TF2.0.

和一个n_data = 10**5 my_data = np.random.random((n_data,10,1)) my_targets = np.random.randint(0,2,(n_data,1)) data = ({'x_input':my_data}, {'target':my_targets}) 模型

tf.keras

我可以使用创建一个x_input = Input((None,1), name = 'x_input') RNN = SimpleRNN(100, name = 'RNN')(x_input) dense = Dense(1, name = 'target')(RNN) my_model = Model(inputs = [x_input], outputs = [dense]) my_model.compile(optimizer='SGD', loss = 'binary_crossentropy') 的批处理

dataset

其中dataset = tf.data.Dataset.from_tensor_slices(data) dataset = dataset.batch(10) prediction_dataset = dataset.map(transform_predictions) 是用户定义的函数,可从transform_predictions获得预测

predict_on_batch

这会导致来自def transform_predictions(inputs, outputs): predictions = my_model.predict_on_batch(inputs) # predictions = do_transformations_here(predictions) return predictions 的错误:

predict_on_batch

据我了解,AttributeError: 'Tensor' object has no attribute 'numpy'期望一个numpy数组,并且它正在从数据集中获取张量对象。

似乎一种可能的解决方案是将predict_on_batch包装在`tf.py_function中,尽管我也无法使其正常工作。

有人知道该怎么做吗?

1 个答案:

答案 0 :(得分:0)

Dataset.map()返回<class 'tensorflow.python.framework.ops.Tensor'>,它没有numpy()方法。

遍历数据集返回 <class 'tensorflow.python.framework.ops.EagerTensor'>具有numpy()方法。

提供一个渴望张量的predict()系列方法很好。

您可以尝试这样的事情:

dataset = tf.data.Dataset.from_tensor_slices(data)
dataset = dataset.batch(10)

for x,y in dataset:
    predictions = my_model.predict_on_batch(x['x_input'])
    #or 
    predictions = my_model.predict_on_batch(x)