以元素方式计算中的矢量批量广播

时间:2016-10-13 19:52:59

标签: tensorflow

我正在尝试使用一批向量并使用广播进行元素明智的减法,以获得所有组合之间不同的矩阵。我可以使用一批长度为1的工作,但是当我尝试增加样本数量时,我会遇到各种形状匹配错误,并且不相信它会再次播放。下面是获得单个批处理工作的示例代码,以及我尝试过的一些其他输入但没有成功获得一批2个工作:

import tensorflow as tf

#initx = [[1.0, 2.0, 3.0, 4.0],[1.0, 2.0, 3.0, 4.0]]
#initx = [[[1.0, 2.0, 3.0, 4.0]],[[1.0, 2.0, 3.0, 4.0]]]
initx = [[1.0, 2.0, 3.0, 4.0]]

x = tf.placeholder(dtype=tf.float32)

deltas = tf.sub(x,tf.transpose(x))

reshaped_deltas = tf.reshape(deltas,[-1])

with tf.Session('') as session:
  session.run(tf.initialize_all_variables())   

  print "Delta:",session.run([deltas],feed_dict={x:initx })
  print "Flattened Output:",session.run([reshaped_deltas],feed_dict={x:initx })

我得到了单个例子的预期结果:

Delta: [array([[ 0.,  1.,  2.,  3.],
       [-1.,  0.,  1.,  2.],
       [-2., -1.,  0.,  1.],
       [-3., -2., -1.,  0.]], dtype=float32)]
Flattened Output: [array([ 0.,  1.,  2.,  3., -1.,  0.,  1.,  2., -2., -1.,  0.,  1., -3.,
       -2., -1.,  0.], dtype=float32)]

我无法弄清楚如何让“tf.sub()”函数与批次一起使用,并且仍然为每批次正确地播放[1,4]向量。

有谁知道怎么做?我知道有一个tf.batch_matmul()但不是batch_sub()可能会解决问题。

编辑:根据Yaroslav Bulatov的反馈更新了解决问题的脚本

import tensorflow as tf

initx = [[1.5, 2.0, 3.0, 4.0],[1.0, 2.0, 3.0, 4.0]]
#initx = [[1.0, 2.0, 3.0, 4.0]]

VectorSize = len(initx[1])

x = tf.placeholder(dtype=tf.float32)

batch1 = tf.reshape(x, (-1,VectorSize, 1))
deltas = tf.sub(batch1, tf.transpose(batch1, (0, 2, 1)))

reshaped_deltas = tf.reshape(deltas,[-1])

with tf.Session('') as session:
  session.run(tf.initialize_all_variables())   

  print "Delta:",session.run([deltas],feed_dict={x:initx })
  print "Flattened Output:",session.run([reshaped_deltas],feed_dict={x:initx })

1 个答案:

答案 0 :(得分:2)

假设您的批量大小为n且数据大小为k。如果您对sub的输入具有形状n, k, 1n, 1, k,则广播将填充单个维度以输出形状n, k, k的结果,这是您想要的。因此,可以使用tf.reshape将原始文件转换为n, k, 1tf.transpose(..., perm=(0, 2, 1))以获得n, 1, k形状。 IE,就像这样

x1 = tf.constant([1,2,3])
x2 = tf.constant([4,4,5])
batch = tf.pack([x1,x2])
n = 2
k = 3
batch1 = tf.reshape(batch, (n, k, 1))
sess = tf.Session()
sess.run(tf.sub(batch1, tf.transpose(batch1, (0, 2, 1))))

Out[] = array([[[ 0, -1, -2],
        [ 1,  0, -1],
        [ 2,  1,  0]],

       [[ 0,  0, -1],
        [ 0,  0, -1],
        [ 1,  1,  0]]], dtype=int32)