我希望使用TensorFlow 1.12进行语义分割。我使用from_generator()
创建了一个数据集,其中的生成器如下:
def train_sample_fetcher():
return sample_fetcher()
def val_sample_fetcher():
return sample_fetcher(is_validations=True)
def sample_fetcher(is_validations=False):
sample_names = [filename[:-4] for filename in os.listdir(DIR_DATASET + "ndarrays/")]
if not is_validations: sample_names = sample_names[:int(len(sample_names) * TRAIN_VAL_SPLIT)]
else: sample_names = sample_names[int(len(sample_names) * TRAIN_VAL_SPLIT):]
for sample_name in sample_names:
rgb = tf.image.decode_jpeg(tf.read_file(DIR_DATASET + sample_name + ".jpg"))
rgb = tf.image.resize_images(rgb, (HEIGHT, WIDTH))
#d = tf.image.decode_jpeg(tf.read_file(DIR_DATASET + "depth/" + sample_name + ".jpg"))
#d = tf.image.resize_images(d, (HEIGHT, WIDTH))
#rgbd = tf.concat([rgb,d], axis=2)
onehots = tf.convert_to_tensor(np.load(DIR_DATASET + "ndarrays/" + sample_name + ".npy"), dtype=tf.float32)
yield tf.stack([rgb, onehots])
换句话说,我有一个标签张量,其中每个像素包含一个长度为21(21类)的单热点标签矢量。但是,根据此堆栈跟踪,这是不允许的:
Traceback (most recent call last):
File "semantic_fpn.py", line 89, in <module>
callbacks=[checkpoint_full, checkpoint_weights, tensorboard])
File ".../site-packages/tensorflow/python/keras/engine/training.py", line 1574, in fit
steps=validation_steps)
File ".../site-packages/tensorflow/python/keras/engine/training.py", line 975, in _standardize_user_data
next_element = x.get_next()
File ".../site-packages/tensorflow/python/data/ops/iterator_ops.py", line 623, in get_next
return self._next_internal()
File ".../site-packages/tensorflow/python/data/ops/iterator_ops.py", line 564, in _next_internal
output_shapes=self._flat_output_shapes)
File ".../site-packages/tensorflow/python/ops/gen_dataset_ops.py", line 2266, in iterator_get_next_sync
_six.raise_from(_core._status_to_exception(e.code, message), None)
File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.UnknownError: InvalidArgumentError: Shapes of all inputs must match: values[0].shape = [512,512,3] != values[1].shape = [512,512,21] [Op:Pack] name: stack
为什么不允许这样做?我该如何规避?
答案 0 :(得分:1)
westeros
操作尝试将N个等级K张量合并为一个等级(K + 1)张量。换句话说,它试图沿着新轴连接一系列张量,因此,其他张量轴应该相同。
可以简单地从生成器中返回一对tf.stack
。