我正在尝试制作一个在水平轴上翻转图像并将其添加到批处理尺寸的图层。代码如下:
class FlipLayer(keras.layers.Layer):
def __init__(self, input_layer):
super(FlipLayer, self).__init__()
def get_output_shape(self, input_shape):
return (2 * input_shape[0],) + input_shape[1:]
def get_output(self, input):
return keras.layers.Concatenate([
input,
flipim(input)
], axis=0)
其中“ flipim”只是在所需轴上翻转numpy数组的函数。使用此功能编译模型时,Keras不会给出任何错误,但是它什么也没做。当我将该层用作最后一层并检查输出时,与前一层相比,它在批处理维度中的大小仍然相同。