如何保存和恢复tf.estimator模型

时间:2018-06-24 07:33:07

标签: python-3.x tensorflow

我想保存并恢复我的tf.estimator模型。尽管我尝试关注stackoverflow上的其他相关问题,但我无法成功。以下input_fn可以提供要预测的数据。但是我不知道如何使用它来保存和还原模型以进行预测。 顺便说一句,我的返回数据集的形状为 [batch_size,dim] ,其中dtype为float32

def predict_input_fn(path, dim, batch_size):

    dataset = ds.get_dataset(path,
                             dim)

    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(1) 
    return dataset

到目前为止,我一直在尝试以下操作,但是它没有按预期工作,请您帮我保存和恢复这种模型吗?

试用

  def serving_input_receiver_fn():

        features = tf.placeholder(
            dtype=tf.float32, shape=[None, batch_size])

        fn = lambda x : precict_input_fn(path, dim, batch_size)

        mapped_fn = tf.map_fn(fn, features)
        return tf.estimator.export.ServingInputReceiver(mapped_fn, features)

    estimator.export_savedmodel(model_save_path, serving_input_receiver_fn)

错误:

Failed to convert object of type <class 'tensorflow.python.data.ops.dataset_ops.PrefetchDataset'> to Tensor. Contents: <PrefetchDataset shapes: (?, 1024), types: tf.float32>. Consider casting elements to a supported type

0 个答案:

没有答案