所以我一直在尝试在tensorflow中训练一个单层的编码器 - 解码器网络,由于文档在解释上非常稀疏,它只是如此令人沮丧,而且我只在张量流上采用了斯坦福的CS231n。 / p>
所以这是一个直截了当的模型:
def simple_model(X,Y, is_training):
"""
a simple, single layered encoder decoder network,
that encodes X of shape (batch_size, window_len,
n_comp+1), then decodes Y of shape (batch_size,
pred_len+1, n_comp+1), of which the vector Y[:,0,
:], is simply [0,...,0,1] * batch_size, so that
it starts the decoding
"""
num_units = 128
window_len = X.shape[1]
n_comp = X.shape[2]-1
pred_len = Y.shape[1]-1
init = tf.contrib.layers.variance_scaling_initializer()
encoder_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units)
encoder_output, encoder_state = tf.nn.dynamic_rnn(
encoder_cell,X,dtype = tf.float32)
decoder_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units)
decoder_output, _ = tf.nn.dynamic_rnn(decoder_cell,
encoder_output,
initial_state = encoder_state)
# we expect the shape to be of the shape of Y
print(decoder_output.shape)
proj_layer = tf.layers.dense(decoder_output, n_comp)
return proj_layer
现在我尝试设置培训细节:
tf.reset_default_graph()
X = tf.placeholder(tf.float32, [None, 15, 74])
y = tf.placeholder(tf.float32, [None, 4, 74])
is_training = tf.placeholder(tf.bool)
y_out = simple_model(X,y,is_training)
mean_loss = 0.5*tf.reduce_mean((y_out-y[:,1:,:-1])**2)
optimizer = tf.train.AdamOptimizer(learning_rate=5e-4)
extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(extra_update_ops):
train_step = optimizer.minimize(mean_loss)
好的,现在我得到了这个愚蠢的错误
ValueError:变量rnn / basic_lstm_cell / kernel已经存在,不允许。您是不是要在VarScope中设置reuse = True或reuse = tf.AUTO_REUSE?最初定义于:
答案 0 :(得分:0)
我不确定我是否理解正确。您的图表中有两个BasicLSTMCell
个。根据{{3}},您可能应该使用MultiRNNCell
,如下所示:
encoder_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units)
decoder_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units)
rnn_layers = [encoder_cell, decoder_cell]
multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell(rnn_layers)
decoder_output, state = tf.nn.dynamic_rnn(cell=multi_rnn_cell,
inputs=X,
dtype=tf.float32)
如果这不是您想要的正确架构,并且您需要单独使用这两个BasicLSTMCell
,我认为在定义{{1}时会传递不同/唯一name
}和encoder_cell
将有助于解决此错误。 decoder_cell
会将细胞放在“细胞”下面。范围。如果您未明确定义单元名称,则会导致documentation。