如何将tf.loadFrozenModel在主线程中加载的模型转移到webworker

时间:2019-01-29 12:36:46

标签: tensorflow web-worker tensorflow.js

现在,我使用函数tf.loadFrozenModel()在主线程中加载模型,然后我要将克隆的模型克隆或转移到Webworker。我该怎么办?
  我的github中的代码: https://github.com/yiifanLu/webWorker

1 个答案:

答案 0 :(得分:0)

最好直接在worker中下载冻结的模型。原因是在版本10和11上,没有tf.models.modelFromJSON来加载可通过model.toJson传递给工作人员的字符串化模型。

以下内容在主线程中定义了一个模型。该模型保存在本地服务器提供的文件中。工人可以加载并将其用于预测

<head>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.14.2/dist/tf.min.js"></script>
    <script>
        const worker_function = () => {

            onmessage =  async (event) => {
                console.log('from web worker')
                    this.window = this
                    importScripts('https://cdn.jsdelivr.net/npm/setimmediate@1.0.5/setImmediate.min.js')
                    importScripts('https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.11.2')
                    tf.setBackend('cpu')
                    const model = await tf.loadModel('http://localhost:8080/model.json')
                    model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
                    
                    // Generate some synthetic data for training.
                    const xs = tf.tensor2d([1, 2, 3, 4], [4, 1]);
                    const ys = tf.tensor2d([1, 3, 5, 7], [4, 1]);

                    // Train the model inside the worker
                     await model.fit(xs, ys, {epochs: 10})
                     const res = model.predict(tf.tensor2d([5], [1, 1]));
                    // send response to main thread
                    
                    postMessage({res: res.dataSync(), shape: res.shape})
            };
        }
        if (window != self)
            worker_function();
    </script>
    <script>
    
        const model = tf.sequential();
        model.add(tf.layers.dense({units: 1, inputShape: [1]}));
        
        const worker = new Worker(URL.createObjectURL(new Blob(["(" + worker_function.toString() + ")()"], { type: 'text/javascript' })));
        (async() => {
            model.save('downloads://model')
        })()
       
        worker.postMessage({model : 'model'});
        worker.onmessage = (message) => {
            console.log('from main thread')
            const {data} = message
            tf.tensor(data.res, data.shape).print()
        }
    </script>
</head>