Tensorflow MNIST示例中的最后一层LSTM

时间:2016-11-07 18:15:03

标签: tensorflow lstm

我正在使用MNIST数据集上的Tensorflow LSTM示例。 我不明白为什么在最后一层使用逻辑回归。不是使用LSTM网络的最后一个输出比使用前一个'时间步长的输出更好的估算器吗?我怎样才能使用LSTM网络的最后一个输出进行分类?

#  Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

"""
This example builds rnn network for mnist data.
Borrowed structure from here: https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/3%20-%20Neural%20Networks/recurrent_network.py
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from sklearn import metrics, preprocessing

import tensorflow as tf
from tensorflow.contrib import learn

# Parameters
learning_rate = 0.1
training_steps = 3000
batch_size = 128

# Network Parameters
n_input = 28  # MNIST data input (img shape: 28*28)
rnn_timesteps = 28  # timesteps
n_hidden = 128  # hidden layer num of features
n_classes = 10  # MNIST total classes (0-9 digits)

### Download and load MNIST data.
mnist = learn.datasets.load_dataset('mnist')


X_train = mnist.train.images
y_train = mnist.train.labels
X_test = mnist.test.images
y_test = mnist.test.labels

print(X_train.shape) # (55000, 784)
print(y_train.shape) # (55000,)

# It's useful to scale to ensure Stochastic Gradient Descent will do the right thing

scaler = preprocessing.StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.fit_transform(X_test)


def rnn_model(X, y):
  X = tf.reshape(X, [-1, rnn_timesteps, n_input])  # (batch_size, rnn_timesteps, n_input)
  # # permute rnn_timesteps and batch_size
  X = tf.transpose(X, [1, 0, 2])
  # # Reshape to prepare input to hidden activation
  X = tf.reshape(X, [-1, n_input])  # (rnn_timesteps*batch_size, n_input)
  # # Split data because rnn cell needs a list of inputs for the RNN inner loop
  X = tf.split(0, rnn_timesteps, X)  # rnn_timesteps * (batch_size, n_input)

  # Define a GRU cell with tensorflow
  lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden)
  # Get lstm cell output
  _, encoding = tf.nn.rnn(lstm_cell, X, dtype=tf.float32)

  return learn.models.logistic_regression(encoding, y)


classifier = learn.TensorFlowEstimator(model_fn=rnn_model, n_classes=n_classes,
                                       batch_size=batch_size,
                                       steps=training_steps,
                                       learning_rate=learning_rate)

classifier.fit(X_train, y_train, logdir="/tmp/mnist_rnn")
score = metrics.accuracy_score(y_test, classifier.predict(X_test))
print('Accuracy: {0:f}'.format(score))

2 个答案:

答案 0 :(得分:0)

逻辑回归层用于将连续多维输出转换为"类"。从概念上讲,它将输入转换为索引(类标签)。

中间输出传达了有关数据的更多信息,它们可以用于其他任务,但是为了对样本进行分类,应该使用逻辑回归层。

答案 1 :(得分:0)

您使用RNN的最后状态 。根据{{​​3}}的文档,第二个返回值是执行计算后RNN的状态。

rnn用于将LSTM的实值状态投影到一个类,以及定义损失函数。