如何使用Tensorflow的Dataset API在Keras中进行预测

时间:2018-12-23 16:14:16

标签: tensorflow keras

我试图使用Tensorflow的Dataset API在Keras中构建模型。我成功地能够在喀拉拉邦训练模型。但是用于预测测试数据。它必须位于numpy数组中。

https://keras.io/models/model/

x必须是numpy数组。所以我做了这样的事情

x = input_fn('test.tfrecords')
model = models.load_model("model/model-40-0.35.hdf5")

with tf.Session()) as sess:  
          x_out = np.asarray(sess.run(x))
pred = model.predict(x_out,batch_size=BATCH_SIZE, verbose=1)
print(pred)

它成功地做出了预测,但是我在考虑是否有任何方法可以将张量插入到预测函数中。而不是首先将其转换为numpy数组。

我找到了这个https://www.tensorflow.org/api_docs/python/tf/keras/models/Model#predict

但是即使我输入张量,也会出现此错误

  

TypeError:predict()缺少1个必需的位置参数:'x'

1 个答案:

答案 0 :(得分:0)

alias some_alias 'cd /some folder' alias another_alias 'frob -x' # ... 方法的第一个参数(tensorflow.python.keras.engine.training.Model实例的方法)可以是以下任意一个:

  • 一个Numpy数组(或类似数组的数组)或数组列表(如果模型具有多个输入)。
  • TensorFlow张量或张量列表(如果模型具有多个输入)。
  • predict数据集或数据集迭代器。
  • 生成器或tf.data实例。

因此它可以是张量。在您的情况下,keras.utils.Sequence参数可以采用张量。