我正在尝试使用 tf while循环在双向 lstm 中运行 cell_fw 和 cell_bw body函数使用范围的不同参数(是python字符串)并行调用 lstm 函数。
#varaible declarations
....
....
lstm_scopes =['lstm_1','lstm_2']
def cond(act_outs_tensor, _it):
# run the function twice
return tf.less(_it, tf.constant(2))
def body(act_outs_tensor, _it):
sess = tf.InteractiveSession()
# get integer value of _it to get the req scope to pass
# to pass to the function call
idx = _it.eval() # throws an error saying items inside while are not fetchable
lstm_scope = lstm_scopes[idx]
act_outs_tensor = act_outs_tensor.write(_it,
lstm(hidden_size_1,
dim, _input,
lstm_scope))
return _it + 1, act_outs_tensor
_,act_outs_tensor = tf.while_loop(cond, body, [act_outs_tensor,_it],parallel_iterations=2)
上面的代码失败,因为“在TensorFlow条件的分支中创建的任何op或TensorFlow循环的主体都标记为“不可提取”,以防止发生各种编程错误” source。是否有任何变通办法可以在每次运行body函数时更改作用域的值?
谢谢