TensorFlow中的一阶递归关系

时间:2018-07-21 13:27:29

标签: python numpy tensorflow recurrence lowpass-filter

在TensorFlow中实现递归关系的有效方法是什么?在这种情况下,single-pole low-pass filter。现在,我已将过滤器实现为自定义RNNCell,这在图形构建和执行过程中非常慢。

class LowPassFilter(tf.nn.rnn_cell.RNNCell):
    def __init__(self, sampling_rate, time_constant):
        self._alpha = sampling_rate / (time_constant + sampling_rate)

    @property
    def state_size(self):
        return 1

    @property
    def output_size(self):
        return 1

    def __call__(self, inputs, state, scope=None):
        with tf.variable_scope("low-pass-filter-cell"):
            output = state + self._alpha * (inputs - state)
            return output, output

给出以下虚假数据

onset_n = 200
step_n = 1800
sequence_length = onset + step
k_width = 100
k_height = 100
onset = np.zeros((onset_n, k_width, k_height))
step = np.ones((step_n, k_width, k_height))
test_sequence = np.concatenate((onset, step)).reshape((1, sequence_length, k_width, k_height))

TensorFlow在我的系统上需要大约15秒的时间来执行图形

tf.reset_default_graph()
sequence = tf.placeholder(tf.float32, [1, sequence_length, k_width, k_height])
unstacked_sequence = tf.unstack(sequence, num = sequence_length, axis = 1)
cell = LowPassFilter(1, 200)
outputs, _ = tf.nn.static_rnn(cell, unstacked_sequence, dtype = tf.float32)
outputs = tf.reshape(tf.concat(outputs, 1), [1, sequence_length, k_width, k_height])

# -> Start timer
with tf.Session() as session:
    result = session.run(outputs, {sequence: test_sequence})
# -> End timer

简单的numpy实现大约需要0.1秒的时间

def low_pass_filter(x, sampling_rate, time_constant):
    output = np.zeros(x.shape)
    alpha = sampling_rate / (time_constant + sampling_rate)
    output[0] = x[0]
    for i in range(1, x.shape[0]):
        output[i] = output[i-1] + alpha * (x[i] - output[i-1])
    return output

# -> Start timer
result = low_pass_filter(test_sequence[0, :, :, :], 1, 200)
# -> End timer

在TensorFlow中是否有更聪明,更快速的方法来实现简单的重复关系?

0 个答案:

没有答案