我尝试在raw_rnn中为loop_fn使用自定义函数,但这有点奇怪
"raise TypeError("loop_fn must be a callable")" # Exception thrown?
呼叫:
callable_loop_fn = loop_fn(
time=time,
previous_output=None,
previous_state=None,
previous_loop_state=None,
_W=W, _b=b,
_decoder_lengths=decoder_lengths,
_pad_step_embedded=pad_step_embedded,
_eos_step_embedded=eos_step_embedded,
_encoder_final_state=encoder_final_state)
# using the functions for the attention decoder
decoder_outputs_ta, decoder_final_state, decoder_loop_state = tf.nn.raw_rnn(decoder_cell, callable_loop_fn)
定义:
def loop_fn(time, previous_output, previous_state, previous_loop_state, _W, _b, _decoder_lengths, _pad_step_embedded, _eos_step_embedded, _encoder_final_state):
if previous_state is None:
assert previous_output is None and previous_state is None
return loop_fn_initial(_decoder_lengths, _eos_step_embedded, _encoder_final_state)
else:
return loop_fn_transition(time, previous_output, previous_state, previous_loop_state, _W, _b, _decoder_lengths, _pad_step_embedded)
有人知道这可能是什么吗?我认为我提供的功能是可调用的,还是我理解错了?
答案 0 :(得分:2)
callable_loop_fn
不是函数,因此它不可调用。
具体而言,callable_loop_fn
是loop_fn()
返回的值,而loop_fn_initial()
则返回loop_fn_initial()
的输出或loop_fn must be a callable
的输出。显然,这两个函数都没有返回函数,因此抛出了异常def loop_fn(time, cell_output, cell_state, loop_state):
...
return (
elements_finished,
next_input,
next_cell_state,
emit_output,
next_loop_state
)
。
根据TF API,你应该写:
tf.nn.raw_rnn
然后将其传递给raw_rnn(decoder_cell, loop_fn)
:
loop_fn
请注意,您应该尊重Unexpected argument
期望接收的参数的数量和顺序,否则您loop_fn
会因函数add_dependencies(my_shared1 my_common_lib)
而出错。因此,您的实现必须重新排列,只需要4个参数。