如何在TensorFlow中连续恢复多个会话

时间:2018-08-09 07:38:13

标签: tensorflow model neural-network

我想恢复具有相同结构但学习率不同的多个模型。

我遇到的问题是我无法连续运行两次还原功能。如果我评论tt1并运行tt2,反之亦然,则可以得到所需的预测,但是如果我将它们同时运行,则不可能。

def predict(data, features, submodel_type, ckpt):
    n_input, weights, biases, X_test, Y_test = init(data, submodel_type, features)
    x = tf.placeholder("float", [None, n_input])
    pred = multilayer_perceptron(x, weights, biases)
    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver.restore(sess, ckpt)
        X_test_scale = preprocessing.scale(X_test)
        pred_y = sess.run(pred, feed_dict={x: X_test_scale})
        gMAE, gMRE = evaluate('TT', pred_y, Y_test)
        print("GMRE:", gMRE)
        print("GMAE:", gMAE)
        for v1, v2 in zip(pred_y, Y_test):
            print('PV: %.2f, TV: %.2f, ERR: %.d' % (v1, v2, abs(v1 - v2)))
    sess.close()
    return pred_y, Y_test


data = np.genfromtxt('/home/simeonv/PycharmProjects/TotalTime/data/TEST SET DO NOT USE/test.csv', delimiter=',',
                     dtype=float)
features_0000 = [6,...89]
features_0001 = [0, .... 5,89]
submodel_0000 = '0000'
submodel_0001 = '0001'
ckpt_0000 = '/../TT_0000_2018_7_27_134233_MAE_1395.0954.ckpt'
ckpt_0001 = '/../TT_0001_2018_7_27_153715_MAE_1526.3000.ckpt'

tt1 = predict(data, features_0000, submodel_0000, ckpt_0000)
tt2 = predict(data, features_0001, submodel_0001, ckpt_0001)

我如何使其起作用?我尝试了sess.close,但由于我仍无法进行会话,因此无法正常工作,因为该会话仍在使用一些剩余值运行。

同时运行tt1和tt2时出现的错误是:

ValueError: Variable h1 already exists, disallowed. 
Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope? Originally defined at:

File "/home/TT/TT_0001/TT_0001.py", line 26,
 in init'h1': tf.nn.l2_normalize(tf.get_variable(name='h1', shape=[n_input, n_hidden_1], initializer=init), axis=[0]),
File "/home/TT/TT_0001/TT_0001.py", line 94,
 in predict n_input, weights, biases, X_test, Y_test = init(data, submodel_type, features)
File "/home/TT/TT_0001/TT_0001.py", line 124,
 in <module> tt1 = predict(data, features_0000, submodel_0000, ckpt_0000)

1 个答案:

答案 0 :(得分:0)

我找到了解决此问题的方法,我不知道这是否是一个好的解决方法,但是它可以工作。

tt1 = predict(data, features_0000, submodel_0000, ckpt_0000)
tf.reset_default_graph()
tt2 = predict(data, features_0001, submodel_0001, ckpt_0001)

通过重置图形,我可以毫无问题地开始新的会话。