我正在使用RL4J(DeepLearning4J中集成的强化学习框架)让汽车在电子游戏的轨道上完成一圈。
我在训练后使用以下代码保存模型:
QLearningDiscreteConv<ScreenFrameState> dql = new QLearningDiscreteConv(mdp, RACING_NET_CONFIG, RACING_HP, RACING_QL, manager);
dql.train();
dql.getNeuralNet().save(model);
保存模型后,我想看看它是如何表现的,所以我加载它来播放它:
DQN load = DQN.load(model);
QLearningDiscreteConv<ScreenFrameState> dql = new QLearningDiscreteConv(mdp, load, RACING_HP, RACING_QL, manager);
dql.getPolicy().play(mdp);
但加载此错误时失败:
org.deeplearning4j.exception.DL4JInvalidInputException: Cannot do forward pass in Convolution layer (layer name = layer0, layer index = 0): input array depth does not match CNN layer configuration (data input depth = 109, [minibatch,inputDepth,height,width]=[1, 109, 150, 3]; expected input depth = 10) (layer name: layer0, layer index: 0)
at org.deeplearning4j.nn.layers.convolution.ConvolutionLayer.preOutput(ConvolutionLayer.java:294)
at org.deeplearning4j.nn.layers.convolution.ConvolutionLayer.preOutput(ConvolutionLayer.java:248)
at org.deeplearning4j.nn.layers.convolution.ConvolutionLayer.activate(ConvolutionLayer.java:392)
at org.deeplearning4j.nn.layers.AbstractLayer.activate(AbstractLayer.java:309)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.activationFromPrevLayer(MultiLayerNetwork.java:789)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.feedForwardToLayer(MultiLayerNetwork.java:929)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.feedForward(MultiLayerNetwork.java:870)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.feedForward(MultiLayerNetwork.java:861)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.silentOutput(MultiLayerNetwork.java:1906)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.output(MultiLayerNetwork.java:1898)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.output(MultiLayerNetwork.java:1871)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.output(MultiLayerNetwork.java:1952)
at org.deeplearning4j.rl4j.network.dqn.DQN.output(DQN.java:49)
at org.deeplearning4j.rl4j.policy.DQNPolicy.nextAction(DQNPolicy.java:32)
at org.deeplearning4j.rl4j.policy.DQNPolicy.nextAction(DQNPolicy.java:18)
at org.deeplearning4j.rl4j.policy.Policy.play(Policy.java:72)
at org.deeplearning4j.rl4j.policy.Policy.play(Policy.java:27)
at me.andreaiacono.racinglearning.rl.QLearning.race(QLearning.java:81)
at me.andreaiacono.racinglearning.core.player.QLearningPlayer.race(QLearningPlayer.java:19)
at me.andreaiacono.racinglearning.gui.GameWorker.doInBackground(GameWorker.java:56)
at me.andreaiacono.racinglearning.gui.GameWorker.doInBackground(GameWorker.java:11)
at javax.swing.SwingWorker$1.call(SwingWorker.java:295)
at java.util.concurrent.FutureTask.run(FutureTask.java:266)
at javax.swing.SwingWorker.run(SwingWorker.java:334)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
at java.lang.Thread.run(Thread.java:748)
输入正确:我的屏幕是150 * 109像素,有3个颜色通道;加载时为什么期望大小为10?我错过了什么?
谢谢, 安德烈
答案 0 :(得分:0)
(数据输入深度= 109,[minibatch,inputDepth,height,width] = [1,109,150,3];预期输入深度= 10)
看起来您将inputDepth
设置为109
,而应将其设置为3
(频道数)。我个人不熟悉dl4j
,所以不确定为什么它会说&#34;预期输入深度= 10&#34;但我想你至少可以尝试切换你的顺序提出了这些论点。
答案 1 :(得分:0)
您使用什么版本?有时,如果您使用快照存储库,则可能会出现暂时性错误,但是伙计们很快就将其修复。因此,此时您可能偶然从快照中获取了代码。使用稳定版本。