我想自动将1个额外的维度添加到numpy数组中。我应该如何设置?
编辑:
#TODO: This feels like it could be automated...
def add_batch(arr):
if arr.ndim == 2:
arr = np.reshape(arr, (arr.shape[0], arr.shape[1], 1))
elif arr.ndim == 3:
arr = np.reshape(arr, (arr.shape[0], arr.shape[1], arr.shape[2], 1))
答案 0 :(得分:1)
您可以使用*
定义功能
import numpy as np
def add_batch(arr):
if arr.ndim >= 2:
arr = np.reshape(arr, (*arr.shape, 1))
return arr
测试功能
arr = np.random.randint(0, 100, (5,6))
print (add_batch(arr).shape)
# (5, 6, 1)
arr = np.random.randint(0, 100, (5,6, 7))
print (add_batch(arr).shape)
# (5, 6, 7, 1)