TensorFlow中的函数与PyTorch中的expand()相同?

时间:2018-01-12 12:22:04

标签: python tensorflow pytorch

假设我有一个2 x 3矩阵,我想创建一个6 x 2 x 3矩阵,其中第一维中的每个元素都是原始的2 x 3矩阵。

在PyTorch中,我可以这样做:

import torch
from torch.autograd import Variable
import numpy as np

x = np.array([[1, 2, 3], [4, 5, 6]])
x = Variable(torch.from_numpy(x))

# y is the desired result
y = x.unsqueeze(0).expand(6, 2, 3)

在TensorFlow中执行此操作的等效方法是什么?我知道unsqueeze()等同于tf.expand_dims(),但我没有TensorFlow具有等同于expand()的任何内容。我正在考虑在1 x 2 x 3张量列表中使用tf.concat,但我不确定这是否是最好的方法。

2 个答案:

答案 0 :(得分:3)

pytorch expand的等效功能是tensorflow tf.broadcast_to

文档:https://www.tensorflow.org/api_docs/python/tf/broadcast_to

答案 1 :(得分:0)

Tensorflow自动广播,因此通常您不需要执行任何操作。假设您的y'形状为6x2x3且x形状为2x3,那么您已经可以y'*xy'+x已经表现得像您一样扩大了它。但是如果由于某些其他原因你确实需要这样做,那么tensorflow中的命令是tile

y = tf.tile(tf.reshape(x, (1,2,3)), multiples=(6,1,1))

文档:https://www.tensorflow.org/api_docs/python/tf/tile