Tensorflow:“无效的检查点”

时间:2019-03-10 20:18:00

标签: python tensorflow neural-network recurrent-neural-network

我正在尝试从.cpkt文件还原循环神经网络。我恢复网络的代码是:

graph = tf.Graph()
with graph.as_default():
    X = tf.placeholder(tf.float32, [1, n_steps, n_inputs])
    cell = tf.contrib.rnn.OutputProjectionWrapper(
        tf.contrib.rnn.BasicRNNCell(num_units=n_neurons, activation=tf.nn.relu),
        output_size=n_outputs
    )
    outputs, states = tf.nn.dynamic_rnn(cell, X, dtype=tf.float32)
    saver = tf.train.Saver()
with tf.Session(graph=graph) as sess:
    name = "rnnMonthly2"
    saver.restore(sess, os.getcwd() + "//RNNConfigs//" + name + "//" + name + ".cpkt")
    X_batch = priceArrayToRNNFormat(getPriceArray(symbol="IBM")[-30:0])
    y_val = sess.run(feed_dict={X: X_batch})
    print(y_val)

作为参考,文本检查点文件指出检查点文件的路径如下:

model_checkpoint_path: "/home/john/Python/StockProject//RNNConfigs//rnnMonthly2//rnnMonthly2.cpkt"
all_model_checkpoint_paths: "/home/john/Python/StockProject//RNNConfigs//rnnMonthly2//rnnMonthly2.cpkt"

由于这个原因,我认为给定我保存到saver.restore的文件路径应该正确地还原模型。但是,当我运行代码时,会收到以下消息:

Traceback (most recent call last):
  File "/home/john/Python/StockProject/monthlyRnn1.py", line 151, in <module>
    saver.restore(sess, os.getcwd() + "//RNNConfigs//" + name + "//" + name + ".cpkt.index")
  File "/home/john/.local/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 1538, in restore
    + compat.as_text(save_path))
ValueError: The passed save_path is not a valid checkpoint: /home/john/Python/StockProject//RNNConfigs//rnnMonthly2//rnnMonthly2.cpkt.index

此错误的原因是什么,我该怎么解决?作为参考,这是我用来训练和保存网络的代码:

saver = tf.train.Saver()
init = tf.global_variables_initializer()

with tf.Session() as sess:
    mse_list = []
    init.run()
    for iteration in range(n_iterations):
        dataOrig = allStocksDict[list(allStocksDict.keys())[iteration]]
        X_batch, y_batch = priceArrayToRNNFormat(dataOrig)
        print(X_batch, y_batch)
        print(X_batch, y_batch)
        sess.run(training_op, feed_dict={X: X_batch, y: y_batch})
        mse = loss.eval(feed_dict={X: X_batch, y: y_batch})
        print(iteration, "\tMSE", mse)
        mse_list.append(mse)
    print(mse_list)
    name = "rnnMonthly2"
    saver.save(sess, os.getcwd() + "//RNNConfigs//" + name + "//" + name + ".cpkt")

0 个答案:

没有答案