现在,我使用函数tf.loadFrozenModel()在主线程中加载模型,然后我要将克隆的模型克隆或转移到Webworker。我该怎么办?
我的github中的代码:
https://github.com/yiifanLu/webWorker
答案 0 :(得分:0)
最好直接在worker中下载冻结的模型。原因是在版本10和11上,没有tf.models.modelFromJSON
来加载可通过model.toJson
传递给工作人员的字符串化模型。
以下内容在主线程中定义了一个模型。该模型保存在本地服务器提供的文件中。工人可以加载并将其用于预测
<head>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.14.2/dist/tf.min.js"></script>
<script>
const worker_function = () => {
onmessage = async (event) => {
console.log('from web worker')
this.window = this
importScripts('https://cdn.jsdelivr.net/npm/setimmediate@1.0.5/setImmediate.min.js')
importScripts('https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.11.2')
tf.setBackend('cpu')
const model = await tf.loadModel('http://localhost:8080/model.json')
model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
// Generate some synthetic data for training.
const xs = tf.tensor2d([1, 2, 3, 4], [4, 1]);
const ys = tf.tensor2d([1, 3, 5, 7], [4, 1]);
// Train the model inside the worker
await model.fit(xs, ys, {epochs: 10})
const res = model.predict(tf.tensor2d([5], [1, 1]));
// send response to main thread
postMessage({res: res.dataSync(), shape: res.shape})
};
}
if (window != self)
worker_function();
</script>
<script>
const model = tf.sequential();
model.add(tf.layers.dense({units: 1, inputShape: [1]}));
const worker = new Worker(URL.createObjectURL(new Blob(["(" + worker_function.toString() + ")()"], { type: 'text/javascript' })));
(async() => {
model.save('downloads://model')
})()
worker.postMessage({model : 'model'});
worker.onmessage = (message) => {
console.log('from main thread')
const {data} = message
tf.tensor(data.res, data.shape).print()
}
</script>
</head>