Tensorflow矩阵乘法误差(尺寸必须相等,但'Mul'为3和4)

时间:2017-11-29 15:01:37

标签: tensorflow matrix-multiplication

我正在尝试构建基本设置:

输入4个值。 然后是一层3个神经元。 公式为a = W*x + b

理论上,x将为4x1W将为3x4a则为3x1

现在我试图用tensorflow语法定义它:

import tensorflow as tf
sess = tf.Session()
W = tf.constant([
    [.2, .3, -.1, -.2],
    [.2, -.1, .7, -.1],
    [.1, .4, -.4, -.3],
], dtype=tf.float32)
print(w.get_shape()) # (3, 4)

x = tf.constant([
        [0.0],
        [1.0],
        [2.0],
        [3.0]
    ]
)
print(x.get_shape()) # (1, 4)

a = W*x

print(sess.run(a))

然后我收到此错误,

tensorflow.python.framework.errors_impl.InvalidArgumentError: 
Dimensions must be equal, but are 3 and 4 for 'Mul' (op: 'Mul') 
with input shapes: [3,4], [4,1].

我认为这是完全正常的计算((m,n) X (n,k) = (m,k))。 因此,我预计(3,1)a

我没得到什么?

1 个答案:

答案 0 :(得分:2)

我希望我上面的评论可以为您提供答案。但是,希望以下有所帮助:

*tf.multiply执行逐元素乘法。它需要尺寸匹配或可以广播。使用tf.matmul

import tensorflow as tf
sess = tf.Session()
W = tf.constant([
    [.2, .3, -.1, -.2],
    [.2, -.1, .7, -.1],
    [.1, .4, -.4, -.3],
], dtype=tf.float32)
print(W.get_shape()) # (3, 4)

x = tf.constant([
        [0.0],
        [1.0],
        [2.0],
        [3.0]
    ]
)
print(x.get_shape()) # (1, 4)

a = tf.matmul(W,x)

print(sess.run(a))

print a.shape

输出:

(3, 4)
(4, 1)
[[-0.5       ]
 [ 1.        ]
 [-1.29999995]]
(3, 1)