张量流中的广播点积

时间:2017-06-06 06:59:40

标签: python tensorflow

在tensorflow中,我有以下问题。

我有一个形状[batch_size,dim_a,dim_b]的张量 m 和形状[batch_size,dim_b]的矩阵 u

M = tf.constant(shape=[batch_size, sequence_size, embed_dim])
U = tf.constant(shape=[batch_size, embed_dim])

我要实现的是我的批次的每个索引的[i,dim_a,dim_b] x [i,dim_b]的点积。

P[i] = tf.matmul(M[i, :, :], tf.expand_dims(U[i, :], 1)) for each i.

基本上,在批次轴上广泛使用点积。这是可能的,我该如何实现呢?

1 个答案:

答案 0 :(得分:3)

这可以通过tf.einsum()实现:

import tensorflow as tf
import numpy as np

batch_size = 2
sequence_size = 3
embed_dim = 4

M = tf.constant(range(batch_size * sequence_size * embed_dim), shape=[batch_size, sequence_size, embed_dim])
U = tf.constant(range(batch_size, embed_dim), shape=[batch_size, embed_dim])

prod = tf.einsum('bse,be->bs', M, U)

with tf.Session():
  print "M"
  print M.eval()
  print
  print "U"
  print U.eval()
  print
  print "einsum result"
  print prod.eval()
  print

  print "numpy, example 0"
  print np.matmul(M.eval()[0], U.eval()[0])
  print
  print "numpy, example 1"
  print np.matmul(M.eval()[1], U.eval()[1])

输出:

M
[[[ 0  1  2  3]
  [ 4  5  6  7]
  [ 8  9 10 11]]

 [[12 13 14 15]
  [16 17 18 19]
  [20 21 22 23]]]

U
[[2 3 3 3]
 [3 3 3 3]]

einsum result
[[ 18  62 106]
 [162 210 258]]

numpy, example 0
[ 18  62 106]

numpy, example 1
[162 210 258]