在TF 1.14和TF 2.1下恢复张量流模型问题

时间:2020-09-05 07:52:11

标签: tensorflow model save load version

在TF 1.14和TF 2.1之间使用tf.saved_model.simple_save和tf.saved_model.load时遇到了一些麻烦

如您所见,我附上了代码,

我想看看权重(W),其值必须是节省时间时初始化的状态。

在TF 2.1下, 保存和加载tensorflow模型(pb文件)没有问题。 保存后加载时,我能够识别出相同的重量(W)值

但是,当我使用TF 1.14时, 保存模型还可以..但是,当我加载保存的模型时,结果不是我期望的。 看来tf.saved_model.load无法加载节省的重量,只能随机初始化。

我附上了下面的代码, 您可以通过切换TF_VERSION = 2.1和1.14,SAVE = True和False来运行

TF_VERSION = 2.1
# TF_VERSION = 1.14
SAVE = False

model_dir_path = "./pb"


if TF_VERSION == 1.14:
    import tensorflow as tf
else:
    import tensorflow.compat.v1 as tf
    tf.compat.v1.disable_eager_execution()



X = tf.placeholder(tf.float32, shape=[None, 2], name='input')

# weight
weight_initer = tf.truncated_normal_initializer(mean=0.0, stddev=0.01)
W = tf.get_variable(name="Weight", dtype=tf.float32, shape=[2, 1], initializer=weight_initer)

# bias
bias_initer = tf.constant(0., shape=[1], dtype=tf.float32)
b = tf.get_variable(name="Bias", dtype=tf.float32, initializer=bias_initer)

x_w = tf.matmul(X, W, name="MatMul")
x_w_b = tf.add(x_w, b, name="Add")


#save
if SAVE:
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        x_batch = [[2, 1], [3, 5]]

        feed_dict = {X: x_batch}
        output = sess.run(x_w_b, feed_dict=feed_dict)

        tf.saved_model.simple_save(sess, model_dir_path, inputs={"inputs": X}, outputs={"outputs": W})


# restore
else:
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        tf.saved_model.load(sess, [tf.saved_model.tag_constants.SERVING], model_dir_path)

        x_batch = [[2, 1], [3, 5]]

        feed_dict = {X: x_batch}
        weight = sess.run(W, feed_dict=feed_dict)

        print(weight)

0 个答案:

没有答案