除了TensorFlow tanh
中的默认BasicRNNCell
之外,我还想尝试一些其他传输功能。
原始实现如下:
class BasicRNNCell(RNNCell):
(...)
def __call__(self, inputs, state, scope=None):
"""Most basic RNN: output = new_state = tanh(W * input + U * state + B)."""
with vs.variable_scope(scope or type(self).__name__): # "BasicRNNCell"
output = tanh(linear([inputs, state], self._num_units, True))
return output, output
......我把它改成了:
class MyRNNCell(BasicRNNCell):
(...)
def __call__(self, inputs, state, scope=None):
"""Most basic RNN: output = new_state = tanh(W * input + U * state + B)."""
with tf.variable_scope(scope or type(self).__name__): # "BasicRNNCell"
output = my_transfer_function(linear([inputs, state], self._num_units, True))
return output, output
将vs.variable_scope
更改为tf.variable_scope
已成功,但linear
是>中的实施。 rnn_cell.py <并且tf
本身无法使用。
我怎样才能让它发挥作用?
我是否必须完全重新实施linear
? (我已经检查了代码,我想我也会遇到依赖问题......)
答案 0 :(得分:2)
您无需为此更改张量流实现的代码。
BasicRNNCell有一个名为激活功能的参数。您只需将其从tf.tanh更改为您想要的任何激活功能。