在没有写入磁盘的情况下,在会话之间重用TensorFlow变量值

时间:2017-07-16 17:02:55

标签: python tensorflow

在sklearn中,我习惯拥有一个可以运行fit然后predict的模型。但是,使用TensorFlow时,我在调用fit时无法从predict加载学习参数。归结为我不知道如何在会话之间重用变量的值。例如,

import tensorflow as tf

x = tf.Variable(0.0)

# fit code
with tf.Session() as sess1:
    sess1.run(tf.global_variables_initializer())
    sess1.run(tf.assign(x, 1.0)) # at end of training, x = 1.0

# predict code
with tf.Session() as sess2:
    sess2.run(tf.global_variables_initializer())
    print(sess2.run(x)) # want this to be 1.0, but is 0.0

我可以想到一个解决方法,但它看起来真的太乱了,如果我想重用几个变量会很烦人:

import tensorflow as tf

x = tf.Variable(0.0)

# fit code
with tf.Session() as sess1:
    sess1.run(tf.global_variables_initializer())
    sess1.run(tf.assign(x, 1.0)) # at end of training, x = 1.0
    learned_x = sess1.run(x) # remember value of learned x at end of session

# predict code
with tf.Session() as sess2:
    sess2.run(tf.global_variables_initializer())
    sess2.run(tf.assign(x, learned_x))
    print(sess2.run(x)) # prints 1.0

如何在不写入磁盘的情况下在会话之间重用变量(即使用tf.train.Saver)?我上面写的解决方法是正确的方法吗?

1 个答案:

答案 0 :(得分:1)

要模仿sklearn的模型,只需将 Unknown column 'issues' in 'where clause' (SQL: select * from `result_test` where `result_test`.`client_id` in (342074, 160374, 596433) and `issues` = 0) 包装到一个类中,以便您可以在方法之间共享它。

session

确保手动关闭class Model: def __init__(self): self.graph = self.build_graph() self.session = tf.Session() self.session.run(tf.global_variables_initializer()) def build_graph(self): return {'x': tf.Variable(0.0)} def fit(self): self.session.run(tf.assign(self.graph['x'], 1.0)) def predict(self): print(self.session.run(self.graph['x'])) def close(self): tf.reset_default_graph() self.session.close() m = Model() m.fit() m.predict() m.close() 并相应处理异常。