我正在尝试使用带有估计器API的TensorFlow中的预训练模型进行转移学习。模型的详细信息(层数,神经元等)将改变并且不相关。
get_feature_extractor()实例化TensorFlow Hub模块。最终发生的情况是,每次对.train_and_evaluate()、. predict()等的调用都会使会话和图形消失,并从头开始,从而重新加载特征提取器。这需要几秒钟。是否有一种干净的方法可以在这些调用之间持久保存get_feature_extractor()的结果,并使其保留在会话中-至少对于.predict()而言?还是我必须使用较低级别的API来实现这一目标?
def model_fn(features, labels, mode):
feature_extractor = get_feature_extractor()
layer = feature_extractor(features)
layer = tf.layers.batch_normalization(layer)
layer = tf.layers.dense(inputs=layer, units=1280, activation=tf.nn.relu)
layer = tf.layers.dense(inputs=layer, units=2048, activation=tf.nn.relu)
layer = tf.layers.dense(inputs=layer, units=512, activation=tf.nn.relu)
layer = tf.layers.dense(inputs=layer, units=2)
if mode == tf.estimator.ModeKeys.PREDICT:
estimator = tf.estimator.EstimatorSpec(mode, predictions=layer)
else:
accuracy = tf.metrics.accuracy(labels=labels,
predictions=layer,
name='acc_op')
metrics = {'accuracy': accuracy}
loss = tf.losses.mean_squared_error(labels=labels, predictions=layer)
optimizer = tf.train.AdamOptimizer()
train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step())
estimator = tf.estimator.EstimatorSpec(
mode=mode, loss=loss, train_op=train_op,
eval_metric_ops=metrics)
return estimator