U net多类分割图像输入数据集错误

时间:2020-11-12 14:06:35

标签: python tensorflow tensorflow-datasets multiclass-classification unity3d-unet

我正在尝试使用U-net进行多类细分。在以前的试验中,我尝试了二进制分割,并且可以正常工作。但是当我尝试进行多类学习时,我会遇到此错误。

**ValueError: 'generator yielded an element of shape (128,192,1) where an element of shape (128,192,5) was expected**   

这5表示类的数量。这就是我定义输出层的方式。 output:Tensor("output/sigmoid:0",shape(?,128,192,5),dtype=float32)

由于灰度图像,我的裁剪尺寸为input_shape:(128,192,1)label_shape:(128,192,5)

数据被加载到tensorflow数据集中并使用tf.iterator。 生成器从tf.dataset产生数据。

def get_datapoint_generator(self):
  def generator():
   for i in itertools.count(1):
    datapoint_dict=self._get_next_datapoint()
    yield datapoint_dict['image'],datapoint_dict['mask']

_get_next_datapoint_函数从ram获取下一个数据点,并处理裁剪和扩充。

现在,与输出形状不匹配的地方哪里出错了?

1 个答案:

答案 0 :(得分:0)

您可以尝试使用此实现吗?我正在使用这个,但是它在Keras中

def sparse_crossentropy(y_true, y_pred):
    nb_classes = K.int_shape(y_pred)[-1]
    y_true = K.one_hot(tf.cast(y_true[:, :, 0], dtype=tf.int32), nb_classes + 1)
    return K.categorical_crossentropy(y_true, y_pred)