有人可以告诉我,此代码有什么问题吗?为什么显示索引错误。编码器和解码器是分别定义的,当我如图所示在最后一行合并两个模型时,它将引发索引错误。
def Encoder():
x= Conv2D(128, (3, 3),padding='same',activation='relu')(input_img)
x= Conv2D(128, (3, 3), activation='relu', padding='same')(x)
x= MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(x)
x= Conv2D(64, (3, 3), activation='relu', padding='same')(x)
x= Conv2D(64, (3, 3), activation='relu', padding='same')(x)
x= MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(x)
x= Conv2D(32, (3, 3), activation='relu', padding='same')(x)
x= Conv2D(32, (3, 3), activation='relu', padding='same')(x)
x= Conv2D(32, (3, 3), activation='relu', padding='same')(x)
x= MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(x)
x= Conv2D(16, (3, 3), activation='relu', padding='same')(x)
x= Conv2D(16, (3, 3), activation='relu', padding='same')(x)
x= Conv2D(8, (3, 3), activation='relu', padding='same')(x)
x= Conv2D(4, (3, 3), activation='relu', padding='same')(x)
encoded =Flatten()(x)
return Model(input_img,encoded)
# Encoder.add(Flatten())
# autoencoder.add(Reshape((28, 28, 4)))
def Decoder():
encoded=Input(shape=(28, 28, 4))
#encoded = Reshape((28, 28, 4))(encoded)
x= Conv2D(4, (3, 3), activation='relu',padding='same')(encoded)
x= Conv2D(8, (3, 3), activation='relu',padding='same')(x)
x= Conv2D(16, (3, 3), activation='relu', padding='same')(x)
x= Conv2D(16, (3, 3), activation='relu', padding='same')(x)
x= UpSampling2D(size=(2, 2))(x)
x= Conv2D(32, (3, 3), activation='relu', padding='same')(x)
x= Conv2D(32, (3, 3), activation='relu', padding='same')(x)
x= Conv2D(32, (3, 3), activation='relu', padding='same')(x)
x= UpSampling2D(size=(2, 2))(x)
x= Conv2D(64, (3, 3), activation='relu', padding='same')(x)
x= Conv2D(64, (3, 3), activation='relu', padding='same')(x)
x= UpSampling2D(size=(2, 2))(x)
x= Conv2D(128, (3, 3), activation='relu', padding='same')(x)
x= Conv2D(128, (3, 3), activation='relu', padding='same')(x)
decoded= Conv2D(3, (3, 3), activation='sigmoid', padding='same')(x)
return Model(encoded,decoded)
x = Input(shape=(224, 224, 3))
autoencoder = Model(x, Decoder()(Encoder()(x)))
堆栈跟踪为:
IndexError Traceback (most recent call last) in ()
6
7 # make the model:
----> 8 autoencoder = Model(x, Decoder()(Encoder()(x)))
9
10
~/VirtualEnvs/deep/lib/python3.5/site-
packages/keras/engine/base_layer.py in call(self, inputs, **kwargs)
455 # Actually call the layer,
456 # collecting output(s), mask(s), and shape(s).
--> 457 output = self.call(inputs, **kwargs)
458 output_mask = self.compute_mask(inputs, previous_mask)
459
~/VirtualEnvs/deep/lib/python3.5/site-
packages/keras/engine/network.py in call(self, inputs, mask)
568 return self._output_tensor_cache[cache_key]
569 else:
--> 570 output_tensors, _, _ = self.run_internal_graph(inputs,
masks)
571 return output_tensors
572
~/VirtualEnvs/deep/lib/python3.5/site-
packages/keras/engine/network.py in run_internal_graph(self, inputs,
masks)
722 if 'mask' not in kwargs:
723 kwargs['mask'] = computed_mask
--> 724 output_tensors = to_list(layer.call(computed_tensor,
**kwargs))
725 output_masks = layer.compute_mask(computed_tensor,
726 computed_mask)
~/VirtualEnvs/deep/lib/python3.5/site-
packages/keras/layers/convolutional.py in call(self, inputs)
166 padding=self.padding,
167 data_format=self.data_format,
--> 168 dilation_rate=self.dilation_rate)
169 if self.rank == 3:
170 outputs = K.conv3d(
~/VirtualEnvs/deep/lib/python3.5/site-
packages/keras/backend/tensorflow_backend.py in conv2d(x, kernel,
strides, padding, data_format, dilation_rate)
3563 strides=strides,
3564 padding=padding,
-> 3565 data_format=tf_data_format)
3566
3567 if data_format == 'channels_first' and tf_data_format == 'NHWC':
~/VirtualEnvs/deep/lib/python3.5/site-
packages/tensorflow/python/ops/nn_ops.py in convolution(input,
filter, padding, strides, dilation_rate, name, data_format)
777 dilation_rate=dilation_rate,
778 name=name,
--> 779 data_format=data_format)
780 return op(input, filter)
781
~/VirtualEnvs/deep/lib/python3.5/site-
packages/tensorflow/python/ops/nn_ops.py in init(self, input_shape,
filter_shape, padding, strides, dilation_rate, name, data_format)
826
827 if data_format is None or not data_format.startswith("NC"):
--> 828 input_channels_dim = input_shape[num_spatial_dims + 1]
829 spatial_dims = range(1, num_spatial_dims + 1)
830 else:
~/VirtualEnvs/deep/lib/python3.5/site-
packages/tensorflow/python/framework/tensor_shape.py in getitem(self,
key)
613 return TensorShape(self._dims[key])
614 else:
--> 615 return self._dims[key]
616 else:
617 if isinstance(key, slice):
**IndexError: list index out of range
答案 0 :(得分:2)
形状不匹配。 encoder
返回编码图像的扁平化表示,即数据的形状为(BS, -1)
。在您的解码器中,您正在定义这样的输入
encoded=Input(shape=(28, 28, 4))
即您假设输入的形状(BS, 28, 28, 4)
与编码器的输出不兼容。
答案 1 :(得分:0)
您需要更改输入的形状以使其与解码器输出兼容。使用 Reshape((28, 28, 4))
使其匹配。