我有一个形状为(1,64,96,4)的numpy堆栈,其中(64,96)是与从屏幕(帧)捕获的图像相对应的像素,而4是堆栈的帧数。因此,基本上堆栈具有4个2d数组(64,96)。该堆栈用作Deep q网络的输入。
我想用matplotlib显示不同的帧,但是我没有设法访问正确的元素。
以下是用于创建框架堆栈的代码:
frame = self.preprocess_frame(frame) # (64,96)
state = np.stack([frame] * 4, axis=-1)
state = state.reshape(1, state.shape[0], state.shape[1], state.shape[2]) # 1*64*96*4, the extra dimension is only used for Keras library
return state
问题在于如何从堆栈中访问不同的帧,以便将其馈送到matplotlib.imshow
我已经尝试过只得到一帧,但效果不佳:
plt.imshow(state[1], state[2])
plt.show()
预先感谢您!