TensorFlowJS中的单位和inputShape

时间:2018-05-10 09:17:10

标签: javascript tensorflow tensorflow.js

我是TensorflowJS和ML的新手。在API参考中,以下代码为there

const model = tf.sequential();

// First layer must have an input shape defined.
model.add(tf.layers.dense({units: 32, inputShape: [50]}));

// Afterwards, TF.js does automatic shape inference.
model.add(tf.layers.dense({units: 4}));

// Inspect the inferred shape of the model's output, which equals
// `[null, 4]`. The 1st dimension is the undetermined batch dimension; the
// 2nd is the output size of the model's last layer.
console.log(JSON.stringify(model.outputs[0].shape));

我想知道的是,

什么是inputShape

什么是自动形状?

由于unit引用了数据集的属性,为什么unitmodel.add(tf.layers.dense({units: 4}))行中设置为4。 (该层在unit中将model.add(tf.layers.dense({units: 32, inputShape: [50]}))定义为32)由于sequential()的一层输出是下一层的输入,因此单位必须相同?< / p>

2 个答案:

答案 0 :(得分:2)

  

什么是inputShape

它是一个包含张量尺寸的数组,在运行神经网络时用作输入。

  

什么是自动形状?

它之前只使用了图层的输出形状。在这种情况下[32],因为之前的图层是一个具有32个单位的密集图层。

  

由于单位是指数据集的属性,为什么单位设置为   model.add(tf.layers.dense({units: 4}))行中的4。 (定义的图层   单位为model.add(tf.layers.dense({units: 32, inputShape: [50]})))中的32,因为一层的连续()输出是输入   下一层,单位必须相同?

单位定义密集层的输出形状。在这种情况下,神经元应该有4个输出,所以最后一层必须有4个单位。输出和输入形状不必相同,因为每个神经元的输出(其数量是输出形状)是基于前一层的所有神经元(输出)计算的。 (如果是密集层)

答案 1 :(得分:-1)

我总是喜欢一个有效的例子。这是我能做的一个简单的例子。

我曾经在我的网站上找到超过50个TFJS示例的链接,但似乎放置链接被认为是垃圾邮件,所以我不会分享它。

<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.10.3"> </script> 




<input type="number" id="myAsk" value="5"><br> 

<input id="myButton123" type="button" value="Keras Layers Train and Test" onclick="{
       document.getElementById('myButton123').style.backgroundColor = 'red'                                                                              

    model = tf.sequential(); // no const so that it is a global variable 

    model.add(tf.layers.dense({ units: 10,  inputShape: [1] }) );  
    model.add(tf.layers.dense({ units: 10 }) );  
    model.add(tf.layers.dense({ units:  1 }) );  

   model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});

   // Generate some synthetic data for training.
   const xs = tf.tensor2d([[1], [2], [3], [4]], [4, 1]);
   const ys = tf.tensor2d([[1], [3], [5], [7]], [4, 1]);


  (async function () {   // inline async so we can use promises and await

    for (let myLoop = 1; myLoop <= 100; myLoop++) {                                                                                 
        var myFit = await model.fit(xs, ys, { epochs: 10 });
        if (myLoop % 20 == 0){   
             await tf.nextFrame();   // This allows the GUI to update but only every 20 batches      
             document.getElementById('myDiv123').innerHTML  =  'Loss after Batch ' + myLoop + ' : ' + myFit.history.loss[0] +'<br><br>'                                                                           
        }

    }                                                                                    


    const myPredictArray = await  model.predict(tf.tensor2d([document.getElementById('myAsk').value.split(',')], [1, 1]))  

    document.getElementById('myDiv123').innerHTML += 'Input '+document.getElementById('myAsk').value+', Output = ' + await myPredictArray.data() +'<br>'
    document.getElementById('myButton123').style.backgroundColor = 'lightgray'                                                                                

  })() // end the inline async funciton                                                                        


}" style="background-color: red;">   


<input id="myButton123b" type="button" value="re-Test" onclick="{
   (async function () {                                                                
   const myPredictArray = await  model.predict(tf.tensor2d([document.getElementById('myAsk').value.split(',')], [1, 1]))  

   document.getElementById('myDiv123').innerHTML = 'Input '+document.getElementById('myAsk').value+', Output = ' + await myPredictArray.data() +'<br>'
   })() // end the inline async funciton                                                                                     

 }"><br><br>

<div id='myDiv123'>...</div><br>