使用自定义类替换lambda函数时,构建模型失败。
我以前在lambda函数中具有此代码,并且可以正常工作,但无法保存模型。我需要保存正在构建的模型,该模型依赖于此代码片段。
import keras
import tensorflow as tf
class ShapePositionLayer(keras.layers.Layer):
def call(self, x):
assert isinstance(x, list)
a, b = x
return keras.backend.gather(keras.backend.shape(a), b)
def compute_output_shape(self, input_shape):
return (1)
captions = keras.layers.Input(shape=[5,1024], name='captions')
batch_size = ShapePositionLayer()([captions,tf.constant(0,
dtype=tf.int32)])
model = keras.models.Model(inputs=[captions], outputs=[batch_size])
我希望能够建立一个模型。
接收错误: AttributeError:“ NoneType”对象没有属性“ _inbound_nodes”