火炬代码:
att_h = att_h.unsqueeze(1).expand_as(att)
att_h形状为(10,512)
att形状为(10,196,512)
Keras代码:
K.expand_dims(att_h, 1).expand_as(att)
出现错误: “张量”对象没有属性“ expand_as”
不确定在keras中如何做同样的事情。
答案 0 :(得分:0)
在Keras中没有任何内置函数。但是,您可以使用:np.reshape
获得相同的结果。
因此您可以:
K.expand_dims(att_h,1).reshape(att.shape)