我想将其重塑为(3,5,3)
所以我必须像:
from sklearn import tree
import graphviz
tree_graph = tree.export_graphviz(tree_model, out_file=None, feature_names=feature_names)
graphviz.Source(tree_graph)
我尝试了reshape(3,5,3),但是没有得到想要的结果吗?
答案 0 :(得分:1)
您的输入数组的形状为(3, 3, 5)
,并且您希望将其重塑为(3, 5, 3)
。有很多方法可以做到这一点。以下是一些内容,正如评论中所述:
首先将使用接受newshape
作为参数的numpy.reshape()
:
In [77]: arr = np.arange(3*3*5).reshape(3, 3, 5)
# reshape to desired shape
In [78]: arr = arr.reshape((3, 5, 3))
In [79]: arr.shape
Out[79]: (3, 5, 3)
或者您可以像下面这样使用numpy.transpose()
:
In [80]: arr = np.arange(3*3*5).reshape(3, 3, 5)
In [81]: arr.shape
Out[81]: (3, 3, 5)
# now, we want to move the last axis which is 2 to second position
# thus our new shape would be `(3, 5, 3)`
In [82]: arr = np.transpose(arr, (0, 2, 1))
In [83]: arr.shape
Out[83]: (3, 5, 3)
另一种方法是使用numpy.moveaxis()
:
In [87]: arr = np.arange(3*3*5).reshape(3, 3, 5)
# move the last axis (-1) to 2nd position (1)
In [88]: arr = np.moveaxis(arr, -1, 1)
In [89]: arr.shape
Out[89]: (3, 5, 3)
另一种方法是只使用numpy.swapaxes()
交换轴:
In [90]: arr = np.arange(3*3*5).reshape(3, 3, 5)
In [91]: arr.shape
Out[91]: (3, 3, 5)
# swap the position of ultimate and penultimate axes
In [92]: arr = np.swapaxes(arr, -1, 1)
In [93]: arr.shape
Out[93]: (3, 5, 3)
选择对您更直观的方法,因为所有方法都会返回所需形状的新视图。
尽管以上所有方法都返回了视图,但是在时间上还是有些差异。因此,这样做(出于效率考虑)的首选方法是:
In [124]: arr = np.arange(3*3*5).reshape(3, 3, 5)
In [125]: %timeit np.swapaxes(arr, -1, 1)
456 ns ± 6.79 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In [126]: %timeit np.transpose(arr, (0, 2, 1))
458 ns ± 6.93 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In [127]: %timeit np.reshape(arr, (3, 5, 3))
635 ns ± 9.06 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In [128]: %timeit np.moveaxis(arr, -1, 1)
3.42 µs ± 79.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
numpy.swapaxes()
和numpy.transpose()
花费的时间几乎相同,其中numpy.reshape()
稍慢一些,而numpy.moveaxis
是最慢的。因此,最好使用swapaxes
或transpose
ufunc。
答案 1 :(得分:0)
我找到了一种使用List comprehension
和Numpy transpose
的方法。
代码:
import numpy as np
database = [
[
[1,2,3,4,5],
[6,7,8,9,10],
[11,12,13,14,15]
],
[
[16,17,18,19,20],
[21,22,23,24,25],
[26,27,28,29,30]
],
[
[31,32,33,34,35],
[36,37,38,39,40],
[41,42,43,44,45]
]
]
ans = [np.transpose(data) for data in database]
print(ans)
输出:
[array([[ 1, 6, 11],
[ 2, 7, 12],
[ 3, 8, 13],
[ 4, 9, 14],
[ 5, 10, 15]]),
array([[16, 21, 26],
[17, 22, 27],
[18, 23, 28],
[19, 24, 29],
[20, 25, 30]]),
array([[31, 36, 41],
[32, 37, 42],
[33, 38, 43],
[34, 39, 44],
[35, 40, 45]])]