TensorFlow:将BasicRNNCell的tanh改为另一个op?

时间:2016-04-29 13:56:10

标签: python inheritance neural-network tensorflow recurrent-neural-network

除了TensorFlow tanh中的默认BasicRNNCell之外,我还想尝试一些其他传输功能。

原始实现如下:

class BasicRNNCell(RNNCell):
(...)
def __call__(self, inputs, state, scope=None):
    """Most basic RNN: output = new_state = tanh(W * input + U * state + B)."""
    with vs.variable_scope(scope or type(self).__name__):  # "BasicRNNCell"
      output = tanh(linear([inputs, state], self._num_units, True))
    return output, output

......我把它改成了:

class MyRNNCell(BasicRNNCell):
(...)
def __call__(self, inputs, state, scope=None):
    """Most basic RNN: output = new_state = tanh(W * input + U * state + B)."""
    with tf.variable_scope(scope or type(self).__name__):  # "BasicRNNCell"
      output = my_transfer_function(linear([inputs, state], self._num_units, True))
    return output, output

vs.variable_scope更改为tf.variable_scope已成功,但linear是>中的实施。 rnn_cell.py <并且tf本身无法使用。

我怎样才能让它发挥作用?

我是否必须完全重新实施linear? (我已经检查了代码,我想我也会遇到依赖问题......)

1 个答案:

答案 0 :(得分:2)

您无需为此更改张量流实现的代码。

BasicRNNCell有一个名为激活功能的参数。您只需将其从tf.tanh更改为您想要的任何激活功能。