Keras函数(K.function)不适用于RNN(提供了代码)

时间:2019-04-30 09:16:07

标签: keras output recurrent-neural-network keras-layer

我试图在Keras上查看每个图层的输出,但是我无法获得正确的代码,所以我做了一个简单的代码。

问题:我应该如何获取在整个层中都有RNN层的每个层的输出?

您可以在下面的代码中看到我的尝试方式。

以下是有效的测试代码(1):

seq_length = 3
latent_dim = 2
inputs = Input(shape=(seq_length, latent_dim))
outputs = Dense(5)(inputs)
outputs = Flatten()(outputs)

model = Model(inputs=inputs, outputs=outputs)
model.compile(optimizer='rmsprop', loss='mse')
print(model.summary())

要查看每个图层(2)的输出:

layer_outputs = list()
for idx, l in enumerate(model.layers):
    if idx == 0:
        continue
    layer_outputs.append(l.output)
get_3rd_layer_output = K.function([model.layers[0].input],
                                  layer_outputs)
layer_output = get_3rd_layer_output([enc_input])
print('')
for l_output in layer_output:
    print(l_output[0][0])
    print('')

然后输出将类似于

  

[4.172303 -2.248884 1.397713 3.2669916 2.5788064]

     

4.172303

但是,如果我尝试使用下面使用RNN的代码测试与(2)相同的逻辑:

seq_length = 3
latent_dim = 2
inputs = Input(shape=(seq_length, latent_dim))
outputs, last_output = GRU(latent_dim, return_state=True, return_sequences=True)(inputs)

model = Model(inputs=inputs, outputs=outputs)
model.compile(optimizer='rmsprop', loss='mse')
print(model.summary())

并使用(2)进行测试,它将发出如下信号:

  

-------------------------------------------------- ---------------------------- TypeError跟踪(最近的呼叫   最后)         5 layer_outputs.append(l.output)         6 get_3rd_layer_output = K.function([model.layers [0] .input],   ----> 7 layer_outputs)         8 layer_output = get_3rd_layer_output([enc_input])         9次打印('')

     

d:\ igs_projects \ nlp_nlu \ venv \ lib \ site-packages \ keras \ backend \ tensorflow_backend.py   在函数中(输入,输出,更新,** wargs)2742
  msg ='通过TensorFlow将无效的参数“%s”传递给K.function   后端'%键2743引发ValueError(msg)   -> 2744返回函数(输入,输出,更新=更新,**假)2745 2746

     

d:\ igs_projects \ nlp_nlu \ venv \ lib \ site-packages \ keras \ backend \ tensorflow_backend.py   在 init (自身,输入,输出,更新,名称,** session_kwargs)中
  2544 self.inputs = list(inputs)2545 self.outputs =   清单(输出)   -> 2546,带有tf.control_dependencies(self.outputs):2547 updates_ops = [] 2548,用于更新:

     

d:\ igs_projects \ nlp_nlu \ venv \ lib \ site-packages \ tensorflow \ python \ framework \ ops.py   在control_dependencies(control_inputs)5002中返回   _NullContextmanager()5003其他:   -> 5004返回get_default_graph()。control_dependencies(control_inputs)5005
  5006

     

d:\ igs_projects \ nlp_nlu \ venv \ lib \ site-packages \ tensorflow \ python \ framework \ ops.py   在control_dependencies(self,control_inputs)中4541如果   isinstance(c,IndexedSlices):4542 c = c.op   -> 4543 c = self.as_graph_element(c)4544 if isinstance(c,张量):4545 c = c.op

     

d:\ igs_projects \ nlp_nlu \ venv \ lib \ site-packages \ tensorflow \ python \ framework \ ops.py   在as_graph_element中(self,obj,allow_tensor,allow_operation)3488   3489具有self._lock:   -> 3490返回self._as_graph_element_locked(obj,allow_tensor,allow_operation)3491 3492 def _as_graph_element_locked(self,   obj,allow_tensor,allow_operation):

     

d:\ igs_projects \ nlp_nlu \ venv \ lib \ site-packages \ tensorflow \ python \ framework \ ops.py   在_as_graph_element_locked中(self,obj,allow_tensor,allow_operation)   3577#我们放弃! 3578引发TypeError(“无法   将%s转换为%s。“%(type(obj)。名称,   -> 3579 types_str))3580 3581 def get_operations(self):

     

TypeError:无法将列表转换为张量或操作。

1 个答案:

答案 0 :(得分:0)

对于GRU层,layer.output本身就是一个列表。

import tkinter as tk

def callback(event):
    #print('event.widget.get():', event.widget.get())
    event.widget.insert('end', '.')  # put new text in Entry
    return 'break' # stop event so it will not put comma in Entry

root = tk.Tk()

e = tk.Entry(root)
e.pack()
e.bind('<Key-comma>', callback) # execute callback before it puts comma in Entry

root.mainloop()

layer_outputs是一个包含另一个列表的列表,因此错误“无法将列表转换为张量或运算”

>>> model.layers[1].output
[<tf.Tensor 'gru_1/transpose_1:0' shape=(?, ?, 2) dtype=float32>, <tf.Tensor 'gru_1/while/Exit_3:0' shape=(?, 2) dtype=float32>]

更新这样的代码应该可以:

>>> layer_outputs
[[<tf.Tensor 'gru_1/transpose_1:0' shape=(?, ?, 2) dtype=float32>, <tf.Tensor 'gru_1/while/Exit_3:0' shape=(?, 2) dtype=float32>]]