通过张量流来预测10个数字的2维输出

时间:2017-07-11 09:00:03

标签: tensorflow tensor

我想预测10个数字中的一个数字

我想要做的是从t

预测mat

每个mat[i]都与t[i]

相对应

当然我在mat和t中有超过5行,现在只是简化了问题。

我已经编写了如下代码。

#There is target data `t` and traindata `mat[0]`,`mat[1]`,`mat[2]`....

t = [0,1,0,1,0] #answer 2 dimension

limit = 10# number of degrees
mat = [[2,-2,3,-4,2,2,3,5,3,6],   #10 degrees number of mat[0] leads t[0]
[1,3,-3,2,2,5,1,3,2,3],   #10 degrees number of mat[1] leads t[1]
[-2,3,2,-2,2,-2,1,3,4,5],   #10 degrees number of mat[2] leads t[2]
[-2,2,-1,-2,2,-2,7,3,9,2],   #10 degrees number of mat[3] leads t[3]
[-2,-3,2,-2,2,-4,1,-4,4,5],   #10 degrees number of mat[4] leads t[4]
]

x = tf.placeholder(tf.float32,[None,10])
w = tf.Variable(tf.zeros([10,5]))
y = tf.matmul(x,w)
t = tf.placeholder(tf.float32,[None,1])

loss = tf.reduce_sum(tf.square(y-t))

train_step = tf.train.AdamOptimizer().minimize(loss)
sess = tf.Session()
sess.run(tf.initialize_all_variables())

train_t = np.array(mat)
train_t = train_t.reshape([limit,5])
train_x = np.zeros([limit,5])

# initialize
for row, num in enumerate(range(1,limit + 1)):
    for col, n in enumerate(range(0,5)):
        train_x[row][col] = num**n

i = 0
for _ in range(100000):
    i += 1
    sess.run(train_step,feed_dict={x:train_x,t:train_t})
    if i % 10000 == 0:
        loss_val = sess.run(loss,feed_dict={x:train_x,t:train_t})
        print('step : %d,Loss: %f' % (i,loss_val))
        w_val = sess.run(w)
        pprint("w_val")
        pprint(w_val)

然而,这显示像这样的错误

Traceback (most recent call last):
  File "wisdom2.py", line 60, in <module>
    sess.run(train_step,feed_dict={x:train_x,t:train_t})
  File "/Users/whitebear/tensorflow/lib/python3.4/site-packages/tensorflow/python/client/session.py", line 789, in run
    run_metadata_ptr)
  File "/Users/whitebear/tensorflow/lib/python3.4/site-packages/tensorflow/python/client/session.py", line 975, in _run
    % (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape())))
ValueError: Cannot feed value of shape (10, 5) for Tensor 'Placeholder:0', which has shape '(?, 10)'

1 个答案:

答案 0 :(得分:1)

问题是占位符的形状和输入的形状不匹配。占位符x需要一个 N 行和10列的值,但train_x有10行和5列。同样,t应该有N行和1列,但传递的值train_t有10行5列。您应该更改占位符的形状或输入的形状。