将非符号张量传递给Keras Lambda层

时间:2019-09-18 21:54:38

标签: python tensorflow keras

我试图将RNNCell对象传递给Keras lambda层,以便可以在Keras模型中使用Tensorflow层,如下所示。

conv_cell = ConvGRUCell(shape = [14, 14],
                       filters = 32,
                       kernel = [3,3],
                       padding = 'SAME')

def convGRU(inputs, cell, length):
    output, final = tf.nn.bidirectional_dynamic_rnn(
            cell, cell, x, length, dtype=tf.float32)
    output = tf.concat(output, -1)
    final = tf.concat(final, -1)
    return [output, final]

lm = Lambda(lambda x: convGRU(x[0], x[1], x[2])([input, conv_cell, length])

但是,我得到一个错误,conv_cell不是符号张量(它是基于Tensorflow的GRUCell的自定义层)。

是否可以将单元格传递到Lambda层?我可以将其与functools.partial一起使用,但是由于无法访问模型内​​部的函数,因此无法保存/加载模型。

1 个答案:

答案 0 :(得分:0)

def convGRU(cell, length): # if length is produced by the model, use it with the inputs    
    def inner_func(inputs):
        code...
    return inner_func

lm = Lambda(convGRU(cell, length))(input)

要进行保存/加载,您需要使用custom_objects = {'convGRU': convGRU, 'cell':cell, 'length': length}等。Keras不知道的所有内容都需要自动放入custom_objects中,以加载保存的模型。