使用Keras.js将Keras模型实施到网站

时间:2017-08-02 20:09:33

标签: javascript keras

我一直在尝试使用Keras.js库将Python生成的基本Keras模型实现到网站中。现在,我对模型进行了培训,并将其导出到model.jsonmodel_weights.bufmodel_metadata.json文件中。现在,我基本上从github页面复制并粘贴了测试代码,以查看模型是否会在浏览器中加载,但不幸的是我收到了错误。这是测试代码。 (编辑:我修正了一些错误,请参阅下面的剩余错误。)

var model = new KerasJS.Model({
    filepaths: {
        model: 'dist/model.json',
        weights: 'dist/model_weights.buf',
        metadata: 'dist/model_metadata.json'
  },
  gpu: true
});

    model.ready()
  .then(function() {
    console.log("1");
    // input data object keyed by names of the input layers
    // or `input` for Sequential models
    // values are the flattened Float32Array data
    // (input tensor shapes are specified in the model config)
    var inputData = {
      'input_1': new Float32Array(data)
    };
    console.log("2 " + inputData);
    // make predictions
    return model.predict(inputData);
  })
  .then(function(outputData) {
    // outputData is an object keyed by names of the output layers
    // or `output` for Sequential models
    // e.g.,
    // outputData['fc1000']
    console.log("3 " + outputData);
  })
  .catch(function(err) {
    console.log(err);
    // handle error
  });

编辑:所以我改变了我的程序,以便与JS 5兼容(这对我来说是一个愚蠢的错误),现在我遇到了一个不同的错误。捕获此错误然后记录。我得到的错误是:Error: predict() must take an object where the keys are the named inputs of the model: input.我认为出现此问题是因为我的data变量格式不正确。我认为如果我的模型采用28x28数组,那么data也应该是28x28数组,以便它可以正确地预测"正确的输出。但是,我相信我错过了一些东西,这就是错误被抛出的原因。 This问题与我的问题非常相似,但它在python而不是JS中。再次,任何帮助将不胜感激。

1 个答案:

答案 0 :(得分:0)

好的,所以我想出了为什么会这样。有两个问题。首先,$ stack build Warning: /home/matthew/backup/azara_work/platform/api/stack.yaml: Unrecognized field in NixOptsMonoid: system-ghc Cloning into '/home/matthew/backup/azara_work/platform/api/.stack-work/downloaded/4FnxEtHDACVR'... Permission denied (publickey). fatal: Could not read from remote repository. Please make sure you have the correct access rights and the repository exists. Process exited with ExitFailure 128: /nix/var/nix/profiles/default/bin/git clone --recursive git@github.com:seanhess/rollbar-haskell.git /home/matthew/backup/azara_work/platform/api/.stack-work/downloaded/4FnxEtHDACVR 数组需要展平,所以我写了一个快速函数来获取2D输入并将其“展平”为长度为784的一维数组。然后,因为我使用了顺序模型,所以数据名称不应该是data,而只是'input_1'。这摆脱了所有的错误。

现在,要获取输出信息,我们只需将其存储在如下数组中:'input'。因为我使用了MNIST数据集,var out = outputData['output']是一个长度为10的一维数组,其中包含每个数字是用户写入数字的概率。从那里,您可以简单地找到概率最高的数字,并将其用作模型的预测。