TensorFlow.js:在训练期间保存不同的模型实例

时间:2019-02-06 14:25:13

标签: neural-network tensorflow.js

我正在NODE上运行TensorFlow.JS,我希望能够在训练过程中的某个时刻保存模型。

我试图仅将实际模型复制到全局变量,但是JavaScript对象是通过引用复制的,最后,全局变量具有与上一个训练时期相同的模型。

然后,我使用了许多不同的JavaScript方法进行深度克隆(包括lodash深度克隆),但是我在复制的模型上遇到错误,例如最终导致丢失的函数(例如model.evaluate)。

我想知道是否可以保存某个检查点的唯一方法是直接使用model.save(),或者是否还有其他方法可以将模型对象(按值引用)复制到全局或类属性。

感谢前进!

**更新**

现在对我最有效的解决方案是创建模型副本:

  const copyModel = (model) => {
    const copy = tf.sequential();
    model.layers.forEach(layer => {
      copy.add(layer);
    });
    copy.compile({ loss: model.loss, optimizer: model.optimizer });
    return copy;
  }
  • 考虑到您可能需要将其他一些设置从原始模型复制到新模型(副本)。

1 个答案:

答案 0 :(得分:1)

tf.Model对象包含权重值,通常在GPU上显示 (作为WebGL纹理)并且不容易克隆。所以这不是一个好主意 克隆tf.Model对象。您应该对其进行序列化并将其保存在某处。 有两个选项:

  1. 如果您使用的是Node.js,则应该具有足够的存储空间。只是 使用Model.save()将模型“快照”到磁盘上,然后可以将其装回 以后。
  2. 如果您不想避免遍历文件系统,则可以在内存中进行序列化和反序列化。使用方法tf.io.withSaveHandlertf.io.fromMemory()。请参见下面的示例:
const tf = require('@tensorflow/tfjs');
require('@tensorflow/tfjs-node');

(async function main() {
  const model = tf.sequential();
  model.add(tf.layers.dense({units: 1, inputShape: [3], useBias: false}));
  model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});

  const xs = tf.randomUniform([4, 3]);
  const ys = tf.randomUniform([4, 1]);

  const artifactsArray = [];

  // First save, before training.
  await model.save(tf.io.withSaveHandler(artifacts => {
    artifactsArray.push(artifacts);
  }));

  // First load.
  const model2 = await tf.loadModel(tf.io.fromMemory(
      artifactsArray[0].modelTopology, artifactsArray[0].weightSpecs,
      artifactsArray[0].weightData));

  // Do some training.
  await model.fit(xs, ys, {epochs: 5});

  // Second save, before training.
  await model.save(tf.io.withSaveHandler(artifacts => {
    artifactsArray.push(artifacts);
  }));

  // Second load.
  const model3 = await tf.loadModel(tf.io.fromMemory(
      artifactsArray[1].modelTopology, artifactsArray[1].weightSpecs,
      artifactsArray[1].weightData));

  // The two loaded models should have different weight values.
  model2.getWeights()[0].print();
  model3.getWeights()[0].print();
})();