我使用张量流训练了自然网络,而在测试过程中发生了错误,这表明我训练的图和恢复的图不匹配。但是定义用于还原的网络结构和用于训练的网络的代码是相似的。
x_data, y_data = incentiveAndResponse(childSequence[ridx, :], order)
x_test, y_test = incentiveAndResponse(testChildSequence[ridx, :], order)
# define placeholder for inputs to network
xs = tf.placeholder(tf.float32, [None, order])
ys = tf.placeholder(tf.float32, [None, 1])
# add hidden layer
l1 = add_layer(xs, order, NoNode, activation_function=tf.nn.tanh)
# l2 = add_layer(l1, NoNode, NoNode, activation_function=tf.nn.tanh)
# l3 = add_layer(l2, NoNode, NoNode, activation_function=tf.nn.tanh)
# add output layer
prediction = add_layer(l1, NoNode, 1, activation_function=None)
# the error between prediciton and real data
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),
reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
saver = tf.train.Saver()
for i in range(1000):
sess.run(train_step, feed_dict={xs: x_data, ys: y_data})
if i % 50 == 0:
# to see the step improvement
print(sess.run(loss, feed_dict={xs: x_data, ys: y_data}))
saver.save(sess, './my_net/save_net.ckpt')
tf.reset_default_graph()
# test
# define placeholder for inputs to network
xsnew = tf.placeholder(tf.float32, [None, order])
ysnew = tf.placeholder(tf.float32, [None, 1])
# add hidden layer
l1 = add_layer(xsnew, order, NoNode, activation_function=tf.nn.tanh)
# l2 = add_layer(l1, NoNode, NoNode, activation_function=tf.nn.tanh)
# l3 = add_layer(l2, NoNode, NoNode, activation_function=tf.nn.tanh)
prediction = add_layer(l1, NoNode, 1, activation_function=None)
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.restore(sess, './my_net/save_net.ckpt')
res = sess.run(prediction, feed_dict={xsnew: x_test, ysnew: y_test})
shutil.rmtree('./my_net')
这是ide的结果:
InvalidArgumentError(请参阅上面的回溯):从中还原 检查点失败。这很可能是由于两者之间的不匹配 当前图和来自检查点的图。请确保 您尚未更改基于检查点的预期图形。 原始错误:
分配需要两个张量的形状匹配。 lhs shape = [2,3] rhs shape = [2,2] [[节点保存/分配(在 F:/code/python/NewIdea/WaveletCnnInFCM.py:270)]]