给出3维调度批处理的三维数据批量的最后一个调整的n长度子集的numpy方法

时间:2018-05-21 10:50:44

标签: python numpy

任务:给定'values'和'ind'以最笨拙的方式获得'结果'。

输入

import numpy as np
values = np.reshape(np.array([x/100 for x in range(4*5*10)]), (4, 5, 10))
ind = np.reshape(np.array([np.random.randint(0,10) for x in range(4*5*5)]), (4, 5, 5))

样本所需的输出:

result = np.array([[[0.08, 0.02, 0.03, 0.01, 0.  ],
    [0.18, 0.15, 0.17, 0.19, 0.17],
    [0.29, 0.27, 0.24, 0.27, 0.2 ],
    [0.39, 0.37, 0.33, 0.37, 0.3 ],
    [0.46, 0.47, 0.48, 0.43, 0.49]],

   [[0.56, 0.58, 0.57, 0.55, 0.52],
    [0.63, 0.61, 0.63, 0.6 , 0.62],
    [0.77, 0.74, 0.73, 0.71, 0.7 ],
    [0.88, 0.82, 0.87, 0.82, 0.83],
    [0.96, 0.95, 0.93, 0.98, 0.94]],

   [[1.08, 1.09, 1.04, 1.02, 1.05],
    [1.18, 1.16, 1.15, 1.12, 1.17],
    [1.28, 1.29, 1.27, 1.21, 1.27],
    [1.38, 1.38, 1.31, 1.35, 1.32],
    [1.41, 1.49, 1.42, 1.48, 1.46]],

   [[1.59, 1.5 , 1.56, 1.53, 1.51],
    [1.6 , 1.69, 1.69, 1.6 , 1.68],
    [1.79, 1.73, 1.72, 1.74, 1.77],
    [1.84, 1.84, 1.83, 1.88, 1.8 ],
    [1.98, 1.99, 1.91, 1.95, 1.92]]])

编辑:我的不好,忘了指定随机种子。
编辑:非numpyic版本的代码是:

result_ = np.zeros_like(result)
for batch_idx in range(len(values)):
    for word_idx in range(len(values[0])):
        result_[batch_idx][word_idx] = values[batch_idx,word_idx, ind[batch_idx, word_idx]]

1 个答案:

答案 0 :(得分:1)

我认为你需要的是:

import numpy as np

np.random.seed(100)
values = np.reshape(np.array([x/100 for x in range(4*5*10)]), (4, 5, 10))
ind = np.reshape(np.array([np.random.randint(0,10) for x in range(4*5*5)]), (4, 5, 5))

ii = np.arange(values.shape[0])[:, np.newaxis, np.newaxis]
jj = np.arange(values.shape[1])[np.newaxis, :, np.newaxis]
result = values[ii, jj, ind]
print(result)

输出:

[[[0.08 0.08 0.03 0.07 0.07]
  [0.1  0.14 0.12 0.15 0.12]
  [0.22 0.22 0.21 0.2  0.28]
  [0.34 0.3  0.39 0.36 0.32]
  [0.44 0.41 0.45 0.43 0.44]]

 [[0.54 0.53 0.57 0.51 0.51]
  [0.67 0.67 0.6  0.62 0.69]
  [0.79 0.73 0.72 0.75 0.78]
  [0.81 0.8  0.87 0.86 0.82]
  [0.9  0.98 0.92 0.95 0.91]]

 [[1.08 1.01 1.05 1.04 1.02]
  [1.18 1.13 1.15 1.1  1.19]
  [1.23 1.26 1.23 1.24 1.27]
  [1.36 1.33 1.39 1.3  1.34]
  [1.44 1.45 1.47 1.46 1.46]]

 [[1.52 1.54 1.52 1.57 1.51]
  [1.66 1.66 1.6  1.67 1.62]
  [1.73 1.75 1.74 1.72 1.74]
  [1.83 1.87 1.89 1.8  1.8 ]
  [1.95 1.99 1.96 1.96 1.95]]]