Tensorflow:工作tf.while_loop不能作为Dataset API输入管道的一部分工作

时间:2017-11-10 13:14:20

标签: tensorflow

我的问题是关于蜗牛图像的图像关键点识别任务。我发现尽管有许多用于分类任务的预先编写的图像增强功能(例如Keras' ImageDataGenerator),但我找不到适合此问题的任何功能,这需要更改输出关键点以匹配随机变换图片。因此,当我从TFRecord读取数据集时,我正在编写自己的数据集。

我使用的逻辑涉及一个while循环,它继续生成随机变换(旋转+移位+缩放等)并将它们应用于真实的关键点,直到它找到一组关键点适合图像的变换。这是为了避免将部分蜗牛留在图像外的转换。然后它会将相同的变换应用于图像并返回它们。

我的问题是,虽然我已经成功地使用这个扩充函数来处理单个测试关键点集,但当我使用相同的函数作为输入管道的一部分时,它不起作用,抛出以下错误:& #39;合并不能有多个有效输入' (完整的痕迹包含在最后)。我无法在任何地方找到解释。

# Defining cond argument to while loop.'ph' are placeholders to match numbers of arguments for tf.while_loop

def not_fit_in_image(landmarks, ph2, ph3, ph4, ph5, ph6):
    # tf logical operators to find if landmarks fit in image
    return landmarks_not_fit_in_image

def augmentation_function(image, original_landmarks):

    def body(ph1, ph2, ph3, ph4, ph5, ph6):

        shift = tf.random_uniform([1, 2], -shift_max, shift_max, tf.float32)
        landmarks = original_landmarks + shift
        # More random transformations generated and applied

        return landmarks, rotation, shift, zoom, y_over_x_proportion_change, shear

    # placeholders to match number of arguments
    ph_a = tf.constant(0, dtype=tf.float32)

    landmarks, rotation, shift, zoom, y_over_x_proportion_change, shear = tf.while_loop(not_fit_in_image, body, [original_landmarks, ph_a, ph_b, ph_a, ph_a, ph_a])

    # In future, would now apply these same transformations to image.

    return image, landmarks


# Setting up input data pipeline using Dataset API
train = tf.data.TFRecordDataset(train_data_tfrecords).map(parse_function)

train = train.map(augmentation_function) # Using the above augmentation function

train = train.repeat().shuffle(buffer_size).batch(batch_size)

# ... Set up handle, iterator, init ops ... all works ...

with tf.Session() as sess:
    train_handle = sess.run(train_iterator.string_handle())
    sess.run(train_init_op)
    train_images, train_landmarks = sess.run(next_batch, feed_dict={handle: train_handle})

发生以下错误:

2017-11-10 13:08:14.449612: W C:\tf_jenkins\home\workspace\rel-win\M\windows-gpu\PY\35\tensorflow\core\framework\op_kernel.cc:1192] Internal: Merge can not have more than one valid input.
     [[Node: while/Merge_5 = Merge[N=2, T=DT_FLOAT](while/Enter_5, while/NextIteration_5)]]
Traceback (most recent call last):
  File "C:\Users\hanne\Anaconda3\envs\tensorflow-gpu\lib\site-packages\tensorflow\python\client\session.py", line 1323, in _do_call
    return fn(*args)
  File "C:\Users\hanne\Anaconda3\envs\tensorflow-gpu\lib\site-packages\tensorflow\python\client\session.py", line 1302, in _run_fn
    status, run_metadata)
  File "C:\Users\hanne\Anaconda3\envs\tensorflow-gpu\lib\site-packages\tensorflow\python\framework\errors_impl.py", line 473, in __exit__
    c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.InternalError: Merge can not have more than one valid input.
     [[Node: while/Merge_5 = Merge[N=2, T=DT_FLOAT](while/Enter_5, while/NextIteration_5)]]
     [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?,384,384], [?,15,2]], output_types=[DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](IteratorFromStringHandle)]]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "C:/Users/hanne/Documents/Tensorflow Projects/Snails/random_rotations_working_while_loop_experiments.py", line 143, in <module>
    train_images, train_landmarks = sess.run(next_batch, feed_dict={handle: train_handle})
  File "C:\Users\hanne\Anaconda3\envs\tensorflow-gpu\lib\site-packages\tensorflow\python\client\session.py", line 889, in run
    run_metadata_ptr)
  File "C:\Users\hanne\Anaconda3\envs\tensorflow-gpu\lib\site-packages\tensorflow\python\client\session.py", line 1120, in _run
    feed_dict_tensor, options, run_metadata)
  File "C:\Users\hanne\Anaconda3\envs\tensorflow-gpu\lib\site-packages\tensorflow\python\client\session.py", line 1317, in _do_run
    options, run_metadata)
  File "C:\Users\hanne\Anaconda3\envs\tensorflow-gpu\lib\site-packages\tensorflow\python\client\session.py", line 1336, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InternalError: Merge can not have more than one valid input.
     [[Node: while/Merge_5 = Merge[N=2, T=DT_FLOAT](while/Enter_5, while/NextIteration_5)]]
     [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?,384,384], [?,15,2]], output_types=[DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](IteratorFromStringHandle)]]

这是我第一次询问有关堆栈溢出的问题,所以关于如何编写更好的问题的任何评论也非常受欢迎!为了简洁,我试图尽可能地删除上面的代码,因此它是最小的但不完整或可验证 - 如果我应该包含更多代码,请告诉我。

修改

我能够弄清楚出了什么问题! tf.while_loop就像一个python while循环,在每次运行&#39; body&#39;之前检查条件,其中包括非常第一次运行。论证&#39; loop_vars&#39;获取第一次检查的变量。我输入了错误格式的占位符值到&#39; loop_vars&#39;,这导致了上面的错误。围绕这一点的一个好方法,对我来说,就是输入第一次运行&#39; body&#39;到loop_vars变量,因为这可以确保它是正确的形式。

0 个答案:

没有答案