我一直在尝试使用Keras.js库将Python生成的基本Keras模型实现到网站中。现在,我对模型进行了培训,并将其导出到model.json
,model_weights.buf
和model_metadata.json
文件中。现在,我基本上从github页面复制并粘贴了测试代码,以查看模型是否会在浏览器中加载,但不幸的是我收到了错误。这是测试代码。 (编辑:我修正了一些错误,请参阅下面的剩余错误。)
var model = new KerasJS.Model({
filepaths: {
model: 'dist/model.json',
weights: 'dist/model_weights.buf',
metadata: 'dist/model_metadata.json'
},
gpu: true
});
model.ready()
.then(function() {
console.log("1");
// input data object keyed by names of the input layers
// or `input` for Sequential models
// values are the flattened Float32Array data
// (input tensor shapes are specified in the model config)
var inputData = {
'input_1': new Float32Array(data)
};
console.log("2 " + inputData);
// make predictions
return model.predict(inputData);
})
.then(function(outputData) {
// outputData is an object keyed by names of the output layers
// or `output` for Sequential models
// e.g.,
// outputData['fc1000']
console.log("3 " + outputData);
})
.catch(function(err) {
console.log(err);
// handle error
});
编辑:所以我改变了我的程序,以便与JS 5兼容(这对我来说是一个愚蠢的错误),现在我遇到了一个不同的错误。捕获此错误然后记录。我得到的错误是:Error: predict() must take an object where the keys are the named inputs of the model: input.
我认为出现此问题是因为我的data
变量格式不正确。我认为如果我的模型采用28x28数组,那么data
也应该是28x28数组,以便它可以正确地预测"正确的输出。但是,我相信我错过了一些东西,这就是错误被抛出的原因。 This问题与我的问题非常相似,但它在python而不是JS中。再次,任何帮助将不胜感激。
答案 0 :(得分:0)
好的,所以我想出了为什么会这样。有两个问题。首先,$ stack build
Warning: /home/matthew/backup/azara_work/platform/api/stack.yaml: Unrecognized field in NixOptsMonoid: system-ghc
Cloning into '/home/matthew/backup/azara_work/platform/api/.stack-work/downloaded/4FnxEtHDACVR'...
Permission denied (publickey).
fatal: Could not read from remote repository.
Please make sure you have the correct access rights
and the repository exists.
Process exited with ExitFailure 128: /nix/var/nix/profiles/default/bin/git clone --recursive git@github.com:seanhess/rollbar-haskell.git /home/matthew/backup/azara_work/platform/api/.stack-work/downloaded/4FnxEtHDACVR
数组需要展平,所以我写了一个快速函数来获取2D输入并将其“展平”为长度为784的一维数组。然后,因为我使用了顺序模型,所以数据名称不应该是data
,而只是'input_1'
。这摆脱了所有的错误。
现在,要获取输出信息,我们只需将其存储在如下数组中:'input'
。因为我使用了MNIST数据集,var out = outputData['output']
是一个长度为10的一维数组,其中包含每个数字是用户写入数字的概率。从那里,您可以简单地找到概率最高的数字,并将其用作模型的预测。