我正在尝试在裁剪的图像上测试现有网络。我有一个相当大的数据集,所以我使用的是tensorflow数据集API。我首先创建了一个包含我感兴趣的图像的所有名称的数据集,然后使用flat_map()函数将图像名称的数据集映射到裁剪图像补丁的数据集。
所以,这是问题所在。我不知道将为这个图像生成多少个补丁,我有另一个python函数get_image_regions,它返回一个n×4的numpy数组。
所以我想使用类似的东西:
boxes = tf.py_func(get_image_regions,[im_path],[tf.float32])
获取框的集合并使用该框作为tf.image.crop_and_resize()的输入
但是,由于py_func的返回值是未知的形状和等级,因此它不能用作crop_and_resize()函数的输入。还有其他方法可以解决这个问题吗?
答案 0 :(得分:0)
如果你确定形状是这样的话,你可以tf.reshape()
张量:
boxes = tf.reshape( tf.py_func(get_image_regions, [im_path], [tf.float32]), [ n, 4 ] )
这将修复图表其余部分的形状,并允许您将张量提供给tf.image.crop_and_resize()
但如果输入错误的大小数据则会抛出错误。