使用TF-SLIM

时间:2017-11-19 16:37:20

标签: tensorflow tf-slim

我是tensorflow的新手,我目前正在尝试学习如何使用TF-SLIM dataset classes解析数据。问题是我目前正在将两个图像组合成一个640x480x6的numpy阵列(6个因为我组合了两个图像的RGB通道)并将它们序列化以将它们保存到 .tfrecords 文件中。这是代码。

img_pair = combine_images(images[i][0],images[i][1])
img_flo = read_flo_file(labels[i][0])

height = img_pair.shape[0]
width = img_pair.shape[1]

img = img_pair.tostring()
flo = img_flo.tostring()

example = image_to_tfexample(
        img, height, width, flo)
tfrecord_writer.write(example.SerializeToString())

def image_to_tfexample(image_data, height, width, flo):
  return tf.train.Example(features=tf.train.Features(feature={
      'image/img_pair': bytes_feature(image_data),
      'image/flo': bytes_feature(flo),
      'image/height': int64_feature(height),
      'image/width': int64_feature(width),
  }))

def combine_images(img1,img2):
  img1 = np.array(Image.open(img1))
  img2 = np.array(Image.open(img2))
  return np.concatenate((img1,img2),axis=-1)

其中img_pair是 640x480x6 numpy数组,而flo是 640x480x2 numpy数组。

现在我想阅读这些例子。这是我迄今为止从tf-slim flower.py(更新以适应我)的例子。

def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
  if split_name not in SPLITS_TO_SIZES:
    raise ValueError('split name %s was not recognized.' % split_name)

  if not file_pattern:
    file_pattern = _FILE_PATTERN
  file_pattern = os.path.join(dataset_dir, file_pattern % split_name)

  # Allowing None in the signature so that dataset_factory can use the default.
  if reader is None:
    reader = tf.TFRecordReader

  keys_to_features =  {
      'image/width': tf.FixedLenFeature([], tf.int64),
      'image/height': tf.FixedLenFeature([], tf.int64),
      'image/img_pair': tf.FixedLenFeature([], tf.string),
      'image/flo': tf.FixedLenFeature([], tf.string)
  }

  items_to_handlers = {
    'image': slim.tfexample_decoder.Tensor('image/img_pair'),
    'label': slim.tfexample_decoder.Tensor('image/flo'),
  }
  decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)

  return slim.dataset.Dataset(
      data_sources=file_pattern,
      reader=reader,
      decoder=decoder,
      num_samples=SPLITS_TO_SIZES[split_name],
      items_to_descriptions=_ITEMS_TO_DESCRIPTIONS)

现在的问题是 image / img_pair 和image / flo 是二进制字符串。他们首先需要转换为Tensors,以便将它们作为item_handlers提供。据我所知。

喜欢这个。

items_to_handlers = {
    'image': slim.tfexample_decoder.Tensor('image/img_pair'),
    'label': slim.tfexample_decoder.Tensor('image/flo'),
  }

但是我不知道如何将它解析回具有相同形状的张量,即img_pair为640x480x6,flo为640x480x2。

这是我得到的错误。

Will save model to /tmp/tfslim_model/
Traceback (most recent call last):
  File "main.py", line 16, in <module>
    images, _, labels = helpers.load_batch(dataset)
  File "/home/muazzam/mywork/python/thesis/SceneflowTensorflow/new_stuff/helpers.py", line 36, in load_batch
    common_queue_min=8)
  File "/home/muazzam/anaconda3/envs/tensorflow/lib/python3.5/site-packages/tensorflow/contrib/slim/python/slim/data/dataset_data_provider.py", line 97, in __init__
    tensors = dataset.decoder.decode(data, items)
  File "/home/muazzam/anaconda3/envs/tensorflow/lib/python3.5/site-packages/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py", line 424, in decode
    outputs.append(handler.tensors_to_item(keys_to_tensors))
  File "/home/muazzam/anaconda3/envs/tensorflow/lib/python3.5/site-packages/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py", line 321, in tensors_to_item
    return self._decode(image_buffer, image_format)
  File "/home/muazzam/anaconda3/envs/tensorflow/lib/python3.5/site-packages/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py", line 350, in _decode
    pred_fn_pairs, default=decode_image, exclusive=True)
  File "/home/muazzam/anaconda3/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/ops/control_flow_ops.py", line 3169, in case
    case_seq = _build_case()
  File "/home/muazzam/anaconda3/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/ops/control_flow_ops.py", line 3151, in _build_case
    strict=strict, name="If_%d" % i)
  File "/home/muazzam/anaconda3/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/util/deprecation.py", line 296, in new_func
    return func(*args, **kwargs)
  File "/home/muazzam/anaconda3/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/ops/control_flow_ops.py", line 1819, in cond
    orig_res_t, res_t = context_t.BuildCondBranch(true_fn)
  File "/home/muazzam/anaconda3/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/ops/control_flow_ops.py", line 1694, in BuildCondBranch
    original_result = fn()
  File "/home/muazzam/anaconda3/envs/tensorflow/lib/python3.5/site-packages/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py", line 338, in decode_image
    return image_ops.decode_image(image_buffer, self._channels)
  File "/home/muazzam/anaconda3/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/ops/image_ops_impl.py", line 1346, in decode_image
    raise ValueError('channels must be in (None, 0, 1, 3, 4)')
ValueError: channels must be in (None, 0, 1, 3, 4)

有人可以帮帮我吗?

0 个答案:

没有答案