在以下情况下,我可以使用tf.matmul(A, B)
进行批矩阵乘法:
A.shape == (..., a, b)
和B.shape == (..., b, c)
, ...
相同的地方。
但是我想要一个额外的广播:
A.shape == (a, b, 2, d)
和 B.shape == (a, 1, d, c)
result.shape == (a, b, 2, c)
我希望结果是a x b
和(2, d)
之间的(d, c)
个矩阵乘法。
该怎么做?
测试代码:
import tensorflow as tf
import numpy as np
a = 3
b = 4
c = 5
d = 6
x_shape = (a, b, 2, d)
y_shape = (a, d, c)
z_shape = (a, b, 2, c)
x = np.random.uniform(0, 1, x_shape)
y = np.random.uniform(0, 1, y_shape)
z = np.empty(z_shape)
with tf.Session() as sess:
for i in range(b):
x_now = x[:, i, :, :]
z[:, i, :, :] = sess.run(
tf.matmul(x_now, y)
)
print(z)
答案 0 :(得分:1)
tf.einsum
-任意维度的张量之间的广义收缩将是您遇到这种问题的朋友。请参阅tf文档here。
关于stackoverflow有一个很棒的教程:(Understanding NumPy's einsum)。
import tensorflow as tf
import numpy as np
a = 3
b = 4
c = 5
d = 6
x_shape = (a, b, 2, d)
y_shape = (a, d, c)
z_shape = (a, b, 2, c)
x = tf.constant(np.random.uniform(0, 1, x_shape))
y = tf.constant(np.random.uniform(0, 1, y_shape))
z = tf.constant(np.empty(z_shape))
v = tf.einsum('abzd,adc->abzc', x, y)
print z.shape, v.shape
with tf.Session() as sess:
print sess.run(v)
RESULT:
(3, 4, 2, 5) (3, 4, 2, 5)
[[[[ 1.8353901 1.29175219 1.49873967 1.78156638 0.79548786]
[ 2.32836196 2.01395003 1.53038244 2.51846521 1.65700572]]
[[ 1.76139921 1.78029925 1.22302866 2.18659201 1.51694413]
[ 2.32021949 1.98895703 1.7098903 2.21515966 1.33412172]]
[[ 2.13246675 1.63539287 1.64610271 2.16745158 1.02269943]
[ 1.75559616 1.6715972 1.26049591 2.14399714 1.34957603]]
[[ 1.80167636 1.91194534 1.3438773 1.9659323 1.25718317]
[ 1.4379158 1.31033243 0.71024123 1.62527415 1.31030634]]]
[[[ 2.04902039 1.59019464 1.32415689 1.59438659 2.02918951]
[ 2.23684642 1.27256603 1.63474052 1.73646679 2.42958829]]
....
....
答案 1 :(得分:0)
仅需要tf.reshape
和tf.matmul
。无需移调。
import tensorflow as tf
import numpy as np
jit_scope = tf.contrib.compiler.jit.experimental_jit_scope
a = 3
b = 4
c = 5
d = 6
x_shape = (a, b, 2, d)
y_shape = (a, d, c)
x = tf.constant(np.random.uniform(0, 1, x_shape))
y = tf.constant(np.random.uniform(0, 1, y_shape))
x2 = tf.reshape(x, (a, b * 2, d))
with jit_scope():
z = tf.reshape(tf.matmul(x2, y), (a, b, 2, c))
z2 = x @ (y[:, np.newaxis, :, :])
z3 = tf.einsum('abzd, adc -> abzc', x, y)
with tf.Session() as sess:
z_, z2_, z3_ = sess.run([z, z2, z3])
assert np.allclose(z_, z2_)
assert np.allclose(z_, z3_)