我的目标是训练一个在某些位置提取的图像块上运行的网络(例如,立体声块,其中一个块位于左侧图像的(x,y)处,一个块位于右侧图像的(x + d,y)处) ,我认为最有效的训练方法是将图像和随机样本补丁(x,y,d)加载到生成器中,然后使用dataset.from_generator()来提供训练数据。
但是,在测试时,我想用C ++部署网络。在C ++中有等效的from_generator()吗?
谢谢!
答案 0 :(得分:1)
应该可以使用Dataset.flat_map()
来实现这一点,并且该实现将完全在C ++中运行。使用Python API构建图形,并假设您对sample_x_y_d()
和get_patches_from_images()
有自己的逻辑:
input_dataset = ... # Dataset containing pairs of `(left_img, right_img)`
def generate_samples_fn(left_img, right_img):
num_samples = ...
def sample_x_y_d():
x = ... # Sample a value for `x`.
y = ... # Sample a value for `y`.
d = ... # Sample a value for `d`.
return x, y, d
def get_patches_from_images(x, y, d):
left_patch = ... # Slice a patch at (x, y) from `left_img`.
right_patch = ... # Slice a patch at (x+d, y) from `right_img`.
return left_patch, right_patch
return (tf.data.Dataset.range(num_samples)
.map(lambda _: sample_x_y_d())
.map(get_patches_from_images))
result = input_dataset.flat_map(generate_samples_fn)