我正在尝试通过修改LSTMCell类来创建贝叶斯RNN,我无法获得每一层的KL损失,我认为这与循环上下文中的RNN有关。
简短
我创建了一个以LSTMcell为基础的新类,并覆盖了必需的函数,这里显示的是“调用”函数,该函数执行必要的计算,即KL损失,并提供了一个吸气剂函数来获取KL损失,我无法添加为最终损失层。我不确定我的问题是否与https://github.com/openai/gradient-checkpointing/issues/7有关,我不确定该问题,也不确定该问题是否是通用的?
class BayesianLSTMCell(LSTMCell):
..... #other bit's not shown for brevity
def get_kl(self):
"""
:returns: the KL loss for the for this lstms weights and bias
"""
return self.kl_loss
def call(self):
#overridding the call method from LSMTCell
....
if(self.isTraining):
with tf.variable_scope("KL_loss_" + self.layer_name, reuse=True):
kl_loss = self.compute_KL_univariate_prior(self.prior, self.theta_w, self.w)
kl_loss += self.compute_KL_univariate_prior(self.prior, self.theta_b, self.b)
self.kl_loss=kl_loss
print("Compute KL loss for LSTM: " + self.layer_name)
print(kl_loss)
.....
#create LSTMs layers
for i in range (0,len(lstm_sizes)):
self.lstms.append(BayesianLSTMCell(lstm_sizes[i], self.prior, self.isTraining, 'lstm'+str(i)))
# Stack up multiple LSTM layer
cell = tf.contrib.rnn.MultiRNNCell(self.lstms)
# Getting an initial state of all zeros
initial_state = cell.zero_state(batch_size, tf.float32)
# perform dynamic unrolling of the network, for variable
lstm_outputs, final_state = tf.nn.dynamic_rnn(cell, embed_input, initial_state=initial_state)
#then I try to access KL
with tf.variable_scope('rnn_loss',reuse=True):
# use cross_entropy as class loss
self.loss = tf.losses.softmax_cross_entropy(onehot_labels=self.groundtruths, logits=self.logits)
if (self.isTraining):
self.KL = self.compute_KL_univariate_prior(self.prior,(self.softmax_w_mean, self.softmax_w_std), self.softmax_w)
self.KL += self.compute_KL_univariate_prior(self.prior,(self.softmax_b_mean, self.softmax_b_std), self.softmax_b)
for i in range(len(self.lstm_sizes)):
self.LSTM_KL=self.lstms[i].get_kl()
print(self.LSTM_KL)
self.KL +=self.LSTM_KL
最后一行由于以下错误而失败,任何指针/建议/解决方案将不胜感激。理想情况下,我想避免这里建议使用static_rnn或raw_rnn。
ValueError: Cannot use 'rnn_cell/rnn/while/rnn/multi_rnn_cell/cell_0/bayesian_lstm_cell/add_2' as input to 'rnn_loss/add_1' because 'rnn_cell/rnn/while/rnn/multi_rnn_cell/cell_0/bayesian_lstm_cell/add_2' is in a while loop. See info log for more details.
完全错误是
File "/home/jehill/python/NeuralNetworks/models/BNN/Sentiment.py", line 214, in <module>
model = SentimentAnalysisMultiLayerLSTM(training=True)
File "/home/jehill/python/NeuralNetworks/models/BNN/Sentiment.py", line 62, in __init__
self.KL +=self.LSTM_KL
File "/usr/local/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/math_ops.py", line 812, in binary_op_wrapper
return func(x, y, name=name)
File "/usr/local/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/gen_math_ops.py", line 365, in add
"Add", x=x, y=y, name=name)
File "/usr/local/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 788, in _apply_op_helper
op_def=op_def)
File "/usr/local/anaconda3/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
return func(*args, **kwargs)
File "/usr/local/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3300, in create_op
op_def=op_def)
File "/usr/local/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1838, in __init__
self._control_flow_post_processing()
File "/usr/local/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1847, in _control_flow_post_processing
control_flow_util.CheckInputFromValidContext(self, input_tensor.op)
File "/usr/local/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_util.py", line 335, in CheckInputFromValidContext
raise ValueError(error_msg + " See info log for more details.")
ValueError: Cannot use 'rnn_cell/rnn/while/rnn/multi_rnn_cell/cell_0/bayesian_lstm_cell/KL_loss_lstm0/add' as input to 'rnn_loss/add_1' because 'rnn_cell/rnn/while/rnn/multi_rnn_cell/cell_0/bayesian_lstm_cell/KL_loss_lstm0/add' is in a while loop. See info log for more details.```
Tensorflow version is 1.13.1 Note: I am aware some of my API above will change in new version 2.0.0.