因此,我在迁移学习示例中工作。我的任务是拍摄VGG-16,使用传递学习对其进行修改和训练,以便仅预测2个标签(“猫”和“狗”)。这是我的代码,它的格式不正确,所以请不要被“混乱”分散注意力:
static DataSetIterator trainIter;
public static void main(String[] args) throws Exception {
SpringApplication.run(AiProjectDl4jApplication.class, args);
int seed = 12345;
int numClasses = 2;
ZooModel zooModel = VGG16.builder().build();
ComputationGraph pretrainedNet = (ComputationGraph) zooModel.initPretrained(PretrainedType.IMAGENET);
FineTuneConfiguration fineTuneConf = new FineTuneConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(new Nesterovs(5e-5))
.seed(seed)
.build();
ComputationGraph vgg16Transfer = new TransferLearning.GraphBuilder(pretrainedNet)
.fineTuneConfiguration(fineTuneConf)
.setFeatureExtractor("fc2")
.removeVertexKeepConnections("predictions")
.addLayer("predictions",
new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nIn(4096).nOut(numClasses)
.weightInit(WeightInit.XAVIER)
.activation(Activation.SOFTMAX).build(), "fc2")
.build();
TransferLearningHelper transferLearningHelper =
new TransferLearningHelper(vgg16Transfer, "fc2");
getImages();
DataSetPreProcessor preProcessor = new VGG16ImagePreProcessor() ;
trainIter.setPreProcessor(preProcessor);
vgg16Transfer.setListeners(new ScoreIterationListener(5));
log.info("Training starting...");
for (int i = 0; i < 5; i++) {
while (trainIter.hasNext()) {
DataSet currentFeaturized = trainIter.next();
vgg16Transfer.fit(currentFeaturized);
}
}
log.info("If no results shown -> Summary: Failed");
}
static void getImages() throws IOException {
Random rng = new Random();
File parentDir = new File("Path To DataSet");
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
ImageRecordReader recordReader = new ImageRecordReader(224,224,3,labelMaker);
recordReader.initialize(new FileSplit(new File(String.valueOf(parentDir))));
trainIter = new RecordReaderDataSetIterator(recordReader,150,1,2);
}
}
问题出在我尝试训练修改后的模型的地方。我只得到第一次迭代的结果(已经很高,为0.9171),然后得到此错误:
# A fatal error has been detected by the Java Runtime Environment:
#
# EXCEPTION_ACCESS_VIOLATION (0xc0000005) at pc=0x00007ffb426a3b29, pid=11708, tid=0x0000000000001a78
#
# JRE version: Java(TM) SE Runtime Environment (8.0_251-b08) (build 1.8.0_251-b08)
# Java VM: Java HotSpot(TM) 64-Bit Server VM (25.251-b08 mixed mode windows-amd64 compressed oops)
# Problematic frame:
# C [KERNELBASE.dll+0x43b29]
#
# Failed to write core dump. Minidumps are not enabled by default on client versions of Windows
几个月前,我已经尝试过建立和训练自己的网络,但遇到了同样的问题,没人能真正帮助我。这次我希望你能。您有什么可能的原因吗?