我试图将RNNCell
对象传递给Keras lambda层,以便可以在Keras模型中使用Tensorflow层,如下所示。
conv_cell = ConvGRUCell(shape = [14, 14],
filters = 32,
kernel = [3,3],
padding = 'SAME')
def convGRU(inputs, cell, length):
output, final = tf.nn.bidirectional_dynamic_rnn(
cell, cell, x, length, dtype=tf.float32)
output = tf.concat(output, -1)
final = tf.concat(final, -1)
return [output, final]
lm = Lambda(lambda x: convGRU(x[0], x[1], x[2])([input, conv_cell, length])
但是,我得到一个错误,conv_cell
不是符号张量(它是基于Tensorflow的GRUCell的自定义层)。
是否可以将单元格传递到Lambda层?我可以将其与functools.partial一起使用,但是由于无法访问模型内部的函数,因此无法保存/加载模型。
答案 0 :(得分:0)
def convGRU(cell, length): # if length is produced by the model, use it with the inputs
def inner_func(inputs):
code...
return inner_func
lm = Lambda(convGRU(cell, length))(input)
要进行保存/加载,您需要使用custom_objects = {'convGRU': convGRU, 'cell':cell, 'length': length}
等。Keras不知道的所有内容都需要自动放入custom_objects
中,以加载保存的模型。