我试图创建一个能够区分10个图像类的keras模型。我为每个对象创建了12个观察点,然后将它们传递给模型。
input_shape = (12, 1, 32, 32)
model = get_test_shared_model()
input = keras.layers.Input(shape=input_shape)
views = SplitLayer(num_views)(input) # list of keras-tensors
processed_views = [] # empty list
for view in views:
x = model(view)
processed_views.append(x)
在这个阶段,我自己对每个观察点进行了计算,现在我尝试使用最大层将它们合并在一起。 Conv2D线路上发生错误。 关于如何解决这个问题的任何想法?
pooled_views = keras.layers.Maximum()(processed_views)
prediction = Conv2D(32, kernel_size=(3, 3), activation='relu')(pooled_views)
Split Layer源代码是:
class SplitLayer(keras.layers.Layer):
"""
Layer expects a tensor (multi-dimensonal array) of shape (samples, views,
...)
and returns a list of #views elements, each of shape (samples, ...)
"""
def __init__(self, num_splits, **kwargs):
self.num_splits = num_splits
super(SplitLayer, self).__init__(**kwargs)
def call(self, x):
a = [x[:, i] for i in range(self.num_splits)]
return a
def compute_output_shape(self, input_shape):
return [(input_shape[0],) + input_shape[2:]]*self.num_splits
def get_test_shared_model():
num_channels = 4 # for example
cnn = keras.models.Sequential()
cnn.add(keras.layers.Conv2D(num_channels, kernel_size=(3, 3),
activation='relu',
input_shape=(32, 32, 1)))
cnn.add(keras.layers.MaxPooling2D(pool_size=(2, 2)))
cnn.add(keras.layers.Flatten())
cnn.add(keras.layers.Dense(128, activation='relu'))
return cnn