可视化来自Keras的自定义模型中的中间层

时间:2018-10-17 09:25:35

标签: python tensorflow keras visualization

我目前正在Keras官方存储库中运行ConvLSTM模型。关于模型的理解,最困难的部分之一是难以理解过程的每个阶段正在发生的事情。可以找到here的官方代码。 从序列模型来看,批处理规范和卷积LSTM层的多次出现不是显而易见的选择,并且很难理解它们的实际作用。

seq = Sequential()
seq.add(ConvLSTM2D(filters=40, kernel_size=(3, 3),
               input_shape=(None, 40, 40, 1),
               padding='same', return_sequences=True))
seq.add(BatchNormalization())

seq.add(ConvLSTM2D(filters=40, kernel_size=(3, 3),
               padding='same', return_sequences=True))
seq.add(BatchNormalization())

seq.add(ConvLSTM2D(filters=40, kernel_size=(3, 3),
               padding='same', return_sequences=True))
seq.add(BatchNormalization())

seq.add(ConvLSTM2D(filters=40, kernel_size=(3, 3),
               padding='same', return_sequences=True))
seq.add(BatchNormalization())

对此的一个模糊的理解是,当涉及到图像时,它有助于利用卷积的作用来提取帧的特征。 [输入是视频的帧]。 如果我可以可视化每层的输出,这将非常有帮助,可以帮助我更好地了解卷积LSTM层在每个阶段的作用。对此有任何建议或朝这个方向的指针,我们深表感谢。谢谢您的宝贵时间。

1 个答案:

答案 0 :(得分:0)

如果您只熟悉Keras样式图的构造,建议您创建另一个Sequential,该Sequential与origianl Sequential共享一些图层,例如

seq = tf.keras.Sequential()
d1 = layers.Dense(units=1, use_bias=False, 
                  kernel_initializer=tf.initializers.constant(2.0))
d2 = layers.Dense(units=1, use_bias=False, 
                  kernel_initializer=tf.initializers.constant(3.0))
seq.add(d1)
seq.add(d2)

seq2 = tf.keras.Sequential()
seq2.add(d1)

print (seq.predict(np.ones(shape=[1,1])))
print (seq2.predict(np.ones(shape=[1,1])))

在上述情况下,您可以获得中间层d1的值。

如果您经常使用原始的tf.Session(),则可以为模型进行自定义的call()方法

class MultiOut(tf.keras.Model):
    def __init__(self, name="original"):
        super().__init__(name=name)
        self.d1 = layers.Dense(units=1, use_bias=False, 
                               kernel_initializer=tf.initializers.constant(2.0))
        self.d2 = layers.Dense(units=1, use_bias=False, 
                               kernel_initializer=tf.initializers.constant(3.0))

    def call(self, inputs, multiout=False):
        d1 = self.d1(inputs)
        d2 = self.d2(d1)

        if not multiout:
            return d2
        else:
            return d1, d2

model = MultiOut()

input = np.ones(shape=[1,1])
print (model.predict(input))

sess = tf.keras.backend.get_session()
ts_input = tf.constant(input, dtype=tf.float32)

print (sess.run(model(ts_input, multiout=True)))