在TensorFlow中对张量进行批处理操作

时间:2020-05-16 02:10:28

标签: python tensorflow batch-processing tensor

我有一批图像,它们的形状为(2,1024,1024,1),形状为(B,H,W,C),形状为B的TensorFlow张量(例如Tensor 1),其中H是批量大小2,W C是图像尺寸1024,(a,b)是通道数。张量的每个元素(即每个像素)存储一个元组a,其中b0都在255(256,256)范围内。

我有另一个0形状的张量(例如张量2),每个元素在255(1,200,500,1)之间存储一个值。

鉴于此设置,我有以下问题。

我希望将Tensor 1中的每个元素值替换为Tensor 2中的相应元素值。例如,假设张量1中由索引(100,20)给出的元素包含值(100,200)。我想在像素位置(1,200,500,1)的Tensor 2中查找存储的值,并使用该值修改{{1}}处的条目。

我如何以最有效的方式对整个批次执行此操作?

如果有不清楚的地方,请告诉我。我是TensorFlow的初学者,因此将感谢您的帮助。

1 个答案:

答案 0 :(得分:0)

尝试将Tensor 1中的元素值替换为Tensor 2中的相应元素值

tensor1 = tf.Variable(tf.zeros((2,1024,1024,1)), trainable=False)
tensor2 = tf.constant(np.random.randint(0,255, (255,255,1)), dtype='float32')

tensor12 = tensor1[1,200,500].assign(tensor2[100,200])

在此示例中,您将Tensor 1位置(1,200,500,1)中的元素替换为Tensor 2位置(100,200)中的元素


batch_dim = 2
tensor1 = tf.Variable(tf.zeros((batch_dim,1024,1024,1)), trainable=False)
tensor2 = tf.constant(np.random.randint(0,255, (1,255,255,1)), dtype='float32')

tensor12 = tensor1[:,:tensor2.shape[1],:tensor2.shape[2]].assign(tf.repeat(tensor2, batch_dim, axis=0))

使用这些行,将所有批次样品的tensor2的所有值分配给tensor1中的第一个(255,255)位置