在LSTM Tensorflow.js模型的`.fit`中使用什么参数来输出隐藏状态和单元状态?

时间:2020-08-15 16:15:14

标签: javascript tensorflow lstm tensorflow.js

我正在尝试更改官方lstm-text-generationtfjs-examples示例中使用的模型以输出其隐藏状态。为此,我在createModel中更改了model.js函数,如下所示:

export function createModel(sampleLen, charSetSize, lstmLayerSizes) {
  const inputs = tf.input({ shape: [sampleLen, charSetSize] });

  let outputs = inputs;
  for (let i = 0; i < lstmLayerSizes.length; ++i) {
    const lstmLayerSize = lstmLayerSizes[i];
    const layer = tf.layers.lstm({
      units: lstmLayerSize,
      // Not sure if this is necessary
      returnSequences: i < lstmLayerSizes.length - 1,
      // As far as I understood, this returns the state
      returnState: i === lstmLayerSizes.length - 1,
    });

    outputs = layer.apply(outputs);
  }

  // Destructure the hidden state (optionally also the cell state in the next step)
  const [lstmOutput, hiddenState] = outputs;

  outputs = tf.layers
    .dense({ units: charSetSize, activation: 'softmax' })
    .apply(lstmOutput);

  // Return the hidden state as part of the model outputs
  return tf.model({ inputs, outputs: [outputs, hiddenState] });
}

我不了解我需要如何更改.fitfitModel的{​​{1}}参数:

model.js

训练后的模型似乎还不错。但是,当我将其可视化时,就像this Standford演讲中所做的那样,隐藏状态没有任何意义。

我在做什么有意义吗?

0 个答案:

没有答案