我有一个由4列组成的3D数组。我只想提取最后一列(从0.626到0.022)以进行绘制。
yhat = estimator.predict(Xtest, verbose=0)
array([[[ 0.02373935, -0.00453718, 0.01974867, 0.62695163]],
[[ 0.02106621, -0.00397644, 0.01761295, 0.62469435]],
[[ 0.02013615, -0.00363547, 0.01690874, 0.62409896]],
...,
[[-0.00133965, -0.0020067 , 0.00370775, 0.02289007]],
[[-0.00133965, -0.0020067 , 0.00370775, 0.02289007]],
[[-0.00133965, -0.0020067 , 0.00370775, 0.02289007]]],
dtype=float32)
“ shapes {} and {}”。format(x.shape,y.shape))
ValueError :x和y不能大于2-D,但形状为(912,)和(912,1,4)
答案 0 :(得分:0)
您可以使用numpy多维数组切片:
last_column = arr[:, :, -1]
答案 1 :(得分:0)
来自
last_column = yhat[:, :, -1]
我发生了一些奇怪的事情:如果我在编辑器中打印last_column,则会得到正确的顺序:
array([[ 6.2695163e-01],
[ 6.1686254e-01],
...
[ 2.2890069e-02],
[ 2.2890069e-02]], dtype=float32)
但是,如果我从终端打印,则会得到多行:
...
[0.57803965 0.5793313 0.5800704 ... 0.28363833 0.28206974 0.2802861 ]
[0.57803965 0.5793313 0.5800704 ... 0.28363833 0.28206974 0.2802861 ]
[0.57803965 0.5793313 0.5800704 ... 0.28363833 0.28206974 0.2802861 ]]
在图中,我得到多条曲线。