我对Tensorflow估算器非常陌生。我正在尝试在训练后保存 pix2pix 模型。使用Estimators以后如何在本地计算机上导出和使用模型基本上是我的问题(就像我们对tf.keras模型所做的那样,即加载权重并进行预测)。我已经尝试了几乎所有看到的解决方案。
对于 serving_input_fn 和 tf.estimator.ModeKeys.PREDICT 案例,我大多感到困惑我写的> model_fn :
^[[3J^[[H^[[2J
serving_input_fn stackoverflow source。
我正在使用 TFRECORD 进行训练,但希望使用 .png 图片进行预测。我的 model_fn 是:
def serving_input_receiver_fn(FLAGS):
def decode_and_resize(image_str_tensor):
image = tf.image.decode_jpeg(image_str_tensor,
channels=FLAGS.NB_CHANNELS)
image = tf.expand_dims(image, 0)
image = tf.image.resize_bilinear(image,
[FLAGS.IMAGE_DIM, FLAGS.IMAGE_DIM],
align_corners=False)
image = tf.squeeze(image,
squeeze_dims=[0])
image = tf.cast(image,
dtype=tf.uint8)
return image
input_ph = tf.placeholder(tf.string,
shape=[None],
name='image_binary')
images_tensor = tf.map_fn(decode_and_resize,
input_ph,
back_prop=False,
dtype=tf.uint8)
images_tensor = tf.image.convert_image_dtype(images_tensor,
dtype=tf.float32)
return tf.estimator.export.ServingInputReceiver({'images': images_tensor},
{'bytes': input_ph}
)
serving_input_fn = partial(serving_input_receiver_fn,FLAGS=FLAGS)
估算器训练得很好,检查点和其他文件被顺利保存在 GCS-Bucket 中。只有当我尝试导出模型时,它才会给出错误:
def model_fn(features, labels, mode, params):
if (mode != tf.estimator.ModeKeys.PREDICT):
loss = loss_fn(features, labels)
learning_rate = tf.train.exponential_decay(FLAGS.LEARNING_RATE,
tf.train.get_global_step(),
decay_steps=100000,
decay_rate=0.96)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
train_op=optimizer.minimize(loss,
tf.train.get_global_step())
return tf.contrib.tpu.TPUEstimatorSpec(mode=mode,
loss=loss,
train_op=train_op)
else:
predictions=generator_fn(features)
return tf.contrib.tpu.TPUEstimatorSpec(mode=mode,
predictions={"predictions": predictions})