恢复检查点时使用tf.estimators和tf.keras进行预测

时间:2019-07-11 21:35:47

标签: python tensorflow keras

我已经在tf.keras中创建了一个模型,并将其转换为tf.estimator并将其训练为:

from tensorflow.keras.layers import Conv2D, MaxPool2D, Flatten, Dense
from tensorflow.keras.layers import Input
from tensorflow.keras import Model
from tensorflow.keras.optimizers import Adam
import tensorflow as tf

def build_estimator(config, n_class=3):
    input_shape = (128, 128, 3)
    inputs = Input(shape=input_shape, name='inputs')
    conv_1 = Conv2D(filters=32, kernel_size=(2, 2), strides=(2, 2))(inputs)
    maxpool_1 = tf.keras.layers.MaxPool2D(pool_size=(2,2))(conv_1)
    conv_2 = Conv2D(filters=32, kernel_size=(2, 2), strides=(2, 2))(maxpool_1)
    maxpool_2 = tf.keras.layers.MaxPool2D(pool_size=(2,2))(conv_2)
    flatten = Flatten()(maxpool_2)
    outputs = Dense(n_class, activation='softmax', name='outputs')(flatten)
    model = Model(inputs=inputs, outputs=outputs)
    model.compile(optimizer=Adam(lr=1e-3), loss='categorical_crossentropy', metrics=['accuracy'])
    model_estimator = tf.keras.model_to_estimator(model, config=config)

    return model_estimator

run_config = tf.estimator.RunConfig()
estimator = model.build_estimator(config=run_config, n_class=3)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

开始训练后,我将检查点保存在gcs存储桶中。现在,我想选择一个特定的检查点并在测试集上运行预测。浏览完有关creating custom estimatorsthis stackoverflow question的文档后,我发现我需要执行以下操作:

pred_model = tf.estimator.Estimator(model_fn=model_fn, model_dir="path/to/model.ckpt")

但是我无法弄清楚model_fn的样子。该文档说应该采用以下格式:

def model_fn(features, labels, mode):
    return tf.estimator.EstimatorSpec

考虑我正在使用tf.keras.model_to_estimator而不是低级张量流,我的model_fn应该是什么样?

除此之外,我的input_fn井井有条,而且我还成功地进行了数千步训练。

0 个答案:

没有答案