使用Keras / Tensorflow模拟PyTorch切片分配的最佳方法

时间:2018-04-10 13:45:58

标签: python tensorflow keras

我试图模仿PyTorch中的操作:

vol = Variable(torch.FloatTensor(A, B*2, C, D, E).zero_()).cuda()
for i in range(C):
  if i > 0 :
    vol[:, :B, i, :,i:] = input0[:,:,:,i:]
    vol[:, B:, i, :,i:] = input1[:,:,:,:-i]
  else:
    vol[:, :B, i, :,:] = input0
    vol[:, B:, i, :,:] = input1

到目前为止,我尝试在TF中使用以下切片赋值并将其包装在Keras Lambda图层中:

vol = tf.Variable(K.zeros((A, D, E, C, B*2)))
for i in range(C):
  if i > 0:
    vol[:, :, i:, i, :B].assign(input0[:,:,i:,:])
    vol[:, :, i:, i, B:].assign(input1[:,:,:-i,:])
  else:
    vol[:, :, :, i, :B].assign(input0)
    vol[:, :, :, i, B:].assign(input1)
return vol

我也试过vol = vol[...].assign(...)

这会正确地将值分配给vol变量,然后我可以将其转换为张量以在我的图表的其余部分中使用。但是,此操作的渐变在TF(LookupError: No gradient defined for operation 'strided_slice/_assign' (op type: StridedSliceAssign))中未定义,并且渐变不会传播到生成input0input1的先前图层,而它们会似乎是在PyTorch实现中转移的。有没有办法在TF中构建这个相同的变量,以便定义渐变并且我之前的操作没有None渐变?

2 个答案:

答案 0 :(得分:4)

你需要“手动”构建张量。假设input0input1都有形状(ADEB),您可以执行以下操作:

# Make the indexing mask with TensorFlow
in_shape = tf.shape(input0)
in_dims = 4
idx = tf.meshgrid(*[tf.range(in_shape[i]) for i in range(in_dims)], indexing='ij')[2]
idx = tf.expand_dims(idx, axis=3)
r = tf.range(C)[tf.newaxis, tf.newaxis, tf.newaxis, :, tf.newaxis]
mask = idx >= r

# If all dimensions are known at graph construction time, you can instead
# make the mask with NumPy like this to save graph computation time
idx = np.meshgrid(*[np.arange(d) for d in (A, D, E, B)], indexing='ij')[2]
idx = np.expand_dims(idx, 3)
r = np.arange(C)[np.newaxis, np.newaxis, np.newaxis, :, np.newaxis]
mask = idx >= r

# Make the tensor
input0_tile = tf.tile(tf.expand_dims(input0, 3), (1, 1, 1, C, 1))
input1_tile = tf.tile(tf.expand_dims(input1, 3), (1, 1, 1, C, 1))
zero_tile = tf.zeros_like(input0_tile)
vol0 = np.where(mask, input0_tile, zero_tile)
vol1 = np.where(mask, input1_tile, zero_tile)
vol = tf.concat([vol0, vol1], axis=-1)

请注意,您需要第一个或第二个块,然后是第三个块,而不是三个块(请参阅注释)。该代码使用tf.meshgridtf.range索引构建二进制掩码,然后使用tf.where从输入或零中选择值。

答案 1 :(得分:3)

tf.Variable是一种原始/基本类型。你不应该想要渐变来传播它们。

你想要的是构建一个输出你想要的5维张量的节点。

我会在第4维上运行连接操作来构建张量并使用结果代替vol

如果你不关心传播到input0和input1的渐变,那么我只是在tensorflow之外构建张量并将其用作初始化器。