训练导入到deeplearning4j的CNN Conv1D Keras模型时产生异常

时间:2019-10-23 11:00:22

标签: keras deeplearning4j

我在Jupyter中定义了一个非常简单的CNN

model = Sequential()

model.add(Conv1D(32, 12, activation='relu', padding='same', input_shape=(X_train.shape[1],X_train.shape[2])))
model.add(Conv1D(64, 12, activation='relu', padding='same'))
model.add(Dropout(0.5))
model.add(Dense(1, activation='sigmoid'))
​
model.summary()

model.compile(loss='mse', optimizer='sgd')
  

_________________________________________________________________图层(类型)输出形状参数#
  ================================================== ============== conv1d_7(Conv1D)(无,675,32)416
  _________________________________________________________________ conv1d_8(Conv1D)(无,675,64)24640
  _________________________________________________________________ dropout_4(Dropout)(None,675,64)0
  _________________________________________________________________ density_4(密集)(无,675,1)65
  ================================================== ===============总参数:25,121可训练参数:25,121非可训练参数:0

     

我已将模型(和权重)另存为.h5文件,可以将其导入到我的Java应用程序中。这可以正常工作,我可以使用此模型生成预测。但是,我也想用Java重新训练模型。

使用此代码段

    MultiLayerNetwork model = KerasModelImport.importKerasSequentialModelAndWeights(modelConf.getAbsolutePath());
    int nSamp = normFbe.length; // 3545
    int nChan = normFbe[0].length; // 675;
    INDArray X = Nd4j.create(normFbe);

    // Create the ground truth
    INDArray y_truth = generateTruthData(nSamp, nChan);

    for (int i = 0; i < nSamp; i++) {
        INDArray X_train = X.getRow(i);
        INDArray y_train = y_truth.getRow(i);
        X_train = X_train.reshape(1, 1, 675);
        y_train = y_train.reshape(1, 1, 675);
        model.fit(X_train, y_train);
    }

但引发异常并生成错误消息:

  

错误19/10/19 11:55:25,412 [DxS Worker-0]   org.nd4j.linalg.cpu.nativecpu.ops.NativeOpExecutioner-失败   执行运算乘法。尝试执行2个输入,1个输出,0个   targs,0 bargs和0 iargs。输入:[(FLOAT,[1,32,675,1],c),   (FLOAT,[1,32,675],c)]。输出:[(FLOAT,[1,32,675,1],c)]。 tArgs:-。   iArgs:-。 bArgs:-。运营商自己的名字:   “ d4e50f04-f0b2-4eb4-9276-439b871087ad”-请参见以上消息   (从c ++打印)以获取可能的错误原因。错误23/10/19   11:55:27,568 [DxS Worker-0] DxS-执行工具机器学习   java.lang.RuntimeException:Op [multiply]执行失败

有什么想法吗?在我看来,数据形状在输入或各层之间是不匹配的,但我已经没有足够的想法来修复它。任何帮助表示赞赏。

David Robb

0 个答案:

没有答案
相关问题