我想使用此处所述的转移学习方法: https://www.tensorflow.org/tutorials/images/transfer_learning
问题是我要用作基本模型的模型不是已知的内置Keras模型(例如MobileNetV2)之一。因此,我想我需要执行以下第一步(步骤1),以便能够完成本教程中提到的迁移学习(步骤2-6)。
1。从包含Saved_Model文件的目录中加载模型。
2.冻结模型(使其可训练的参数不变)
3.制作一个单独的层并将其堆叠在冻结模型的顶部
4.训练生成的模型。
5.保存新训练的模型。
6.使用新训练的模型进行预测。
我的问题与第一步有关。尝试使用以下Python代码/脚本加载模型时,出现一个我不明白如何解决的错误:
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from matplotlib import pyplot as plt
import numpy as np
tf.saved_model.load(
export_dir='/dir_to_the_model_files/', tags=None
)
错误是:
OSError: Cannot parse file b'/dir_to_the_model_files/saved_model.pbtxt': 1:1 : Message type "tensorflow.SavedModel" has no field named "node"..
我还认为可能有一种方法可以将包括(saved_model.ckpt-0.data-00000-of-00001)的TensorFlow文件转换为Keras API可读的文件(例如h5py.File格式)与所提到的教程类似,这可以促进转移学习。因此,我可以对以下方法应用类似的方法,以提取基本模型并执行后续步骤。
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
include_top=False,
weights='imagenet')
或者最好使用https://www.tensorflow.org/api_docs/python/tf/keras/models/load_model中的以下方法:
tf.keras.models.load_model(
filepath, custom_objects=None, compile=True
)
更新:我尝试了以下方法,但不起作用(使用兼容版本import tensorflow.compat.v1. as tf
导入了tf):
with tf.Session() as sess:
saver = tf.train.import_meta_graph('/dir_to_the_model_files/saved_model.ckpt-0.meta')
saver.restore(sess, "/dir_to_the_model_files/saved_model.ckpt-0")
loaded = tf.saved_model.load(sess,tags=None,export_dir="/dir_to_the_model_files",import_scope=None)
它返回以下警告和错误:
WARNING:tensorflow:The saved meta_graph is possibly from an older release:
'metric_variables' collection should be of type 'byte_list', but instead is of type 'node_list'.
INFO:tensorflow:Restoring parameters from /dir_to_the_model_files/saved_model.ckpt-0
<tensorflow.python.training.saver.Saver object at 0x2aaab4824a50>
WARNING:tensorflow:From <ipython-input-3-b8fd24f6b841>:9: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0.
OSError: Cannot parse file b'/dir_to_the_model_files/saved_model.pbtxt': 1:1 : Message type "tensorflow.SavedModel" has no field named "node"..
答案 0 :(得分:0)
tf.saved_model.load
的TensorFlow文档可能会有所帮助:
来自tf.estimator.Estimator或1.x SavedModel API的SavedModels具有一个 平面图,而不是tf.function对象。这些SavedModels将具有 与.signatures中的签名相对应的函数 属性,但也有一个.prune方法,可让您提取 新子图的功能。这等效于导入 SavedModel并在TensorFlow的会话中命名提要和获取 1.x。
您可能必须使用不推荐使用的v1 api调用 https://www.tensorflow.org/api_docs/python/tf/compat/v1/saved_model/load