我的问题是关于蜗牛图像的图像关键点识别任务。我发现尽管有许多用于分类任务的预先编写的图像增强功能(例如Keras' ImageDataGenerator),但我找不到适合此问题的任何功能,这需要更改输出关键点以匹配随机变换图片。因此,当我从TFRecord读取数据集时,我正在编写自己的数据集。
我使用的逻辑涉及一个while循环,它继续生成随机变换(旋转+移位+缩放等)并将它们应用于真实的关键点,直到它找到一组关键点适合图像的变换。这是为了避免将部分蜗牛留在图像外的转换。然后它会将相同的变换应用于图像并返回它们。
我的问题是,虽然我已经成功地使用这个扩充函数来处理单个测试关键点集,但当我使用相同的函数作为输入管道的一部分时,它不起作用,抛出以下错误:& #39;合并不能有多个有效输入' (完整的痕迹包含在最后)。我无法在任何地方找到解释。
# Defining cond argument to while loop.'ph' are placeholders to match numbers of arguments for tf.while_loop
def not_fit_in_image(landmarks, ph2, ph3, ph4, ph5, ph6):
# tf logical operators to find if landmarks fit in image
return landmarks_not_fit_in_image
def augmentation_function(image, original_landmarks):
def body(ph1, ph2, ph3, ph4, ph5, ph6):
shift = tf.random_uniform([1, 2], -shift_max, shift_max, tf.float32)
landmarks = original_landmarks + shift
# More random transformations generated and applied
return landmarks, rotation, shift, zoom, y_over_x_proportion_change, shear
# placeholders to match number of arguments
ph_a = tf.constant(0, dtype=tf.float32)
landmarks, rotation, shift, zoom, y_over_x_proportion_change, shear = tf.while_loop(not_fit_in_image, body, [original_landmarks, ph_a, ph_b, ph_a, ph_a, ph_a])
# In future, would now apply these same transformations to image.
return image, landmarks
# Setting up input data pipeline using Dataset API
train = tf.data.TFRecordDataset(train_data_tfrecords).map(parse_function)
train = train.map(augmentation_function) # Using the above augmentation function
train = train.repeat().shuffle(buffer_size).batch(batch_size)
# ... Set up handle, iterator, init ops ... all works ...
with tf.Session() as sess:
train_handle = sess.run(train_iterator.string_handle())
sess.run(train_init_op)
train_images, train_landmarks = sess.run(next_batch, feed_dict={handle: train_handle})
发生以下错误:
2017-11-10 13:08:14.449612: W C:\tf_jenkins\home\workspace\rel-win\M\windows-gpu\PY\35\tensorflow\core\framework\op_kernel.cc:1192] Internal: Merge can not have more than one valid input.
[[Node: while/Merge_5 = Merge[N=2, T=DT_FLOAT](while/Enter_5, while/NextIteration_5)]]
Traceback (most recent call last):
File "C:\Users\hanne\Anaconda3\envs\tensorflow-gpu\lib\site-packages\tensorflow\python\client\session.py", line 1323, in _do_call
return fn(*args)
File "C:\Users\hanne\Anaconda3\envs\tensorflow-gpu\lib\site-packages\tensorflow\python\client\session.py", line 1302, in _run_fn
status, run_metadata)
File "C:\Users\hanne\Anaconda3\envs\tensorflow-gpu\lib\site-packages\tensorflow\python\framework\errors_impl.py", line 473, in __exit__
c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.InternalError: Merge can not have more than one valid input.
[[Node: while/Merge_5 = Merge[N=2, T=DT_FLOAT](while/Enter_5, while/NextIteration_5)]]
[[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?,384,384], [?,15,2]], output_types=[DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](IteratorFromStringHandle)]]
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "C:/Users/hanne/Documents/Tensorflow Projects/Snails/random_rotations_working_while_loop_experiments.py", line 143, in <module>
train_images, train_landmarks = sess.run(next_batch, feed_dict={handle: train_handle})
File "C:\Users\hanne\Anaconda3\envs\tensorflow-gpu\lib\site-packages\tensorflow\python\client\session.py", line 889, in run
run_metadata_ptr)
File "C:\Users\hanne\Anaconda3\envs\tensorflow-gpu\lib\site-packages\tensorflow\python\client\session.py", line 1120, in _run
feed_dict_tensor, options, run_metadata)
File "C:\Users\hanne\Anaconda3\envs\tensorflow-gpu\lib\site-packages\tensorflow\python\client\session.py", line 1317, in _do_run
options, run_metadata)
File "C:\Users\hanne\Anaconda3\envs\tensorflow-gpu\lib\site-packages\tensorflow\python\client\session.py", line 1336, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InternalError: Merge can not have more than one valid input.
[[Node: while/Merge_5 = Merge[N=2, T=DT_FLOAT](while/Enter_5, while/NextIteration_5)]]
[[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?,384,384], [?,15,2]], output_types=[DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](IteratorFromStringHandle)]]
这是我第一次询问有关堆栈溢出的问题,所以关于如何编写更好的问题的任何评论也非常受欢迎!为了简洁,我试图尽可能地删除上面的代码,因此它是最小的但不完整或可验证 - 如果我应该包含更多代码,请告诉我。
修改
我能够弄清楚出了什么问题! tf.while_loop就像一个python while循环,在每次运行&#39; body&#39;之前检查条件,其中包括非常第一次运行。论证&#39; loop_vars&#39;获取第一次检查的变量。我输入了错误格式的占位符值到&#39; loop_vars&#39;,这导致了上面的错误。围绕这一点的一个好方法,对我来说,就是输入第一次运行&#39; body&#39;到loop_vars变量,因为这可以确保它是正确的形式。