如何获得中间层的输出?

时间:2019-05-23 20:13:03

标签: python tensorflow keras

我正试图了解Google's colab code。我应该如何使用此代码:

from keras import backend as K
prediction_model = lstm_model(seq_len=1, batch_size=BATCH_SIZE, stateful=True)
prediction_model.load_weights('/tmp/bard.h5')

get_test_layer_output = K.function([prediction_model.layers[0].input],
                                  [prediction_model.layers[1].output])
layer_output = get_test_layer_output([x])[0]

查看每一层之后的值?还是有其他方法可以查看值(不是形状)?

Layer (type)                 Output Shape              Param #   
=================================================================
seed (InputLayer)            (128, 100)                0         
_________________________________________________________________
embedding (Embedding)        (128, 100, 512)           131072    
_________________________________________________________________
lstm (LSTM)                  (128, 100, 512)           2099200   
_________________________________________________________________
lstm_1 (LSTM)                (128, 100, 512)           2099200   
_________________________________________________________________
time_distributed (TimeDistri (128, 100, 256)           131328    
=================================================================
Total params: 4,460,800
Trainable params: 4,460,800
Non-trainable params: 0

1 个答案:

答案 0 :(得分:0)

对于要在Keras模型层上进行的任何操作,首先,我们需要访问模型所保存的keras.layers对象的列表。

model_layers = model.layers

此列表中的每个Layer对象都有其自己的inputoutput张量(如果您使用的是TensorFlow后端)

input_tensor = model.layers[ layer_index ].input
output_tensor = model.layers[ layer_index ].output

如果您使用tf.Session.run()方法直接运行output_tensor,则会出现错误,指出在访问层的输出之前必须将输入馈送到模型中。

import tensorflow as tf
import numpy as np

layer_index = 3 # The index of the layer whose output needs to be fetched

model = tf.keras.models.load_model( 'model.h5' )
out_ten = model.layers[ layer_index ].output

with tf.Session() as sess:
    tf.global_variables_initializer().run()
    output = sess.run(  out_ten , { model.input : np.ones((2,186))}  ) 
    print( output )

在运行模型之前,您需要使用tf.global_variables_initializer().run()初始化变量。 model.input为模型的输入提供占位符张量。