如何在tensorflow.js中实现GRU层

时间:2019-07-06 23:02:39

标签: tensorflow.js

我试图在tensorflow.js中使用GRU模型,但是当我在模型中添加GRU层时,会出现参考错误或输入不兼容。

我正在使用带有tensorflow.js的node.js。我将问题缩小为用户错误,因为老实说我不知道​​如何将数据传递到GRU层,并且tensorflow提供的示例要么模糊不清,要么过于复杂。我提供的代码直接来自tensorflow API,并花了一些时间来添加GRU层(来自api示例)。

const tf = require('@tensorflow/tfjs-node');

const xArray = [
  [1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1],
];
const yArray = [1, 1, 1, 1];
// Create a dataset from the JavaScript array.
const xDataset = tf.data.array(xArray);
const yDataset = tf.data.array(yArray);
// Zip combines the `x` and `y` Datasets into a single Dataset, the
// iterator of which will return an object containing of two tensors,
// corresponding to `x` and `y`.  The call to `batch(4)` will bundle
// four such samples into a single object, with the same keys now 
// pointing
// to tensors that hold 4 examples, organized along the batch dimension.
// The call to `shuffle(4)` causes each iteration through the dataset to
// happen in a different order.  The size of the shuffle window is 4.
const xyDataset = tf.data.zip({xs: xDataset, ys: yDataset})
    .batch(4)
    .shuffle(4);

const cell = tf.layers.gru({units: 2, returnSequences: true});
const input = tf.input({shape: [4, 9]});
const output = cell.apply(input);

const model = tf.sequential();
model.add(tf.layers.dense({units: 1, inputShape: [9]}))
model.add(tf.layers.gru({
    units: 2,
    input,
    dropout: 0,
    recurrentDropout: 0   
}));
model.add(tf.layers.dense({units: 1}))

//const model = tf.sequential({
//  layers: [tf.layers.dense({units: 1, inputShape: [9]})]
//});
model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
//const history = model.fitDataset(xyDataset, {
//  epochs: 4,
//  callbacks: {onEpochEnd: (epoch, logs) => console.log(logs.loss)}
//});

async function test(model) {
    const history = await model.fitDataset(xyDataset, {
        epochs: 4,
        callbacks: {onEpochEnd: (epoch, logs) => console.log(logs.loss)}
});
}

console.log(model.predict(tf.ones([4, 9])).dataSync());

这里的错误是输入与GRU层不兼容。仅使用密集层就可以很好地工作,但是这些层不适合我向网络提供的数据。

如果有人可以帮助使此代码与GRU层一起使用,我将不胜感激,因为这将为像我这样受困的人提供一个简单的示例。

提前谢谢!

0 个答案:

没有答案