TensorFlow重用可变范围

时间:2017-10-30 17:05:39

标签: tensorflow

这是张量流中的代码,我已经定义了一个Bi-LSTM,对于某个任务,我需要遍历我的图形。虽然我在Scope Variable中设置了reuse = True,但它产生了代码下面提到的错误。

for run in range(0, 2):


   with tf.variable_scope("LSTM", reuse= True) as scope:

     def LSTM(input_data):



        LSTM_cell_fw= tf.contrib.rnn.BasicLSTMCell(num_units= hidden_size)
        LSTM_cell_bw= tf.contrib.rnn.BasicLSTMCell(num_units= hidden_size)          
        output, states = tf.nn.bidirectional_dynamic_rnn(LSTM_cell_fw, LSTM_cell_bw, inputs= input_data, dtype=tf.float32)
        output_1= output[0]
        output_2= output[1]
        output_1= output_1[-1, -1, :]
        output_1= tf.reshape(output_1, shape= (1, hidden_size))
        output_2= output_2[-1, -1, :]
        output_2= tf.reshape(output_2, shape= (1, hidden_size))
        fin_output= tf.concat((output_1, output_2), axis=1)

        return fin_output

,错误是:ValueError:变量bidirectional_rnn / fw / basic_lstm_cell / kernel已经存在,不允许。你的意思是在VarScope中设置reuse = True吗?最初定义于:

LSTM中的文件“alpha-rep.py”,第65行     output,states = tf.nn.bidirectional_dynamic_rnn(LSTM_cell_fw,LSTM_cell_bw,inputs = input_data,dtype = tf.float32)   文件“alpha-rep.py”,第77行,in     out = LSTM(input_data)

1 个答案:

答案 0 :(得分:4)

要重复使用首先必须定义的变量,只有在此之后才能重复使用。

定义用于定义变量的函数:

def LSTM(input_data):
    LSTM_cell_fw= tf.contrib.rnn.BasicLSTMCell(num_units= hidden_size)
    LSTM_cell_bw= tf.contrib.rnn.BasicLSTMCell(num_units= hidden_size)          
    output, states = tf.nn.bidirectional_dynamic_rnn(LSTM_cell_fw, LSTM_cell_bw, inputs= input_data, dtype=tf.float32)
    output_1= output[0]
    output_2= output[1]
    output_1= output_1[-1, -1, :]
    output_1= tf.reshape(output_1, shape= (1, hidden_size))
    output_2= output_2[-1, -1, :]
    output_2= tf.reshape(output_2, shape= (1, hidden_size))
    return tf.concat((output_1, output_2), axis=1)

然后第一次调用它来定义变量并将其置于所需的范围内:

with tf.variable_scope("LSTM", reuse=False) as scope:
    first = LSTM(your_input_here)

现在,您可以在同一范围内定义其他图层,重复使用已定义的变量:

with tf.variable_scope("LSTM", reuse=True) as scope:
    second = LSTM(your_other_input_here)