使用`names_to_saveables`时,Tensorflow Saver无法恢复张量(错误:不是变量)

时间:2019-06-25 15:12:53

标签: python tensorflow random-forest

在这个最小的示例中,我将创建一个小的随机森林,将其存储到文件(rf_test.tsess)中,然后将其还原到新的图形/会话中。我正在尝试对names_to_saveables使用Saver.restore()将存储的变量映射到不同变量范围中的现有变量。

import numpy as np
import tensorflow as tf
from tensorflow.contrib.tensor_forest.python import tensor_forest
from tensorflow.python.ops import resources
from tensorflow.python.tools import inspect_checkpoint as chkp

np.random.seed(123)

n_max = 20000
batch_size = 100
x_max = 100
y_max = 20
X_train = np.reshape((x_max, y_max) * np.random.random_sample((n_max, 2)), [-1, 2])
Y_train = [(lambda x, y: 0 if y < x**0.5 else 1)(*xs) for xs in X_train]

Y_sample_s = slice(100, 1000)
Y_sample = []

g0 = tf.Graph()
with g0.as_default(), tf.Session().as_default() as sess:
    base_label = 0
    hparams = tensor_forest.ForestHParams(num_classes=2, num_features=2, num_trees=10, max_nodes=100).fill()
    rf = tensor_forest.RandomForestGraphs(hparams)
    init_vars = tf.group(tf.global_variables_initializer(), resources.initialize_resources(resources.shared_resources()))
    sess.run(init_vars)

    X = tf.placeholder(tf.float32, shape=[None, rf.params.num_features], name="X")
    Y = tf.placeholder(tf.int8, shape=[None], name="Y")
    train_op = rf.training_graph(X, Y)
    loss_op = rf.training_loss(X, Y)
    infer_op = tf.cast(tf.argmax(rf.inference_graph(X)[0], 1, output_type=tf.int32), tf.int8, name="infer_op")
    correct_prediction = tf.equal(infer_op, Y, name="correct_prediction")
    accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name="accuracy_op")

    acc = 0.0
    for i in range(n_max // batch_size):
        s = slice(i * batch_size, (i + 1) * batch_size)
        _, l = sess.run([train_op, loss_op], feed_dict={ X: X_train[s, :], Y: Y_train[s] })
        acc = sess.run(accuracy_op, feed_dict={ X: X_train[s, :], Y: Y_train[s] })
        print('Step %i, Loss: %f, Acc: %f' % (i, l, acc))
        if acc > 0.8:
            # exit early
            break

    for i in range(rf.params.num_trees):
        n = "device_dummy_{}".format(i)
        t = g0.get_tensor_by_name("{}:0".format(n))
        print(t)
        #print("size {}: {}".format(t.name, tf.size(t).eval()))
        n = "stats-{}".format(i)
        t = g0.get_tensor_by_name("{}:0".format(n))
        print(t)
        #print("size {}: {}".format(t.name, tf.size(t).eval()))
        n = "tree-{}".format(i)
        t = g0.get_tensor_by_name("{}:0".format(n))
        print(t)
        #print("size {}: {}".format(t.name, tf.size(t).eval()))

    #for var in tf.global_variables():
    #    print("global variable: {}".format(var.name))

    Y_sample = sess.run(infer_op, feed_dict={ X: X_train[Y_sample_s, :] })

    s = tf.train.Saver()
    s.save(sess, "rf_test.tfsess")

tf.reset_default_graph()
tf.get_default_graph().as_default()
assert len(tf.global_variables()) == 0

chkp.print_tensors_in_checkpoint_file("rf_test.tfsess", tensor_name="", all_tensors=True)

Y_sample_2 = []

with tf.Session().as_default() as sess:
    vs = "myscope"
    with tf.variable_scope(vs, reuse=tf.AUTO_REUSE):
        hparams = tensor_forest.ForestHParams(num_classes=2, num_features=2, num_trees=10, max_nodes=100).fill()
        rf = tensor_forest.RandomForestGraphs(hparams)
        init_vars = tf.group(tf.global_variables_initializer(), resources.initialize_resources(resources.shared_resources()))
        sess.run(init_vars)

        X = tf.placeholder(tf.float32, shape=[None, rf.params.num_features], name="X")
        Y = tf.placeholder(tf.int8, shape=[None], name="Y")
        train_op = rf.training_graph(X, Y)
        loss_op = rf.training_loss(X, Y)
        infer_op = tf.cast(tf.argmax(rf.inference_graph(X)[0], 1, output_type=tf.int32), tf.int8, name="infer_op")
        correct_prediction = tf.equal(infer_op, Y, name="correct_prediction")
        accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name="accuracy_op")

        print(tf.contrib.framework.list_variables("rf_test.tfsess"))

        vars = tf.contrib.framework.list_variables("rf_test.tfsess")
        var_map = {}
        for var, shape in vars:
            print(var)
            suffix = "" if var.endswith(":0") else ":0"
            t = tf.get_default_graph().get_tensor_by_name("{}/{}{}".format(vs, var, suffix))
            var_map["{}".format(var)] = t
        s = tf.train.Saver(var_map)
        s.restore(sess, "rf_test.tfsess")

        for var in tf.global_variables():
            print("global variable: {}".format(var.name))


    Y_sample_2 = sess.run(infer_op, feed_dict={ X: X_train[Y_sample_s, :] })

print()
print((Y_sample == Y_sample_2).all())
print()
print(X_train[Y_sample_s, :][:10])
print(Y_sample[:10])
print(Y_sample_2[:10])

不幸的是,以上程序导致错误:

TypeError: names_to_saveables must be a dict mapping string names to Tensors/Variables. Not a variable: Tensor("myscope/stats-0:0", shape=(), dtype=resource, device=/device:CPU:0)

有趣的是,目标变量是错误消息所请求的张量。但是,我还注意到,会话文件中存储的张量有些奇怪的事情。随机森林图中有三种类型的张量:

Tensor("device_dummy_0:0", shape=(0,), dtype=float32_ref, device=/device:CPU:0)
Tensor("stats-0:0", shape=(), dtype=resource, device=/device:CPU:0)
Tensor("tree-0:0", shape=(), dtype=resource, device=/device:CPU:0)

森林中的每棵树都有其中一个(索引号递增)。但是,tf.contrib.framework.list_variables("rf_test.tfsess")向我显示了

[('device_dummy_0', [0]), ('device_dummy_1', [0]), ('device_dummy_2', [0]), ('device_dummy_3', [0]), ('device_dummy_4', [0]), ('device_dummy_5', [0]), ('device_dummy_6', [0]), ('device_dummy_7', [0]), ('device_dummy_8', [0]), ('device_dummy_9', [0]), ('stats-0:0', []), ('stats-1:0', []), ('stats-2:0', []), ('stats-3:0', []), ('stats-4:0', []), ('stats-5:0', []), ('stats-6:0', []), ('stats-7:0', []), ('stats-8:0', []), ('stats-9:0', []), ('tree-0:0', []), ('tree-1:0', []), ('tree-2:0', []), ('tree-3:0', []), ('tree-4:0', []), ('tree-5:0', []), ('tree-6:0', []), ('tree-7:0', []), ('tree-8:0', []), ('tree-9:0', [])]

并显示device_dummy_#变量没有后缀:0,而stats-#tree-#变量确实有后缀:0。当删除所有变量的var_map后缀时,错误更改为

ValueError: The name 'myscope/device_dummy_0' refers to an Operation, not a Tensor. Tensor names must be of the form "<op_name>:<output_index>".

,并将其添加到所有变量时,错误为

ValueError: The name 'myscope/stats-0:0:0' looks a like a Tensor name, but is not a valid one. Tensor names must be of the form "<op_name>:<output_index>".

上面的代码如何工作?

0 个答案:

没有答案