嗨,我正在尝试从MxNet中预先训练的模型(RsNet-152)中提取倒数第二层的输出。由于我需要脚本来处理Java应用程序,因此我选择使用scala作为语言的选择。
我按照此处https://mxnet.incubator.apache.org/tutorials/python/predict_image.html
所述的步骤进行操作,并相应地通过脚本进行了修改。 这是loadModel函数。
def loadResnetModel(modelPath: String): Module = {
val (net, argParams, auxParams) = Model.loadCheckpoint(modelPath, modelFileNumber)
val allLayer = net.getInternals()
val secondLastLayer = allLayer.get("flatten0_output")
val mod = new Module(symbolVar = secondLastLayer, contexts = Context.cpu(), labelNames =null)
val dataShape = ListMap("data" -> Shape(1, 3, 224, 224))
mod.bind(dataShapes=dataShape, forTraining = false)
mod.setParams(argParams, auxParams, allowMissing=true)
mod
尝试运行脚本时,出现以下错误。
Exception in thread "main" java.lang.IllegalArgumentException: requirement failed: Find name fc1_bias that is not in the arguments
[java] at scala.Predef$.require(Predef.scala:224)
[java] at org.apache.mxnet.Executor$$anonfun$copyParamsFrom$1.apply(Executor.scala:274)
[java] at org.apache.mxnet.Executor$$anonfun$copyParamsFrom$1.apply(Executor.scala:270)
[java] at scala.collection.immutable.HashMap$HashMap1.foreach(HashMap.scala:221)
[java] at scala.collection.immutable.HashMap$HashTrieMap.foreach(HashMap.scala:428)
[java] at scala.collection.immutable.HashMap$HashTrieMap.foreach(HashMap.scala:428)
[java] at scala.collection.immutable.HashMap$HashTrieMap.foreach(HashMap.scala:428)
[java] at org.apache.mxnet.Executor.copyParamsFrom(Executor.scala:270)
[java] at org.apache.mxnet.module.DataParallelExecutorGroup$$anonfun$setParams$1.apply(DataParallelExecutorGroup.scala:452)
[java] at org.apache.mxnet.module.DataParallelExecutorGroup$$anonfun$setParams$1.apply(DataParallelExecutorGroup.scala:452)
[java] at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33)
[java] at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:186)
[java] at org.apache.mxnet.module.DataParallelExecutorGroup.setParams(DataParallelExecutorGroup.scala:452)
[java] at org.apache.mxnet.module.Module.setParams(Module.scala:201)
P.S:我是mxnet和scala的新手。我没有看到任何明显的错误吗?
答案 0 :(得分:1)
您需要更改函数的最后一行:
您需要致电mod.setParams(argParams, auxParams)
git rebase -i