在TensorFlow中没有广播tf.matmul

时间:2016-06-27 10:06:24

标签: tensorflow broadcasting

我有一个问题,我一直在努力。它与tf.matmul()及其缺少广播有关。

我在https://github.com/tensorflow/tensorflow/issues/216上发现了类似的问题,但tf.batch_matmul()对我的案例来说似乎不是一个解决方案。

我需要将输入数据编码为4D张量: X = tf.placeholder(tf.float32, shape=(None, None, None, 100)) 第一个维度是批次的大小,第二个维度是批次中的条目数量。 您可以将每个条目想象为多个对象的组合(第三维)。最后,每个对象由100个浮点值的向量描述。

请注意,我对第二维和第三维使用了None,因为实际大小可能会在每个批次中发生变化。但是,为简单起见,让我们用实际数字来形成张量: X = tf.placeholder(tf.float32, shape=(5, 10, 4, 100))

这些是我计算的步骤:

  1. 计算100个浮点值的每个向量的函数(例如,线性函数) W = tf.Variable(tf.truncated_normal([100, 50], stddev=0.1)) Y = tf.matmul(X, W) 问题tf.matmul()没有广播,tf.batch_matmul()没有成功 预期的Y形状:(5,10,4,50)

  2. 为批处理的每个条目应用平均池(通过每个条目的对象): Y_avg = tf.reduce_mean(Y, 2) 预期的Y_avg形状:(5,10,50)

  3. 我预计tf.matmul()会支持广播。然后我找到了tf.batch_matmul(),但看起来它似乎不适用于我的情况(例如,W需要至少有3个维度,不清楚原因)。

    BTW,上面我使用了一个简单的线性函数(其权重存储在W中)。但在我的模型中,我有一个深层网络。因此,我遇到的更普遍的问题是自动计算张量的每个切片的函数。这就是我期望tf.matmul()会有广播行为的原因(如果是这样,可能甚至不需要tf.batch_matmul()。)

    期待向您学习! Alessio的

2 个答案:

答案 0 :(得分:7)

您可以通过重塑X来塑造[n, d]来实现这一目标,其中d是计算的单个“实例”的维度(在您的示例中为100)和{{1} }是多维对象中的那些实例的数量(在您的示例中为n)。重塑后,您可以使用5*10*4=200然后重新塑造回所需的形状。前三个维度可以变化的事实使得这一点很棘手,但您可以使用tf.matmul来确定运行时的实际形状。最后,您可以执行计算的第二步,该步骤应该是相应维度上的简单tf.shape。总而言之,它看起来像这样:

tf.reduce_mean

答案 1 :(得分:2)

根据您链接的GitHub issue的重命名标题,您应该使用tf.tensordot()。它可以根据Numpy的tensordot()使两个张量之间的轴对收缩。对于您的情况:

X = tf.placeholder(tf.float32, shape=(5, 10, 4, 100))
W = tf.Variable(tf.truncated_normal([100, 50], stddev=0.1))
Y = tf.tensordot(X, W, [[3], [0]])  # gives shape=[5, 10, 4, 50]