在tensorflow中重塑批处理变量

时间:2017-10-19 15:08:01

标签: tensorflow reshape transpose

在Tensorflow中,我想重塑一个表格的批量变量 A [100, 32, 32]形式为[10, 10, 32, 32] B 。但天真地使用重塑失去了订单信息。

例如,A[10,:,:]B[1,0,:,:]不同,其中1表示小批量10s的下一行。

我希望32*32的订单自图像以来保持不变,而我想以下列方式重塑其中的100张图片:

1 - > 2 - > ... - > 10

11 - > 12 - > ... - > 20

... ...

91 - > 92 - > ... - > 100

我怎么能在张量流中做到这一点?

2 个答案:

答案 0 :(得分:0)

tf.reshape做你想要的,如果我理解你想要的。例如:

data = np.array(range(100*32*32)).reshape((100,32,32))
A = tf.constant(data)
B = tf.reshape(A, [10, 10, 32, 32])

with tf.Session() as sess:
    print("A")
    print(sess.run(A)[10, :, :])
    print()
    print("B")
    print(sess.run(B)[1, 0, :, :])

输出:

A
 [[10240 10241 10242 ..., 10269 10270 10271]
  [10272 10273 10274 ..., 10301 10302 10303]
  [10304 10305 10306 ..., 10333 10334 10335]
  ...,
  [11168 11169 11170 ..., 11197 11198 11199]
  [11200 11201 11202 ..., 11229 11230 11231]
  [11232 11233 11234 ..., 11261 11262 11263]]

B
 [[10240 10241 10242 ..., 10269 10270 10271]
  [10272 10273 10274 ..., 10301 10302 10303]
  [10304 10305 10306 ..., 10333 10334 10335]
  ...,
  [11168 11169 11170 ..., 11197 11198 11199]
  [11200 11201 11202 ..., 11229 11230 11231]
  [11232 11233 11234 ..., 11261 11262 11263]]

答案 1 :(得分:0)

您可以使用tf.unstack,然后使用tf.stack

例如:

A = tf.random_uniform([100, 32, 32])
temp = tf.unstack(t)
B = tf.stack([temp[pos:pos+10] for pos in range(0, len(temp), 10)]

或更透明和灵活:

def chunker(seq, size):
    return [seq[pos:pos+size] for pos in range(0, len(seq), size)]

B = tf.stack(chunker(temp,10))