我想使用TensorFlow' s batch_rot90(batch_of_images)
创建一个函数tf.image.rot90()
,后者一次只拍摄一张图像,前者应该一次拍摄一批n张图像(shape = [N,X,Y,F])。
很自然地,人们应该仔细检查批处理中的所有图像并逐个旋转它们。在numpy中,这看起来像:
def batch_rot90(batch):
for i in range(batch.shape[0]):
batch_of_images[i] = rot90(batch[i,:,:,:])
return batch
TensorFlow是如何完成的?
使用tf.while_loop
我得到了他的远:
batch = tf.placeholder(tf.float32, shape=[2, 256, 256, 4])
def batch_rot90(batch, k, name=''):
i = tf.constant(0)
def cond(batch, i):
return tf.less(i, tf.shape(batch)[0])
def body(im, i):
batch[i] = tf.image.rot90(batch[i], k)
i = tf.add(i, 1)
return batch, i
r = tf.while_loop(cond, body, [batch, i])
return r
但是不允许分配给im[i]
,我对使用r返回的内容感到困惑。
我意识到使用tf.batch_to_space()
可能有针对此特定情况的解决方法,但我相信它也应该可以使用某种循环。
答案 0 :(得分:2)
tf中有一个map函数,可以工作:
def batch_rot90(batch, k, name=''):
fun = lambda x: tf.images.rot90(x, k = 1)
return = tf.map_fn(fun, batch)
答案 1 :(得分:0)
更新的答案:
x = tf.placeholder(tf.float32, shape=[2, 3])
def cond(batch, output, i):
return tf.less(i, tf.shape(batch)[0])
def body(batch, output, i):
output = output.write(i, tf.add(batch[i], 10))
return batch, output, i + 1
# TensorArray is a data structure that support dynamic writing
output_ta = tf.TensorArray(dtype=tf.float32,
size=0,
dynamic_size=True,
element_shape=(x.get_shape()[1],))
_, output_op, _ = tf.while_loop(cond, body, [x, output_ta, 0])
output_op = output_op.stack()
with tf.Session() as sess:
print(sess.run(output_op, feed_dict={x: [[1, 2, 3], [0, 0, 0]]}))
我认为您应该考虑使用 tf.scatter_update
更新批处理中的一个图像,而不是使用batch[i] = ...
。有关详细信息,请参阅this link。在你的情况下,我建议将身体的第一行改为:
tf.scatter_update(batch, i, tf.image.rot90(batch[i], k))