我正在尝试从.cpkt文件还原循环神经网络。我恢复网络的代码是:
graph = tf.Graph()
with graph.as_default():
X = tf.placeholder(tf.float32, [1, n_steps, n_inputs])
cell = tf.contrib.rnn.OutputProjectionWrapper(
tf.contrib.rnn.BasicRNNCell(num_units=n_neurons, activation=tf.nn.relu),
output_size=n_outputs
)
outputs, states = tf.nn.dynamic_rnn(cell, X, dtype=tf.float32)
saver = tf.train.Saver()
with tf.Session(graph=graph) as sess:
name = "rnnMonthly2"
saver.restore(sess, os.getcwd() + "//RNNConfigs//" + name + "//" + name + ".cpkt")
X_batch = priceArrayToRNNFormat(getPriceArray(symbol="IBM")[-30:0])
y_val = sess.run(feed_dict={X: X_batch})
print(y_val)
作为参考,文本检查点文件指出检查点文件的路径如下:
model_checkpoint_path: "/home/john/Python/StockProject//RNNConfigs//rnnMonthly2//rnnMonthly2.cpkt"
all_model_checkpoint_paths: "/home/john/Python/StockProject//RNNConfigs//rnnMonthly2//rnnMonthly2.cpkt"
由于这个原因,我认为给定我保存到saver.restore的文件路径应该正确地还原模型。但是,当我运行代码时,会收到以下消息:
Traceback (most recent call last):
File "/home/john/Python/StockProject/monthlyRnn1.py", line 151, in <module>
saver.restore(sess, os.getcwd() + "//RNNConfigs//" + name + "//" + name + ".cpkt.index")
File "/home/john/.local/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 1538, in restore
+ compat.as_text(save_path))
ValueError: The passed save_path is not a valid checkpoint: /home/john/Python/StockProject//RNNConfigs//rnnMonthly2//rnnMonthly2.cpkt.index
此错误的原因是什么,我该怎么解决?作为参考,这是我用来训练和保存网络的代码:
saver = tf.train.Saver()
init = tf.global_variables_initializer()
with tf.Session() as sess:
mse_list = []
init.run()
for iteration in range(n_iterations):
dataOrig = allStocksDict[list(allStocksDict.keys())[iteration]]
X_batch, y_batch = priceArrayToRNNFormat(dataOrig)
print(X_batch, y_batch)
print(X_batch, y_batch)
sess.run(training_op, feed_dict={X: X_batch, y: y_batch})
mse = loss.eval(feed_dict={X: X_batch, y: y_batch})
print(iteration, "\tMSE", mse)
mse_list.append(mse)
print(mse_list)
name = "rnnMonthly2"
saver.save(sess, os.getcwd() + "//RNNConfigs//" + name + "//" + name + ".cpkt")