tensorflow.js lstm掩盖了不正确的时间步

时间:2019-05-09 18:36:52

标签: javascript tensorflow lstm masking

我已经使用tensorflow js编写了一个简单的序列到序列编码器。输入形状为[10,1326],输出形状为[10]。在输入中,可能有多达8个空时间步长,只有1个空位,全为0。

这是模型:

Sub Email_Budget()

   Dim objOutlook As Object
   Set objOutlook = CreateObject("Outlook.Application")

   Dim objEmail As Object
   Set objEmail = objOutlook.CreateItem(olMailItem)

   Dim CaseCount As Long
   CaseCount = WorksheetFunction.CountA(Range("B6:B500"))
   'Debug.Print CaseCount

   Dim i As Integer

   With objEmail
      .To = "abc@xyz.com"
      .Subject = "TEST1: May 2019 Budget"
      .HTMLBody = "Karen,<br><br>"
      .HTMLBody = .HTMLBody & "The potential " & MonthName(Month(ActiveSheet.Range("A2"))) & " invoices are below.<br><br>"
      For i = 1 To CaseCount
        If ActiveSheet.Cells(i + 5, 4).Value = "Yes" Then
            .HTMLBody = .HTMLBody & "<ul style='list-style-type:disc;'>" & "<li>" & ActiveSheet.Cells(i + 5, 2).Value & " - " & Format(ActiveSheet.Cells(i + 5, 6).Value, "Currency") _
            & " (" & Format(ActiveSheet.Cells(i + 5, 8).Value, "Currency") & " without budget or invoicing)." & "</li>" & "</ul>"


      End If
      Next i
      .HTMLBody = .HTMLBody & "<br>Thank you,<br>Kurt"
      .Display
   End With
End Sub

经过一些训练后,我看到了这样的预测:

const tf = require('@tensorflow/tfjs-node')

module.exports = function() {
  const model = tf.sequential();
  model.add(tf.layers.masking({maskValue:0, inputShape:[10, 1326]}))
  model.add(tf.layers.lstm({units:20, returnSequences:true}))
  model.add(tf.layers.lstm({units:15, returnSequences:false}))
  model.add(tf.layers.dense({units:10}))
  model.add(tf.layers.reLU({maxValue:1}))
  const optimizer = tf.train.adadelta()
  model.compile({
    optimizer:optimizer,
    loss:tf.losses.absoluteDifference,
  })
  return model
}

我期望的地方:

Tensor
    [[0.2750083, 0.2593362, 0.1763182, 0        , 0.0915875, 0.0430228, 0, 0, 0, 0],
     [0.2601015, 0.231798 , 0.144048 , 0        , 0.0815957, 0.056561 , 0, 0, 0, 0],
     [0.256667 , 0.2420369, 0.1736434, 0        , 0.0854579, 0.0473421, 0, 0, 0, 0],
     [0.2556586, 0.273939 , 0.1369745, 0        , 0.113734 , 0.0677839, 0, 0, 0, 0],
     [0.1967069, 0.1931839, 0.1047193, 0.0016383, 0.0509892, 0.0433681, 0, 0, 0, 0],
     [0.2441588, 0.2343057, 0.1448116, 0        , 0.0733367, 0.0689584, 0, 0, 0, 0],
     [0.2288964, 0.2493394, 0.1462133, 0        , 0.1020615, 0.0668219, 0, 0, 0, 0],
     [0.2435157, 0.2482093, 0.1485323, 0        , 0.1007352, 0.0541206, 0, 0, 0, 0],
     [0.272615 , 0.2631502, 0.1571562, 0        , 0.1078788, 0.0649559, 0, 0, 0, 0],
     [0.273352 , 0.2620186, 0.1684739, 0        , 0.0936071, 0.0462594, 0, 0, 0, 0]]

输出中的第4个以及最后4个特征几乎始终为0。有什么明显的我想念的地方吗?任何建议都很棒

0 个答案:

没有答案