从另一个函数中运行session.run(function)挂起

时间:2019-06-11 21:40:29

标签: python tensorflow

我正在用Tensorflow用Python编写一个递归神经网络,当我尝试在通过“会话”作为参数之后从测试函数中调用session.run(tf.math.greater(...))时,程序挂起。

我只曾使用Tensorflow进行过培训,并且刚刚为MNIST编写了MLP,所以我对此很陌生。每当我在程序挂起后停止它时,它表明键盘中断发生在行if session.run(tf.math.greater(output[j], output[arg_max_index])):

def RNN_test(s_0, _weights, _biases, length, session):
    recurrent_input = tf.zeros([n_recurrent])
    output = [0.0 for i in range(char2idx[s_0])] + [1.0] + [0.0 for i in range(vocab_size - char2idx[s_0] - 1)]
    outs = s_0
    for i in range(length):
        recurrent_input_weighted = tf.math.multiply(recurrent_input, _weights["w_ri"])
        recurrent = tf.add(tf.matmul(tf.convert_to_tensor([output]), _weights["w_r"]), tf.add(recurrent_input_weighted, _biases["b_r"]))
        recurrent_input = tf.nn.tanh(recurrent)
        recurrent = tf.nn.relu(recurrent)
        hidden = tf.nn.relu(tf.add(tf.matmul(recurrent, _weights["w_h"]), _biases["b_h"]))
        output = tf.add(tf.matmul(hidden, _weights["w_o"]), _biases["b_o"])[0]
        outc_one_hot = [0.0 for j in range(vocab_size)]
        arg_max_index = 0
        for j in range(output.get_shape().as_list()[0]):
            if session.run(tf.math.greater(output[j], output[arg_max_index])):
                arg_max_index = j
        outs += idx2char[arg_max_index]
        outc_one_hot[arg_max_index] = 1.0
        output = tf.convert_to_tensor(outc_one_hot)
    return outs

我期望此函数打印出它生成字符串的尝试,但是它什么也不做,中断程序会产生错误消息:

<ipython-input-1-800a434a66d5> in RNN_test(s_0, _weights, _biases, length, session)
     72         arg_max_index = 0
     73         for j in range(output.get_shape().as_list()[0]):
---> 74             if session.run(tf.math.greater(output[j], output[arg_max_index])):
     75                 arg_max_index = j
     76         outs += idx2char[arg_max_index]
Keyboard Interrupt:

0 个答案:

没有答案