我正在用Tensorflow用Python编写一个递归神经网络,当我尝试在通过“会话”作为参数之后从测试函数中调用session.run(tf.math.greater(...))时,程序挂起。
我只曾使用Tensorflow进行过培训,并且刚刚为MNIST编写了MLP,所以我对此很陌生。每当我在程序挂起后停止它时,它表明键盘中断发生在行if session.run(tf.math.greater(output[j], output[arg_max_index])):
def RNN_test(s_0, _weights, _biases, length, session):
recurrent_input = tf.zeros([n_recurrent])
output = [0.0 for i in range(char2idx[s_0])] + [1.0] + [0.0 for i in range(vocab_size - char2idx[s_0] - 1)]
outs = s_0
for i in range(length):
recurrent_input_weighted = tf.math.multiply(recurrent_input, _weights["w_ri"])
recurrent = tf.add(tf.matmul(tf.convert_to_tensor([output]), _weights["w_r"]), tf.add(recurrent_input_weighted, _biases["b_r"]))
recurrent_input = tf.nn.tanh(recurrent)
recurrent = tf.nn.relu(recurrent)
hidden = tf.nn.relu(tf.add(tf.matmul(recurrent, _weights["w_h"]), _biases["b_h"]))
output = tf.add(tf.matmul(hidden, _weights["w_o"]), _biases["b_o"])[0]
outc_one_hot = [0.0 for j in range(vocab_size)]
arg_max_index = 0
for j in range(output.get_shape().as_list()[0]):
if session.run(tf.math.greater(output[j], output[arg_max_index])):
arg_max_index = j
outs += idx2char[arg_max_index]
outc_one_hot[arg_max_index] = 1.0
output = tf.convert_to_tensor(outc_one_hot)
return outs
我期望此函数打印出它生成字符串的尝试,但是它什么也不做,中断程序会产生错误消息:
<ipython-input-1-800a434a66d5> in RNN_test(s_0, _weights, _biases, length, session)
72 arg_max_index = 0
73 for j in range(output.get_shape().as_list()[0]):
---> 74 if session.run(tf.math.greater(output[j], output[arg_max_index])):
75 arg_max_index = j
76 outs += idx2char[arg_max_index]
Keyboard Interrupt: