Keras公制问题

时间:2019-03-01 08:57:02

标签: tensorflow keras

我正在使用最新版本的tensorflow(1.13)。我正在使用Keras API来训练LSTM网络。

我不能与tf.keras.metrics.Precision()tf.keras.metrics.Recall()一起训练

正在编译。但是在训练过程中,我遇到了以下错误

  

InvalidArgumentError:断言失败:[预测必须> = 0]   [条件x> = y不按元素保存:x(dense_3 / BiasAdd:0)=]   [[[[2.72658144e-06 1.17555362e-06 1.96436554e-06 ...]] ...] [y   (metrics_3 / precision_1 / Cast / x:0)=] [0] [[{{node   metrics_3 / precision_1 / assert_greater_equal / Assert / AssertGuard / Assert}}]]

该模型非常简单,如下所示

model = Sequential()
model.add(LSTM (120,activation = "tanh", input_shape=(timesteps,dim), return_sequences=True))
model.add(LSTM(120, activation = "tanh", return_sequences=True))
model.add(LSTM(120, activation = "tanh", return_sequences=True))
model.add(LSTM(120, activation = "tanh", return_sequences=True))
model.add(LSTM(120, activation = "tanh", return_sequences=True))
model.add(LSTM(120, activation = "tanh", return_sequences=True))
model.add(Dense(dim))
model.compile(optimizer="adam", loss="mse",  metrics=[tf.keras.metrics.Precision()])

history = model.fit(data,data, 
                    epochs=100,
                    batch_size=10,
                    validation_split=0.2,
                    shuffle=True,
                    callbacks=[ch]).history

是Bug还是我做错了什么?

1 个答案:

答案 0 :(得分:1)

Precisionrecall是衡量分类效果的指标。由于您在最后一层使用mse和线性激活,因此您宁愿进行回归。

如果要分类,请确保创建的输出范围为[0,1]。这可以通过在最后一层使用sigmoidsoftmax激活来获得,具体取决于您的问题。 (二进制或n级分类)

进一步确保输出形状正确,因为您在最后一个LSTM层中获得了return_sequences=True,这可能不是您想要的。

编辑:由于您的model.fit通话,我现在可以看到,您正在尝试对数据进行自动编码。因此,precision作为度量标准在这里没有意义。