如何正确将数组输入TensorFlow占位符

时间:2018-11-03 12:49:22

标签: python-3.x tensorflow machine-learning neural-network deep-learning

我正在尝试将数组批量馈入Tensorflow占位符。但是即使提供了正确的形状,我也得到了InvalidArgumentError

这是我的代码的一部分:

import tensorflow as tf
import numpy as np

xdata = np.linspace(1,50, 10000)
noise = np.random.rand(len(xdata))
y_true = (1.5*xdata) + 5 + noise    #m = 1.5 and c = 5

m = tf.Variable(0.1)  #initial values
c = tf.Variable(0.2)

batch_size = 10

x = tf.placeholder(tf.float32, [batch_size])
y = tf.placeholder(tf.float32, [batch_size])

y_hat = (m*x) + c
error = tf.reduce_sum(tf.square(y-y_hat))
optimizer = tf.train.GradientDescentOptimizer(learning_rate= 0.01)
train = optimizer.minimize(error)
init = tf.global_variables_initializer()

with tf.Session() as sess:

    sess.run(init)

    n_batches = 1000

    for i in range(n_batches):

        rand_int = np.random.randint(len(xdata), size =batch_size)
        feed_dict = {x:xdata[rand_int], y: y_true[rand_int]}

        sess.run(train, feed_dict = feed_dict)
        print('Batch:',i, ' loss: ', sess.run(error))

    m_final, slope_final = sess.run([m , c])

错误是:

  

InvalidArgumentError:必须输入占位符张量的值   dtype浮动且形状为[10]的“占位符”

为什么会这样?

1 个答案:

答案 0 :(得分:1)

此行中发生错误:

print('Batch:', i, ' loss: ', sess.run(error))

为了计算张量error的值,必须输入占位符xy的值:

sess.run(error, feed_dict)