在不知道批量大小的情况下进行三维批量矩阵乘法

时间:2018-01-10 05:38:51

标签: arrays tensorflow machine-learning neural-network matrix-multiplication

我正在编写一个张量流程序,需要将一批2-D张量(形状[None,...]的3-D张量)与2-D矩阵W相乘。这需要将W转换为3-D矩阵,这需要知道批量大小。

我无法做到这一点; tf.batch_matmul不再可用,x.get_shape().as_list()[0]返回None,这对于重新整形/平铺操作无效。有什么建议?我见过有些人使用config.cfg.batch_size,但我不知道那是什么。

1 个答案:

答案 0 :(得分:2)

解决方案是使用tf.shape(在运行时返回形状 )和tf.tile(接受动态形状)的组合

x = tf.placeholder(shape=[None, 2, 3], dtype=tf.float32)
W = tf.Variable(initial_value=np.ones([3, 4]), dtype=tf.float32)
print(x.shape)                # Dynamic shape: (?, 2, 3)

batch_size = tf.shape(x)[0]   # A tensor that gets the batch size at runtime
W_expand = tf.expand_dims(W, axis=0)
W_tile = tf.tile(W_expand, multiples=[batch_size, 1, 1])
result = tf.matmul(x, W_tile) # Can multiply now!

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  feed_dict = {x: np.ones([10, 2, 3])}
  print(sess.run(batch_size, feed_dict=feed_dict))    # 10
  print(sess.run(result, feed_dict=feed_dict).shape)  # (10, 2, 4)