Tensorflow C ++等效于data.from_generator

时间:2018-07-16 17:42:40

标签: c++ tensorflow tensorflow-datasets

我的目标是训练一个在某些位置提取的图像块上运行的网络(例如,立体声块,其中一个块位于左侧图像的(x,y)处,一个块位于右侧图像的(x + d,y)处) ,我认为最有效的训练方法是将图像和随机样本补丁(x,y,d)加载到生成器中,然后使用dataset.from_generator()来提供训练数据。

但是,在测试时,我想用C ++部署网络。在C ++中有等效的from_generator()吗?

谢谢!

1 个答案:

答案 0 :(得分:1)

应该可以使用Dataset.flat_map()来实现这一点,并且该实现将完全在C ++中运行。使用Python API构建图形,并假设您对sample_x_y_d()get_patches_from_images()有自己的逻辑:

input_dataset = ...  # Dataset containing pairs of `(left_img, right_img)`

def generate_samples_fn(left_img, right_img):
  num_samples = ...

  def sample_x_y_d():
    x = ...  # Sample a value for `x`.
    y = ...  # Sample a value for `y`.
    d = ...  # Sample a value for `d`.
    return x, y, d

  def get_patches_from_images(x, y, d):
    left_patch = ...  # Slice a patch at (x, y) from `left_img`.
    right_patch = ...  # Slice a patch at (x+d, y) from `right_img`.
    return left_patch, right_patch

  return (tf.data.Dataset.range(num_samples)
          .map(lambda _: sample_x_y_d())
          .map(get_patches_from_images))

result = input_dataset.flat_map(generate_samples_fn)