WaveNet的Tensorflow实现的time_to_batch和batch_time函数是否必要?

时间:2017-06-26 18:19:30

标签: python tensorflow conv-neural-network

Deepmind最近创建了一种名为WaveNet的新型卷积体系结构。 WaveNet的主要创新是扩张的卷积。这个问题不是关于所做的建模,而是关于他们使用的特定编程技术。 可以找到WaveNet的Tensorflow实现here。 我的问题涉及'ops.py'文件中的一些代码 两个功能,特别是:

def time_to_batch(value, dilation, name=None):
    with tf.name_scope('time_to_batch'):
        shape = tf.shape(value)
        pad_elements = dilation - 1 - (shape[1] + dilation - 1) % dilation
        padded = tf.pad(value, [[0, 0], [0, pad_elements], [0, 0]])
        reshaped = tf.reshape(padded, [-1, dilation, shape[2]])
        transposed = tf.transpose(reshaped, perm=[1, 0, 2])
        return tf.reshape(transposed, [shape[0] * dilation, -1, shape[2]])


def batch_to_time(value, dilation, name=None):
    with tf.name_scope('batch_to_time'):
        shape = tf.shape(value)
        prepared = tf.reshape(value, [dilation, -1, shape[2]])
        transposed = tf.transpose(prepared, perm=[1, 0, 2])
        return tf.reshape(transposed,
                          [tf.div(shape[0], dilation), -1, shape[2]])

然后,其他代码会调用这些函数,如:

def causal_conv(value, filter_, dilation, name='causal_conv'):
    with tf.name_scope(name):
        filter_width = tf.shape(filter_)[0]
        if dilation > 1:
            transformed = time_to_batch(value, dilation)
            conv = tf.nn.conv1d(transformed, filter_, stride=1,
                                padding='VALID')
            restored = batch_to_time(conv, dilation)
        else:
            restored = tf.nn.conv1d(value, filter_, stride=1, padding='VALID')
        # Remove excess elements at the end.
        out_width = tf.shape(value)[1] - (filter_width - 1) * dilation
        result = tf.slice(restored,
                          [0, 0, 0],
                          [-1, out_width, -1])
        return result

所以很明显只有当扩张大于1时才会调用batch_to_time和time_to_batch。我不太明白这些函数是如何工作的(我仍在尝试学习Tensorflow)但是他们必须根据扩张来压缩输入,应用定期卷积,然后取消它。 但是,我注意到API提供了一个名为atrous_conv2d的函数。一个痛苦的卷积是扩张卷积的另一个名称。

最后,我准备问我的问题了!

是否需要上述重塑功能,或者是否可以使用其中一个尺寸设置为1的atrous_conv2D来实现扩张卷积?

毕竟,conv1D是通过调用conv2D来实现的。从here:“在内部,这个op重塑输入张量并调用tf.nn.conv2d”

0 个答案:

没有答案