TensorFlow import_meta_graph用于名称重复/冲突的多个图

时间:2019-06-19 15:06:19

标签: python tensorflow random-forest

免责声明:此问题是a previous question的后续措施。

是否可以导入多个具有相同名称的变量的TensorFlow图?据我了解,默认情况下,现有变量将被tf.train.import_meta_graph()覆盖。 the answer to another question中的示例显示了如何使用具有独特命名的变量的示例:

import tensorflow as tf

# The variables v1 and v2 that we want to restore
v1 = tf.Variable(tf.zeros([1]), name="v1")
v2 = tf.Variable(tf.zeros([1]), name="v2")

# saver1 will only look for v1
saver1 = tf.train.Saver([v1])
# saver2 will only look for v2
saver2 = tf.train.Saver([v2])
with tf.Session() as sess:
    saver1.restore(sess, "tmp/v1.ckpt")
    saver2.restore(sess, "tmp/v2.ckpt")
    print sess.run(v1)
    print sess.run(v2)

变量v1v2以前是从不同的图中存储的,现在都可以在TensorFlow默认图中使用。

但是,在使用tensor_forest.RandomForestGraphs()时,变量具有固定的名称(例如device_dummy_1)。尝试导入多个此类图时,有时会出现错误:

NotFoundError: Key device_dummy_100 not found in checkpoint
     [[{{node save/RestoreV2}}]]
     [[{{node GroupCrossDeviceControlEdges_0/save/restore_all}}]]

我的理解是:我有多个RF(这里为3个),它们全部有130棵树,还有一个RF只有100棵树。导入最后一个(较小的)树时,在导入的文件中找不到树101至130,并且导入程序抱怨这些丢失的变量。因此,我必须假设导入第二个RF会覆盖前一个RF。这是正确的吗?

总而言之,我遇到以下问题:

  • tensor_forest.RandomForestGraphs()不允许例如前缀内部变量名-不同的RF共享相同的变量名
  • 导入具有相同变量名的图会覆盖现有变量

是否有任何方法可以在导出之前或导入期间更改(前缀)一个RF中的所有变量名?还是对此有其他解决方案?

编辑:尽管使用可变范围似乎很有希望,但我创建了这个最小的示例来展示我仍然面临的问题:

import tensorflow as tf
from tensorflow.contrib.tensor_forest.python import tensor_forest
from tensorflow.python.ops import resources

num_trees = {
    0: 3,
    1: 2,
}

# create two RFs, the first one with 3 trees, the second one with 2 trees
# resulting RFs are stored to two files separately

g0 = tf.Graph()
with g0.as_default():
    base_label = 0
    hparams = tensor_forest.ForestHParams(num_classes=2, num_features=2, num_trees=num_trees[base_label], max_nodes=100).fill()
    rf0 = tensor_forest.RandomForestGraphs(hparams)
    init_vars = tf.group(tf.global_variables_initializer(), resources.initialize_resources(resources.shared_resources()))
    sess = tf.Session()
    sess.run(init_vars)
    X = tf.placeholder(tf.float32, shape=[None, rf0.params.num_features], name="X")
    Y = tf.placeholder(tf.int8, shape=[None], name="Y")
    infer_op = tf.cast(tf.argmax(rf0.inference_graph(X)[0], 1, output_type=tf.int32), tf.int8, name="infer_op")
    for var in tf.global_variables():
        print("RF0: global variable: {}".format(var.name))
    s = tf.train.Saver()
    s.save(sess, "rf0.tfsess")

g1 = tf.Graph()
with g1.as_default():
    base_label = 1
    hparams = tensor_forest.ForestHParams(num_classes=2, num_features=2, num_trees=num_trees[base_label], max_nodes=100).fill()
    rf1 = tensor_forest.RandomForestGraphs(hparams)
    init_vars = tf.group(tf.global_variables_initializer(), resources.initialize_resources(resources.shared_resources()))
    sess = tf.Session()
    sess.run(init_vars)
    X = tf.placeholder(tf.float32, shape=[None, rf1.params.num_features], name="X")
    Y = tf.placeholder(tf.int8, shape=[None], name="Y")
    infer_op = tf.cast(tf.argmax(rf1.inference_graph(X)[0], 1, output_type=tf.int32), tf.int8, name="infer_op")
    for var in tf.global_variables():
        print("RF1: global variable: {}".format(var.name))
    s = tf.train.Saver()
    s.save(sess, "rf1.tfsess")


# re-create/import both RFs into one graph, "subgraph" using variable scope

tf.reset_default_graph()
assert len(tf.global_variables()) == 0

# first RF
base_label = 0
vs = "{}".format(base_label)
with tf.variable_scope(vs):
    hparams = tensor_forest.ForestHParams(num_classes=2, num_features=2, num_trees=num_trees[base_label], max_nodes=100).fill()
    rf0 = tensor_forest.RandomForestGraphs(hparams)
    init_vars = tf.group(tf.global_variables_initializer(), resources.initialize_resources(resources.shared_resources()))
    sess0 = tf.Session()
    sess0.run(init_vars)
    X0 = tf.placeholder(tf.float32, shape=[None, rf0.params.num_features], name="X")
    Y0 = tf.placeholder(tf.int8, shape=[None], name="Y")
    infer_op0 = tf.cast(tf.argmax(rf0.inference_graph(X0)[0], 1, output_type=tf.int32), tf.int8, name="infer_op")

# second RF
base_label = 1
vs = "{}".format(base_label)
with tf.variable_scope(vs):
    hparams = tensor_forest.ForestHParams(num_classes=2, num_features=2, num_trees=num_trees[base_label], max_nodes=100).fill()
    rf1 = tensor_forest.RandomForestGraphs(hparams)
    init_vars = tf.group(tf.global_variables_initializer(), resources.initialize_resources(resources.shared_resources()))
    sess1 = tf.Session()
    sess1.run(init_vars)
    X1 = tf.placeholder(tf.float32, shape=[None, rf1.params.num_features], name="X")
    Y1 = tf.placeholder(tf.int8, shape=[None], name="Y")
    infer_op1 = tf.cast(tf.argmax(rf1.inference_graph(X1)[0], 1, output_type=tf.int32), tf.int8, name="infer_op")

# check that there are only 5 variables (3 "0/device_dummy_#" and 2 "1/device_dummy_#")
for var in tf.global_variables():
    print("global variable: {}".format(var.name))


# create input map for both graphs and import

# first RF
base_label = 0
vs = "{}".format(base_label)
input_map = {}
for i in range(num_trees[base_label]):
    t = tf.get_default_graph().get_tensor_by_name("{}/device_dummy_{}:0".format(vs, i))
    input_map["device_dummy_{}:0".format(i)] = t
    #print(input_map["device_dummy_{}:0".format(i)])
input_map["X:0"] = X0
input_map["Y:0"] = Y0
input_map["infer_op"] = infer_op0
print(input_map)
# {'device_dummy_0:0': <tf.Tensor '0/device_dummy_0:0' shape=(0,) dtype=float32_ref>, 'device_dummy_1:0': <tf.Tensor '0/device_dummy_1:0' shape=(0,) dtype=float32_ref>, 'device_dummy_2:0': <tf.Tensor '0/device_dummy_2:0' shape=(0,) dtype=float32_ref>, 'X:0': <tf.Tensor '0/X:0' shape=(?, 2) dtype=float32>, 'Y:0': <tf.Tensor '0/Y:0' shape=(?,) dtype=int8>, 'infer_op': <tf.Tensor '0/infer_op:0' shape=(?,) dtype=int8>}
s = tf.train.import_meta_graph("{}.meta".format("rf0.tfsess"), input_map=input_map)
s.restore(sess, "rf0.tfsess")

# second RF
base_label = 1
vs = "{}".format(base_label)
input_map = {}
for i in range(num_trees[base_label]):
    t = tf.get_default_graph().get_tensor_by_name("{}/device_dummy_{}:0".format(vs, i))
    input_map["device_dummy_{}:0".format(i)] = t
    #print(input_map["device_dummy_{}:0".format(i)])
input_map["X:0"] = X1
input_map["Y:0"] = Y1
input_map["infer_op"] = infer_op1
print(input_map)
# {'device_dummy_0:0': <tf.Tensor '1/device_dummy_0:0' shape=(0,) dtype=float32_ref>, 'device_dummy_1:0': <tf.Tensor '1/device_dummy_1:0' shape=(0,) dtype=float32_ref>, 'X:0': <tf.Tensor '1/X:0' shape=(?, 2) dtype=float32>, 'Y:0': <tf.Tensor '1/Y:0' shape=(?,) dtype=int8>, 'infer_op': <tf.Tensor '1/infer_op:0' shape=(?,) dtype=int8>}
s = tf.train.import_meta_graph("{}.meta".format("rf1.tfsess"), input_map=input_map)
s.restore(sess, "rf1.tfsess")

for var in tf.global_variables():
    print("global variable: {}".format(var.name))
# global variable: 0/device_dummy_0:0
# global variable: 0/device_dummy_1:0
# global variable: 0/device_dummy_2:0
# global variable: 1/device_dummy_0:0
# global variable: 1/device_dummy_1:0
# global variable: device_dummy_0:0
# global variable: device_dummy_1:0
# global variable: device_dummy_2:0
# global variable: device_dummy_0:0
# global variable: device_dummy_1:0

可以看出,最终导入的变量没有重新映射为正确的变量。 scope/var_name函数仍会创建变量import_meta_graph()(没有作用域),而不是映射到var_name。在此示例中,不会发生错误(我不知道为什么),但是对于我的应用程序,当尝试在较大的RF之后导入较小的RF时,会发生NotFoundError。这是另一个奇怪的问题,因为变量的末尾实际上似乎具有相同的名称(与device_dummy_{0,1}:0重复)。

0 个答案:

没有答案