如何在TF2中获得预训练模型的中间张量输出

时间:2020-07-23 09:31:45

标签: python tensorflow keras tensorflow2

我想使用注意力模型来提取注意力得分。但是我找不到要使用的任何TF2 API。 简单的代码:

import tensorflow as tf

model = train_model()
func = tf.function(model)
tensor_specs1 = tf.TensorSpec.from_tensor(model.input)

call = func.get_concrete_function(tensor_specs1)
graph = call.graph

tensor_names = [n for n in graph.as_graph_def().node]
for name in tensor_names:
    print(name)

outputs = graph.get_tensor_by_name('model_1/word_encoder_time/word_attention/Softmax:0')

pred_model = tf.keras.models.Model(model.input,outputs)

results = pred_model(tensor_specs1)

print(results)

但引发异常:

    raise ValueError("Tensor %s is not an element of this graph." % obj)
ValueError: Tensor Tensor("model_1/word_encoder_time/model/word_attention/BiasAdd:0", shape=(?, 10), dtype=float32) is not an element of this graph

它正在工作,但这不是我想要的:

outputs = [model.get_layer(name=output).get_output_at(0) for output in output_layers]
pred_model = tf.keras.models.Model(model.input,outputs)

我想获得中间张量,而不是层的输出。

1 个答案:

答案 0 :(得分:0)

为了在Keras模型中评估任意层的输出,可以使用Keras函数来避免使用会话和图形。

import tensorflow as tf
print('TensorFlow: ', tf.__version__, end='\n\n')

input_layer = tf.keras.Input(shape=[100])
x = tf.keras.layers.Dense(16, activation='relu')(input_layer)
x = tf.keras.layers.Dense(64, activation='relu')(x)
x = tf.keras.layers.Dense(32, activation='relu')(x)
x = tf.keras.layers.Dense(10, activation='relu')(x)
output_layer = tf.keras.layers.Dense(5, activation='softmax')(x)

model = tf.keras.Model(inputs=[input_layer], outputs=[output_layer])

a = model.layers[3].output
print(a)

fn = tf.keras.backend.function(input_layer, a)  # create a Keras Function
random_input = tf.random.normal([1, 100])  # random noise

a_eval = fn(random_input)
print('\nLayer Output:\n', a_eval)

输出:

TensorFlow:  2.3.0-dev20200611

Tensor("dense_73/Identity:0", shape=(None, 32), dtype=float32)

Layer Output:
 [[0.         0.         0.46475422 0.0961322  0.         0.
  0.23016977 0.         0.         0.05861767 0.03298267 0.11953808
  0.         0.         0.97043467 0.         0.         0.6384926
  0.         0.         0.         0.2346505  0.1822727  0.0145395
  0.08411474 0.         0.         0.37601566 0.         0.
  0.29435986 0.44069782]]