我正在尝试在Tensorflow中实现超分辨率模型。我一直在使用以下代码
def _phase_shift(I, r):
# Helper function with main phase shift operation
bsize, a, b, c = I.get_shape().as_list()
X = tf.reshape(I, (bsize, a, b, r, r))
X = tf.transpose(X, (0, 1, 2, 4, 3)) # bsize, a, b, 1, 1
X = tf.split(1, a, X) # a, [bsize, b, r, r]
X = tf.concat(2, [tf.squeeze(x) for x in X]) # bsize, b, a*r, r
X = tf.split(1, b, X) # b, [bsize, a*r, r]
X = tf.concat(2, [tf.squeeze(x) for x in X]) # bsize, a*r, b*r
return tf.reshape(X, (bsize, a*r, b*r, 1))
def PS(X, r, color=False):
# Main OP that you can arbitrarily use in you tensorflow code
if color:
Xc = tf.split(3, 3, X)
X = tf.concat(3, [_phase_shift(x, r) for x in Xc])
else:
X = _phase_shift(X, r)
return X
来自https://github.com/tetrachrome/subpixel进行上采样。但是,由于tf.split的第二个参数必须是不变的,我决定使用tf.depth_to_space代替上面的代码。我想知道tf.depth_to_space和上面的代码是否完全相同。