带有RNN的TF2.0自定义LSTMCell:始终作为单个类输出

时间:2019-09-11 12:27:22

标签: python tensorflow lstm tensorflow2.0

在TF 2.0上的实现:

我已经实现了一个自定义LSTMCell并将其提供给RNN-Encapsulation layer

我的目标是具有相同的行为,例如tf.keras.layers.LSTM。 而且,目前我不希望针对GPU进行优化。

该模型用于预测7个输出类别,并且该模型始终预测单个类别。

数据集的所有内容均正确无误,并与标准tf.keras.layers.LSTM兼容

Here is how the confusion matrix looks like

可能存在一个明显的错误,我无法弄清。

class LSTMCell_Layer(Layer):

    def __init__(self, units=100, n_input=6, batch_size=16, n_classes=7):
        self.units = units
        self.state_size = units
        super(LSTMCell_Layer, self).__init__()

        w_init = tf.random_normal_initializer(stddev=random_stddev)

        self.Wxh  = tf.Variable(initial_value=w_init(shape=(n_input, 4*n_hidden),dtype='float32'),trainable=True, name='Wxh')

        self.Whh = tf.Variable(initial_value=w_init(shape=(n_hidden, 4 *n_hidden),dtype='float32'),trainable=True, name='Whh')

        b_init = tf.zeros_initializer()
        self.bias = tf.Variable(initial_value=b_init(shape=(4*n_hidden), dtype='float32'),trainable=True,name='bias')

        self.h = tf.Variable(initial_value=w_init(shape=(batch_size, n_hidden),dtype='float32'),trainable=True, name='h')

        self.forget_bias = 1.0
        self.c = tf.Variable(initial_value=w_init(shape=(batch_size, n_hidden),dtype='float32'),trainable=True, name='c')

    def call(self, X, state):

        self.h = state[0]
        W_full = tf.concat([self.Wxh, self.Whh], 0)
        concat = tf.concat([X, self.h], 1) # concat for speed.
        concat = tf.matmul(concat, W_full) + self.bias
        i, j, f, o = tf.split(concat, 4, 1)
        g = tf.tanh(j)
        new_c = self.c*tf.sigmoid(f+self.forget_bias) + tf.sigmoid(i)*g
        self.c = new_c
        new_h = tf.tanh(new_c) * tf.sigmoid(o)
    return new_c, [new_h]

接着是模型定义:

class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.b1 = RNN(LSTMCell_Layer(100))
        self.b2 = Dense(100, activation='relu')
        self.b3 = Dense(n_classes, activation='softmax')

    def call(self,inputs):
        x = self.b1(inputs)
        x = self.b2(x)
        x = self.b3(x)
return x

0 个答案:

没有答案