我当前的张量具有(3,2)的形状,例如,
[[ 1. 2.]
[ 2. 1.]
[-2. -1.]]
我想扩展为(1,3,2)的形状,每个第二维度是整个张量的复制品,例如,
[[[ 1. 2.]
[ 2. 1.]
[ -2. -1.]]
[[ 1. 2.]
[ 2. 1.]
[ -2. -1.]]
[[ 1. 2.]
[ 2. 1.]
[ -2. -1.]]]
我尝试了以下代码,但它只复制了每一行。
tiled_vecs = tf.tile(tf.expand_dims(input_vecs, 1),
[1, 3, 1])
结果
[[[ 1. 2.]
[ 1. 2.]
[ 1. 2.]]
[[ 2. 1.]
[ 2. 1.]
[ 2. 1.]]
[[-2. -1.]
[-2. -1.]
[-2. -1.]]]
答案 0 :(得分:11)
这应该有效,
(pf.shape(A)[0],1,1] * A
# Achieved by creating a 3d matrix as shown below
# and multiplying it with A, which is `broadcast` to obtain the desired result.
[[[1.]],
[[1.]], * A
[[1.]]]
代码示例:
#input
A = tf.constant([[ 1., 2.], [ 2. , 1.],[-2., -1.]])
B = tf.ones([tf.shape(A)[0], 1, 1]) * A
#output
array([[[ 1., 2.],
[ 2., 1.],
[-2., -1.]],
[[ 1., 2.],
[ 2., 1.],
[-2., -1.]],
[[ 1., 2.],
[ 2., 1.],
[-2., -1.]]], dtype=float32)
同样使用tf.tile
,我们可以获得相同的内容:
tf.tile(tf.expand_dims(A,0),[tf.shape(A)[0],1,1]