如何恢复和执行两个独立的TensorFlow模型?

时间:2019-03-04 15:10:01

标签: python tensorflow

我正在开发基于TensorFlow的程序,并且需要两个不同的模型。根据第一个模型的输出,我执行了一些计算,并获得了第二个模型的输入。这里是部分代码:

key_pose_model = None
gesture_model = None


# Key Poses
with tf.device('/cpu:0'):
    if key_pose_model_type == 'ramdon_forest':
        key_pose_model = RandomForest(key_pose_ramdon_forest_num_steps,
                                key_pose_ramdon_forest_num_classes,
                                key_pose_ramdon_forest_num_trees,
                                key_pose_ramdon_forest_max_nodes,
                                key_pose_ramdon_forest_num_features)
        key_pose_model.read_model(key_pose_model_name)


# Gestures
with tf.device('/cpu:1'):
    if gesture_model_type == 'ramdon_forest':
        gesture_model = RandomForest(gesture_ramdon_forest_num_steps,
                                gesture_ramdon_forest_num_classes,
                                gesture_ramdon_forest_num_trees,
                                gesture_ramdon_forest_max_nodes,
                                gesture_ramdon_forest_num_features)
        gesture_model.read_model(gesture_model_name)

最近我的代码中有以下调用(输入数据是从传感器获取的):

while(True):
......
......
key_pose_model.prediction(input_data_x)
......
......
......
......
gesture_model.prediction(input_data_x_1)
.......

它在第一个模型上运行良好,然后在还原第二个模型时,出现了重复变量的错误,因此我认为我没有使用其他图形。我正在阅读TensorFlow文档,并尝试重现有关不同会话的示例,但我做不到。

g_1 = tf.Graph()
with g_1.as_default():
  # Operations created in this scope will be added to `g_1`.
  c = tf.constant("Node in g_1")

  # Sessions created in this scope will run operations from `g_1`.
  sess_1 = tf.Session()

g_2 = tf.Graph()
with g_2.as_default():
  # Operations created in this scope will be added to `g_2`.
  d = tf.constant("Node in g_2")

# Alternatively, you can pass a graph when constructing a <a href="./../api_docs/python/tf/Session"><code>tf.Session</code></a>:
# `sess_2` will run operations from `g_2`.
sess_2 = tf.Session(graph=g_2)

assert c.graph is g_1
assert sess_1.graph is g_1

assert d.graph is g_2
assert sess_2.graph is g_2

先谢谢您。

0 个答案:

没有答案