结合编码器和解码器模型

时间:2018-09-03 15:11:06

标签: python keras autoencoder

有人可以告诉我,此代码有什么问题吗?为什么显示索引错误。编码器和解码器是分别定义的,当我如图所示在最后一行合并两个模型时,它将引发索引错误。

def Encoder():
    x= Conv2D(128, (3, 3),padding='same',activation='relu')(input_img)
    x= Conv2D(128, (3, 3), activation='relu', padding='same')(x)
     x= MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(x)
     x= Conv2D(64, (3, 3), activation='relu', padding='same')(x)
     x= Conv2D(64, (3, 3), activation='relu', padding='same')(x)
     x= MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(x)
     x= Conv2D(32, (3, 3), activation='relu', padding='same')(x)
     x= Conv2D(32, (3, 3), activation='relu', padding='same')(x)
     x= Conv2D(32, (3, 3), activation='relu', padding='same')(x)
     x= MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(x)
     x= Conv2D(16, (3, 3), activation='relu', padding='same')(x)
     x= Conv2D(16, (3, 3), activation='relu', padding='same')(x)
     x= Conv2D(8, (3, 3), activation='relu', padding='same')(x)
     x= Conv2D(4, (3, 3), activation='relu', padding='same')(x)
     encoded =Flatten()(x)
     return Model(input_img,encoded)

     # Encoder.add(Flatten())
     # autoencoder.add(Reshape((28, 28, 4)))

 def Decoder():
     encoded=Input(shape=(28, 28, 4))
     #encoded = Reshape((28, 28, 4))(encoded)
     x= Conv2D(4, (3, 3), activation='relu',padding='same')(encoded)
     x= Conv2D(8, (3, 3), activation='relu',padding='same')(x)
     x= Conv2D(16, (3, 3), activation='relu', padding='same')(x)
     x= Conv2D(16, (3, 3), activation='relu', padding='same')(x)
     x= UpSampling2D(size=(2, 2))(x)
     x= Conv2D(32, (3, 3), activation='relu', padding='same')(x)
     x= Conv2D(32, (3, 3), activation='relu', padding='same')(x)
     x= Conv2D(32, (3, 3), activation='relu', padding='same')(x)
     x= UpSampling2D(size=(2, 2))(x)
     x= Conv2D(64, (3, 3), activation='relu', padding='same')(x)
     x= Conv2D(64, (3, 3), activation='relu', padding='same')(x)
     x= UpSampling2D(size=(2, 2))(x)
     x= Conv2D(128, (3, 3), activation='relu', padding='same')(x)
     x= Conv2D(128, (3, 3), activation='relu', padding='same')(x)
     decoded= Conv2D(3, (3, 3), activation='sigmoid', padding='same')(x)
     return Model(encoded,decoded)

 x = Input(shape=(224, 224, 3))

 autoencoder = Model(x, Decoder()(Encoder()(x)))

堆栈跟踪为:

IndexError Traceback (most recent call last) in ()
6
7 # make the model:
 ----> 8 autoencoder = Model(x, Decoder()(Encoder()(x)))
9
10

 ~/VirtualEnvs/deep/lib/python3.5/site- 
 packages/keras/engine/base_layer.py in call(self, inputs, **kwargs)
 455 # Actually call the layer,
 456 # collecting output(s), mask(s), and shape(s).
  --> 457 output = self.call(inputs, **kwargs)
  458 output_mask = self.compute_mask(inputs, previous_mask)
  459

  ~/VirtualEnvs/deep/lib/python3.5/site- 
    packages/keras/engine/network.py in call(self, inputs, mask)
   568 return self._output_tensor_cache[cache_key]
  569 else:
  --> 570 output_tensors, _, _ = self.run_internal_graph(inputs, 
  masks)
  571 return output_tensors
  572

  ~/VirtualEnvs/deep/lib/python3.5/site- 
  packages/keras/engine/network.py in run_internal_graph(self, inputs, 
  masks)
  722 if 'mask' not in kwargs:
  723 kwargs['mask'] = computed_mask
  --> 724 output_tensors = to_list(layer.call(computed_tensor, 
  **kwargs))
  725 output_masks = layer.compute_mask(computed_tensor,
  726 computed_mask)

   ~/VirtualEnvs/deep/lib/python3.5/site- 
   packages/keras/layers/convolutional.py in call(self, inputs)
   166 padding=self.padding,
  167 data_format=self.data_format,
   --> 168 dilation_rate=self.dilation_rate)
   169 if self.rank == 3:
   170 outputs = K.conv3d(

 ~/VirtualEnvs/deep/lib/python3.5/site- 
 packages/keras/backend/tensorflow_backend.py in conv2d(x, kernel, 
 strides, padding, data_format, dilation_rate)
 3563 strides=strides,
 3564 padding=padding,
 -> 3565 data_format=tf_data_format)
 3566
 3567 if data_format == 'channels_first' and tf_data_format == 'NHWC':

  ~/VirtualEnvs/deep/lib/python3.5/site- 
 packages/tensorflow/python/ops/nn_ops.py in convolution(input, 
 filter, padding, strides, dilation_rate, name, data_format)
  777 dilation_rate=dilation_rate,
  778 name=name,
  --> 779 data_format=data_format)
 780 return op(input, filter)
 781

  ~/VirtualEnvs/deep/lib/python3.5/site- 
 packages/tensorflow/python/ops/nn_ops.py in init(self, input_shape, 
 filter_shape, padding, strides, dilation_rate, name, data_format)
826
827 if data_format is None or not data_format.startswith("NC"):
--> 828 input_channels_dim = input_shape[num_spatial_dims + 1]
829 spatial_dims = range(1, num_spatial_dims + 1)
830 else:

~/VirtualEnvs/deep/lib/python3.5/site- 
 packages/tensorflow/python/framework/tensor_shape.py in getitem(self, 
 key)
 613 return TensorShape(self._dims[key])
 614 else:
 --> 615 return self._dims[key]
 616 else:
 617 if isinstance(key, slice):

 **IndexError: list index out of range

2 个答案:

答案 0 :(得分:2)

形状不匹配。 encoder返回编码图像的扁平化表示,即数据的形状为(BS, -1)。在您的解码器中,您正在定义这样的输入

encoded=Input(shape=(28, 28, 4))

即您假设输入的形状(BS, 28, 28, 4)与编码器的输出不兼容。

答案 1 :(得分:0)

您需要更改输入的形状以使其与解码器输出兼容。使用 Reshape((28, 28, 4)) 使其匹配。