我有两个张量x
和s
形状:
> x.shape
TensorShape([Dimension(None), Dimension(3), Dimension(5), Dimension(5)])
> s.shape
TensorShape([Dimension(None), Dimension(12), Dimension(5), Dimension(5)])
我想通过尺寸x
在s
和1
之间广播点积,如下所示:
> x_s.shape
TensorShape([Dimension(None), Dimension(4), Dimension(5), Dimension(5)])
,其中
x_s[i, 0, k, l] = sum([x[i, j, k, l] * s[i, j, k, l] for j in range (3)])
x_s[i, 1, k, l] = sum([x[i, j-3, k, l] * s[i, j, k, l] for j in range (3, 6)])
x_s[i, 2, k, l] = sum([x[i, j-6, k, l] * s[i, j, k, l] for j in range (6, 9)])
x_s[i, 3, k, l] = sum([x[i, j-9, k, l] * s[i, j, k, l] for j in range (9, 12)])
我有这个实现:
s_t = tf.transpose(s, [0, 2, 3, 1]) # [None, 5, 5, 12]
x_t = tf.transpose(x, [0, 2, 3, 1]) # [None, 5, 5, 3]
x_t = tf.tile(x_t, [1, 1, 1, 4]) # [None, 5, 5, 12]
x_s = x_t * s_t # [None, 5, 5, 12]
x_s = tf.reshape(x_s, [tf.shape(x_s)[0], 5, 5, 4, 3]) # [None, 5, 5, 4, 3]
x_s = tf.reduce_sum(x_s, axis=-1) # [None, 5, 5, 4]
x_s = tf.transpose(x_s, [0, 3, 1, 2]) # [None, 4, 5, 5]
我知道由于tile
,这在内存中效率不高。此外,reshape
,transpose
&#39} element-wise
和reduce_sum
的操作可能会损害较大张量的性能。有没有其他选择让它变得更干净?
答案 0 :(得分:1)
你有证据表明reshape
贵吗?以下使用重塑和维度广播:
x_s = tf.reduce_sum(tf.reshape(s, (-1, 4, 3, 5, 5)) *
tf.expand_dims(x, axis=1), axis=2)
答案 1 :(得分:0)
只是一些建议,也许不比你的快。首先将s
与tf.split
分成四个张量,然后使用tf.tensordot
得到最终结果,就像这样
splits = tf.split(s, [3] * 4, axis=1)
splits = map(lambda split: tf.tensordot(split, x, axes=[[1], [1]]), splits)
x_s = tf.stack(splits, axis=1)