Tensorflow LSTM RNN输出激活函数

时间:2016-06-13 18:17:17

标签: python machine-learning neural-network artificial-intelligence tensorflow

我有一个输入图像,其灰阶值不等于25000 - 35000。我正在进行二进制像素分类,因此输出基础事实是0's1's的矩阵。

有谁知道默认输出激活功能是什么?我的问题是,它是一个ReLu?我希望它是一个SoftMax功能。在这种情况下,每个预测值都在01之间(显然接近我的地面实况数据)。

我正在使用here中的示例代码,我已调整该代码以便为我的数据工作。

我有一个正在训练的工作网络,但是现在的小批量损失约为425,准确度为0.0,而对于LSTM MNIST示例代码(链接),小批量损失约为0.1,而累计约为1.0。我希望如果我可以更改激活功能以使用SoftMax功能,我可以改善结果。

1 个答案:

答案 0 :(得分:6)

查看the codeBasicLSTMCell的默认激活功能为tf.tanh()。您可以通过在构造activation对象时指定可选的BasicLSTMCell参数,并传递任何需要单个输入并生成相同形状的单个输出的TensorFlow操作来自定义激活函数。例如:

# Defaults to using `tf.tanh()`.
lstm_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)

# Uses  `tf.relu()`.
lstm_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0, activation=tf.nn.relu)

# Uses  `tf.softmax()`.
lstm_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0, activation=tf.nn.softmax)