我运行了mnist_ae1.py(非常简单的自动编码器模型),并希望得到w_enc
和b_enc
的值。
所以,我添加了一些流程如下。
# Train
init = tf.initialize_all_variables()
with tf.Session() as sess:
sess.run(init)
print('Training...')
for i in range(10001):
batch_xs, batch_ys = mnist.train.next_batch(128)
train_step.run({x: batch_xs, y_: batch_ys})
if i % 1000 == 0:
train_loss = loss.eval({x: batch_xs, y_: batch_ys})
print(' step, loss = %6d: %6.3f' % (i, train_loss))
# generate decoded image with test data
test_fd = {x: mnist.test.images, y_: mnist.test.labels}
decoded_imgs = decoded.eval(test_fd)
print('loss (test) = ', loss.eval(test_fd))
# add
w_enc_array, b_enc_array = train_step.run([w_enc, b_enc], {x: mnist.test.images})
print("w_enc :", w_enc_array)
print("b_enc :", b_enc_array)
weight_result = np.append(w_enc_array, 0)
weight_result = np.append(weight_result, b_enc)
np.savetxt("weight_result.csv", weight_result, delimiter=",")
但是,抛出了以下错误。
w_enc_array, b_enc_array = train_step.run([w_enc, b_enc], {x: mnist.test.images})
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 1449, in run
_run_using_default_session(self, feed_dict, self.graph, session)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 3664, in _run_using_default_session
if session.graph is not graph:
AttributeError: 'dict' object has no attribute 'graph'
如何获取并保存w_enc
和b_enc
的价值?
答案 0 :(得分:0)
您在train_step.run
行中调用的Operation.run的第二个参数是会话。通过传递字典,您可以将运行时混淆为将该字典视为会话。请尝试改为feed_dict={x...}
。