
时间:2019-05-12 13:37:26

标签: python python-3.x list numpy reshape

我有以下Numpy数组。 enter image description here



from sklearn import tree
import graphviz

tree_graph = tree.export_graphviz(tree_model, out_file=None, feature_names=feature_names)


2 个答案:

答案 0 :(得分:1)

您的输入数组的形状为(3, 3, 5),并且您希望将其重塑为(3, 5, 3)。有很多方法可以做到这一点。以下是一些内容,正如评论中所述:


In [77]: arr = np.arange(3*3*5).reshape(3, 3, 5)

# reshape to desired shape
In [78]: arr = arr.reshape((3, 5, 3))

In [79]: arr.shape
Out[79]: (3, 5, 3)


In [80]: arr = np.arange(3*3*5).reshape(3, 3, 5)

In [81]: arr.shape
Out[81]: (3, 3, 5)

# now, we want to move the last axis which is 2 to second position
# thus our new shape would be `(3, 5, 3)`
In [82]: arr = np.transpose(arr, (0, 2, 1))

In [83]: arr.shape
Out[83]: (3, 5, 3)


In [87]: arr = np.arange(3*3*5).reshape(3, 3, 5)

# move the last axis (-1) to 2nd position (1)
In [88]: arr = np.moveaxis(arr, -1, 1)

In [89]: arr.shape
Out[89]: (3, 5, 3)


In [90]: arr = np.arange(3*3*5).reshape(3, 3, 5)

In [91]: arr.shape
Out[91]: (3, 3, 5)

# swap the position of ultimate and penultimate axes
In [92]: arr = np.swapaxes(arr, -1, 1)

In [93]: arr.shape
Out[93]: (3, 5, 3)



In [124]: arr = np.arange(3*3*5).reshape(3, 3, 5)

In [125]: %timeit np.swapaxes(arr, -1, 1)
456 ns ± 6.79 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

In [126]: %timeit np.transpose(arr, (0, 2, 1))
458 ns ± 6.93 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

In [127]: %timeit np.reshape(arr, (3, 5, 3))
635 ns ± 9.06 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

In [128]: %timeit np.moveaxis(arr, -1, 1)
3.42 µs ± 79.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

numpy.swapaxes()numpy.transpose()花费的时间几乎相同,其中numpy.reshape()稍慢一些,而numpy.moveaxis是最慢的。因此,最好使用swapaxestranspose ufunc。

答案 1 :(得分:0)

我找到了一种使用List comprehensionNumpy transpose的方法。


import numpy as np
database = [

ans = [np.transpose(data) for data in database]


[array([[ 1,  6, 11],
       [ 2,  7, 12],
       [ 3,  8, 13],
       [ 4,  9, 14],
       [ 5, 10, 15]]), 
 array([[16, 21, 26],
       [17, 22, 27],
       [18, 23, 28],
       [19, 24, 29],
       [20, 25, 30]]), 
 array([[31, 36, 41],
       [32, 37, 42],
       [33, 38, 43],
       [34, 39, 44],
       [35, 40, 45]])]