我已经在tf.keras
中创建了一个模型,并将其转换为tf.estimator
并将其训练为:
from tensorflow.keras.layers import Conv2D, MaxPool2D, Flatten, Dense
from tensorflow.keras.layers import Input
from tensorflow.keras import Model
from tensorflow.keras.optimizers import Adam
import tensorflow as tf
def build_estimator(config, n_class=3):
input_shape = (128, 128, 3)
inputs = Input(shape=input_shape, name='inputs')
conv_1 = Conv2D(filters=32, kernel_size=(2, 2), strides=(2, 2))(inputs)
maxpool_1 = tf.keras.layers.MaxPool2D(pool_size=(2,2))(conv_1)
conv_2 = Conv2D(filters=32, kernel_size=(2, 2), strides=(2, 2))(maxpool_1)
maxpool_2 = tf.keras.layers.MaxPool2D(pool_size=(2,2))(conv_2)
flatten = Flatten()(maxpool_2)
outputs = Dense(n_class, activation='softmax', name='outputs')(flatten)
model = Model(inputs=inputs, outputs=outputs)
model.compile(optimizer=Adam(lr=1e-3), loss='categorical_crossentropy', metrics=['accuracy'])
model_estimator = tf.keras.model_to_estimator(model, config=config)
return model_estimator
run_config = tf.estimator.RunConfig()
estimator = model.build_estimator(config=run_config, n_class=3)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
开始训练后,我将检查点保存在gcs存储桶中。现在,我想选择一个特定的检查点并在测试集上运行预测。浏览完有关creating custom estimators和this stackoverflow question的文档后,我发现我需要执行以下操作:
pred_model = tf.estimator.Estimator(model_fn=model_fn, model_dir="path/to/model.ckpt")
但是我无法弄清楚model_fn
的样子。该文档说应该采用以下格式:
def model_fn(features, labels, mode):
return tf.estimator.EstimatorSpec
考虑我正在使用tf.keras.model_to_estimator
而不是低级张量流,我的model_fn
应该是什么样?
除此之外,我的input_fn
井井有条,而且我还成功地进行了数千步训练。