Tensorflow如何管理图表?

时间:2016-11-28 16:42:59

标签: python tensorflow loading encapsulation prediction

我已经意识到Tensorflow似乎正在管理图表的方式正在发生一些时髦的事情。

由于构建(和重建)模型非常繁琐,我决定将自定义模型包装在一个类中,以便我可以轻松地在其他地方重新实例化。

当我在训练和测试代码时(在原始位置),它可以正常工作,但是在我加载图形变量的代码中,我会得到各种奇怪的错误 - 变量重新定义和其他一切。这(从我关于类似事情的最后一个问题)是暗示一切都被调用了两次。

在进行TON跟踪之后,它归结为我使用加载代码的方式。它是在具有类似结构的类中使用的

class MyModelUser(object):
    def forecast(self):
       # .. build the model in the same way as in the training code
       # load the model checkpoint
       # call the "predict" function on the model
       # manipulate the prediction and return it

然后在一些使用MyModelUser的代码中

def test_the_model(self):
   model_user = MyModelUser()
   print(model_user.forecast())  # 1
   print(model_user.forecast())  # 2

和我(显然)预计会看到两个预测。相反,第一个预测被调用并按预期工作,但第二个调用抛出了一个 TON 变量重用ValueError,其中一个例子是:

ValueError: Variable weight_def/weights already exists, disallowed. Did you mean to set reuse=True in VarScope?

我设法通过添加一系列使用get_variable创建变量的try / except块来平息错误,然后在范围上调用reuse_variables的异常,然后{{1}没有任何东西,除了名字。这带来了一系列令人讨厌的错误,其中之一就是:

get_variable

我随心所欲地说,如果我将建模构建代码移动到tensorflow.python.framework.errors.NotFoundError: Tensor name "weight_def/weights/Adam_1" not found in checkpoint files ,那么它只构建一次会怎么样?"

我的新模特用户:

__init__

现在:

class MyModelUser(object):
    def __init__(self):
       # ... build the model in the same way as in the training code
       # load the model checkpoint


    def forecast(self):
       # call the "predict" function on the model
       # manipulate the prediction and return it

按预期工作,打印两个没有错误的预测。这让我相信我也可以摆脱变量重用的东西。

我的问题是:

为什么要修复它?理论上,图表应该在原始预测方法中每次重新安装,因此它不应该创建多个图形。即使在函数完成后,Tensorflow是否仍然保留图形?这是为什么将创建代码移动到def test_the_model(self): model_user = MyModelUser() print(model_user.forecast()) # 1 print(model_user.forecast()) # 2 的原因?这让我无可救药地困惑。

2 个答案:

答案 0 :(得分:3)

默认情况下,TensorFlow使用在您第一次调用TensorFlow API时创建的单个全局tf.Graph实例。如果未显式创建background-color,则将在该默认实例中创建所有操作,张量和变量。这意味着您的代码中tf.Graph的每次调用都会将操作添加到同一个全局图中,这有点浪费。

这里有(至少)两种可能的行动方案:

  • 理想的操作是重新构建代码,以便model_user.forecast()使用执行预测所需的所有操作构建整个MyModelUser.__init__()tf.Graph只执行{{} 1}}调用现有图表。理想情况下,您也只能创建一个MyModelUser.forecast(),因为TensorFlow会在会话中缓存有关图形的信息,并且执行效率会更高。

  • 对于sess.run()的每次调用,创建新tf.Session的侵入性较小但效率可能较低。从问题中tf.Graph方法创建了多少状态不清楚,但是您可以执行以下操作将两个调用放在不同的图中:

    MyModelUser.forecast()

答案 1 :(得分:0)

TF有一个默认图表,可以添加新的操作等。当你两次调用你的函数时,你会将相同的东西两次添加到同一个图形中。因此,要么构建一次图形并多次评估它(正如您所做的那样,这也是"正常"方法),或者,如果您想要更改内容,可以使用reset_default_graph {{3}重置图表以获得新鲜状态。