块状(n,1,m)至(n,m)

时间:2018-11-06 02:20:00

标签: python numpy

我正在研究一个涉及19个令牌的批次,每个令牌具有400个功能。将两个大小为(1,200)的向量连接到最终特征向量时,得到的形状为(19,1,400)。如果我将1挤出,则剩下(19,),但我尝试获得(19,400)。我尝试过转换为列表,压缩和整理,但没有任何效果。

有没有办法将此数组转换为正确的形状?

def attn_output_concat(sample):
  out_h, state_h = get_output_and_state_history(agent.model, sample)
  attns = get_attentions(state_h)
  inner_outputs = get_inner_outputs(state_h)
  if len(attns) != len(inner_outputs):
    print 'Length err'
  else:
    tokens = [np.zeros((400))] * largest
    print(tokens.shape)
    for j, (attns_token, inner_token) in enumerate(zip(attns, inner_outputs)):
      tokens[j] = np.concatenate([attns_token, inner_token], axis=1)
    print(np.array(tokens).shape)
    return tokens

1 个答案:

答案 0 :(得分:2)

最简单的方法是将标记声明为以numpy.shape =(19,400)开头的数组。这样还可以提高内存/时间效率。这是修改后的代码的相关部分...

import numpy as np
attns_token = np.zeros(shape=(1,200))
inner_token = np.zeros(shape=(1,200))
largest = 19
tokens = np.zeros(shape=(largest,400))
for j in range(largest):
    tokens[j] = np.concatenate([attns_token, inner_token], axis=1)
print(tokens.shape)

BTW ...如果您不包含独立且可运行的代码段,这将使人们很难为您提供帮助(这可能是为什么您尚未对此做出回应的原因)。最好使用上面的代码片段,这样可以帮助您获得更好的答案,因为对您要完成的目标的猜测较少。