如何从自动编码器程序中获取重量和偏差值

时间:2017-01-06 10:54:19

标签: python python-2.7 tensorflow

我运行了mnist_ae1.py(非常简单的自动编码器模型),并希望得到w_encb_enc的值。 所以,我添加了一些流程如下。

# Train
init = tf.initialize_all_variables()

with tf.Session() as sess:
    sess.run(init)
    print('Training...')
    for i in range(10001):
        batch_xs, batch_ys = mnist.train.next_batch(128)
        train_step.run({x: batch_xs, y_: batch_ys})

        if i % 1000 == 0:
            train_loss = loss.eval({x: batch_xs, y_: batch_ys})
            print('  step, loss = %6d: %6.3f' % (i, train_loss))

    # generate decoded image with test data
    test_fd = {x: mnist.test.images, y_: mnist.test.labels}
    decoded_imgs = decoded.eval(test_fd)
    print('loss (test) = ', loss.eval(test_fd))

    # add
    w_enc_array, b_enc_array = train_step.run([w_enc, b_enc], {x: mnist.test.images})
    print("w_enc :", w_enc_array)
    print("b_enc :", b_enc_array)

    weight_result = np.append(w_enc_array, 0)
    weight_result = np.append(weight_result, b_enc)
    np.savetxt("weight_result.csv", weight_result, delimiter=",")

但是,抛出了以下错误。

w_enc_array, b_enc_array = train_step.run([w_enc, b_enc], {x: mnist.test.images})
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 1449, in run
    _run_using_default_session(self, feed_dict, self.graph, session)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 3664, in _run_using_default_session
    if session.graph is not graph:
AttributeError: 'dict' object has no attribute 'graph'

如何获取并保存w_encb_enc的价值?

1 个答案:

答案 0 :(得分:0)

您在train_step.run行中调用的Operation.run的第二个参数是会话。通过传递字典,您可以将运行时混淆为将该字典视为会话。请尝试改为feed_dict={x...}