在TF 1.14和TF 2.1之间使用tf.saved_model.simple_save和tf.saved_model.load时遇到了一些麻烦
如您所见,我附上了代码,
我想看看权重(W),其值必须是节省时间时初始化的状态。
在TF 2.1下, 保存和加载tensorflow模型(pb文件)没有问题。 保存后加载时,我能够识别出相同的重量(W)值
但是,当我使用TF 1.14时, 保存模型还可以..但是,当我加载保存的模型时,结果不是我期望的。 看来tf.saved_model.load无法加载节省的重量,只能随机初始化。
我附上了下面的代码, 您可以通过切换TF_VERSION = 2.1和1.14,SAVE = True和False来运行
TF_VERSION = 2.1
# TF_VERSION = 1.14
SAVE = False
model_dir_path = "./pb"
if TF_VERSION == 1.14:
import tensorflow as tf
else:
import tensorflow.compat.v1 as tf
tf.compat.v1.disable_eager_execution()
X = tf.placeholder(tf.float32, shape=[None, 2], name='input')
# weight
weight_initer = tf.truncated_normal_initializer(mean=0.0, stddev=0.01)
W = tf.get_variable(name="Weight", dtype=tf.float32, shape=[2, 1], initializer=weight_initer)
# bias
bias_initer = tf.constant(0., shape=[1], dtype=tf.float32)
b = tf.get_variable(name="Bias", dtype=tf.float32, initializer=bias_initer)
x_w = tf.matmul(X, W, name="MatMul")
x_w_b = tf.add(x_w, b, name="Add")
#save
if SAVE:
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
x_batch = [[2, 1], [3, 5]]
feed_dict = {X: x_batch}
output = sess.run(x_w_b, feed_dict=feed_dict)
tf.saved_model.simple_save(sess, model_dir_path, inputs={"inputs": X}, outputs={"outputs": W})
# restore
else:
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
tf.saved_model.load(sess, [tf.saved_model.tag_constants.SERVING], model_dir_path)
x_batch = [[2, 1], [3, 5]]
feed_dict = {X: x_batch}
weight = sess.run(W, feed_dict=feed_dict)
print(weight)