我正在通过学习他们的一些教程来试验deeplearning4j
。虹膜数据集是众所周知的,通过使用Weka(使用RandomForest或MultilayerPerceptron),我可以轻松地使F测量值达到近1(在以下示例中为0.97):
TP Rate FP Rate Precision Recall F-Measure MCC ROC Area PRC Area Class
1.000 0.000 1.000 1.000 1.000 1.000 1.000 1.000 Iris-setosa
0.960 0.020 0.960 0.960 0.960 0.940 0.996 0.993 Iris-versicolor
0.960 0.020 0.960 0.960 0.960 0.940 0.996 0.993 Iris-virginica Weighted
Avg. 0.973 0.013 0.973 0.973 0.973 0.960 0.998 0.995
我对deeplearning4j
Examples labeled as 0 classified by model as 0: 4 times
Examples labeled as 1 classified by model as 0: 12 times
Examples labeled as 2 classified by model as 0: 14 times
Warning: class 1 was never predicted by the model. This class was excluded from the average precision
Warning: class 2 was never predicted by the model. This class was excluded from the average precision
Accuracy: 0.1333 Precision: 0.1333 Recall: 0.3333 F1 Score: 0.1905
这里的代码(在Scala中)我正在使用:
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator
import org.deeplearning4j.eval.Evaluation
import org.deeplearning4j.nn.api.{Layer, OptimizationAlgorithm}
import org.deeplearning4j.nn.conf.{Updater, NeuralNetConfiguration}
import org.deeplearning4j.nn.conf.layers.{OutputLayer, RBM}
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork
import org.deeplearning4j.nn.weights.WeightInit
import org.deeplearning4j.ui.weights.HistogramIterationListener
import org.nd4j.linalg.factory.Nd4j
import org.nd4j.linalg.lossfunctions.LossFunctions
object Main extends App {
Nd4j.MAX_SLICES_TO_PRINT = -1
Nd4j.MAX_ELEMENTS_PER_SLICE = -1
Nd4j.ENFORCE_NUMERICAL_STABILITY = true
val inputNum = 4
var outputNum = 3
var numSamples = 150
var batchSize = 150
var iterations = 1000
var seed = 321
var listenerFreq = iterations/5
val learningRate = 1e-6
println("Load data....")
val iter = new IrisDataSetIterator(batchSize, numSamples)
val iris = iter.next()
iris.shuffle()
iris.normalizeZeroMeanZeroUnitVariance()
val testAndTrain = iris.splitTestAndTrain(0.80)
val train = testAndTrain.getTrain
val test = testAndTrain.getTest
println("Build model....")
val RMSE_XENT = LossFunctions.LossFunction.RMSE_XENT
val conf = new NeuralNetConfiguration.Builder()
.seed(seed)
.iterations(iterations)
.learningRate(learningRate)
.l1(1e-1).regularization(true).l2(2e-4)
.optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT)
.useDropConnect(true)
.list(2)
.layer(0, new RBM.Builder(RBM.HiddenUnit.RECTIFIED, RBM.VisibleUnit.GAUSSIAN)
.nIn(inputNum).nOut(3).k(1).activation("relu").weightInit(WeightInit.XAVIER).lossFunction(RMSE_XENT)
.updater(Updater.ADAGRAD).dropOut(0.5)
.build())
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.nIn(3).nOut(outputNum).activation("softmax").build())
.build()
val model = new MultiLayerNetwork(conf)
model.init()
model.setListeners(new HistogramIterationListener(listenerFreq))
println("Train model....")
model.fit(train.getFeatureMatrix)
println("Evaluate model....")
val eval = new Evaluation(outputNum)
val output = model.output(test.getFeatureMatrix, Layer.TrainingMode.TEST)
(0 until output.rows()).foreach { i =>
val actual = train.getLabels.getRow(i).toString.trim()
val predicted = output.getRow(i).toString.trim()
println("actual " + actual + " vs predicted " + predicted)
}
eval.eval(test.getLabels, output)
println(eval.stats())
}
答案 0 :(得分:1)
您可以尝试删除正则化:
.L1(1E-1).regularization(真).L2(2E-4) 这是我的代码:
public MultiLayerNetwork buildModel() {
int lowerSize = featureSize;
List<DenseLayer> hiddenLayers = new ArrayList<DenseLayer>();
for (int i = 0; i < hiddenLayersDim.length; i++) {
int higherSize = hiddenLayersDim[i];
hiddenLayers.add(new DenseLayer.Builder().nIn(lowerSize).nOut(higherSize)
.activation(activationType).weightInit(weightInit).build());
lowerSize = higherSize;
}
NeuralNetConfiguration.ListBuilder listBuilder = new NeuralNetConfiguration.Builder()
.seed(seed)
.iterations(iterations)
.optimizationAlgo(algorithm)
.learningRate(learningRate)
.updater(updater)
.list();
int layerCount = 0;
for (int i = 0; i < hiddenLayers.size(); i++) {
listBuilder.layer(layerCount, hiddenLayers.get(i));
layerCount += 1;
}
listBuilder.layer(layerCount, new OutputLayer.Builder(lossFunction)
.weightInit(weightInit)
.activation("softmax")
.nIn(lowerSize).nOut(labelSize).build());
MultiLayerConfiguration conf = listBuilder.pretrain(false).backprop(true).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.setListeners(new ScoreIterationListener(1));
this.net = net;
return net;
}