张量增长的张量有效方法

时间:2016-07-06 10:46:21

标签: python tensorflow deep-learning

我在tensorflow中有两个张量,第一个张量是3-D,第二个是2D。我想像这样繁殖他们:

x = tf.placeholder(tf.float32, shape=[sequence_length, batch_size, hidden_num]) 
w = tf.get_variable("w", [hidden_num, 50]) 
b = tf.get_variable("b", [50])

output_list = []
for step_index in range(sequence_length):
    output = tf.matmul(x[step_index,  :,  :], w) + b
    output_list.append(output)
output = tf.pack(outputs_list)

我使用循环来进行乘法运算,但我觉得它太慢了。什么是使这个过程尽可能简单/干净的最佳方法?

2 个答案:

答案 0 :(得分:2)

您可以使用batch_matmul。不幸的是,batch_matmul似乎不支持批量维度广播,因此您必须平铺w矩阵。这将使用更多内存,但所有操作都将保留在TensorFlow

a = tf.ones((5, 2, 3))
b = tf.ones((3, 1))
b = tf.reshape(b, (1, 3, 1))
b = tf.tile(b, [5, 1, 1])
c = tf.batch_matmul(a, b) # use tf.matmul in TF 1.0 
sess = tf.InteractiveSession()
sess.run(tf.shape(c))

这给出了

array([5, 2, 1], dtype=int32)

答案 1 :(得分:1)

您可以使用map_fn,它会沿第一维扫描函数。

x = tf.placeholder(tf.float32, shape=[sequence_length, batch_size, hidden_num]) 
w = tf.get_variable("w", [hidden_num, 50]) 
b = tf.get_variable("b", [50])

def mul_fn(current_input):
    return tf.matmul(current_input, w) + b

output = tf.map_fn(mul_fn, x)

I used this at one point沿着序列实现softmax扫描。