这是我正在尝试运行的tensorflow模型的一个片段,但不断出现错误。我使用的是tensorflow 1.14.0,我相信代码的作者使用1.08(如果有所不同)。
def select(parameters, length):
"""Select the last valid time step output as the sentence embedding
:params parameters: [batch, seq_len, hidden_dims]
:params length: [batch]
:Returns : [batch, hidden_dims]
"""
shape = tf.shape(parameters)
idx = tf.range(shape[0])
idx = tf.stack([idx, length - 1], axis = 1)
return tf.gather_nd(parameters, idx)
hidden_size_lstm_1 = 200
hidden_size_lstm_2 = 200
tags = 39
word_dim = 300
proj1 = 200
proj2 = 100
words = 20001
batchSize = 2
log_dir = "train"
model_dir = "DAModel"
model_name = "ckpt"
class DAModel():
def __init__(self):
with tf.variable_scope("placeholder"):
self.dialogue_lengths = tf.placeholder(tf.int32, shape = [None], name = "dialogue_lengths")
self.word_ids = tf.placeholder(tf.int32, shape = [None,None,None], name = "word_ids")
self.utterance_lengths = tf.placeholder(tf.int32, shape = [None, None], name = "utterance_lengths")
# self.utterance_lengths = tf.placeholder(tf.float32, shape = [None, None], name = "utterance_lengths")
self.labels = tf.placeholder(tf.int32, shape = [None, None], name = "labels")
self.clip = tf.placeholder(tf.float32, shape = [], name = 'clip')
with tf.variable_scope("embeddings"):
_word_embeddings = tf.get_variable(
name = "_word_embeddings",
dtype = tf.float32,
shape = [words, word_dim],
initializer = tf.random_uniform_initializer()
)
word_embeddings = tf.nn.embedding_lookup(_word_embeddings,self.word_ids, name="word_embeddings")
self.word_embeddings = tf.nn.dropout(word_embeddings, 0.8)
with tf.variable_scope("utterance_encoder"):
s = tf.shape(self.word_embeddings)
batch_size = s[0] * s[1]
time_step = s[-2]
word_embeddings = tf.reshape(self.word_embeddings, [batch_size, time_step, word_dim])
length = tf.reshape(self.utterance_lengths, [batch_size])
fw = tf.nn.rnn_cell.LSTMCell(hidden_size_lstm_1, forget_bias=0.8, state_is_tuple= True)
bw = tf.nn.rnn_cell.LSTMCell(hidden_size_lstm_1, forget_bias=0.8, state_is_tuple= True)
output, _ = tf.nn.bidirectional_dynamic_rnn(fw, bw, word_embeddings,sequence_length=length, dtype = tf.float32)
output = tf.concat(output, axis = -1) # [batch_size, time_step, dim]
# Select the last valid time step output as the utterance embedding,
# this method is more concise than TensorArray with while_loop
output = select(output, self.utterance_lengths) # [batch_size, dim]
output = tf.reshape(output, s[0], s[1], 2 * hidden_size_lstm_1)
output = tf.nn.dropout(output, 0.8)
我不断收到此错误:
Traceback (most recent call last):
File "HBLSTM-CRF.py", line 308, in <module>
main()
File "HBLSTM-CRF.py", line 251, in main
model = DAModel()
File "HBLSTM-CRF.py", line 134, in __init__
output = select(output, self.utterance_lengths) # [batch_size, dim]
File "HBLSTM-CRF.py", line 82, in select
idx = tf.stack([idx, length - 1], axis = 1)
File "C:\Anaconda3\lib\site-packages\tensorflow\python\util\dispatch.py", line 180, in wrapper
return target(*args, **kwargs)
File "C:\Anaconda3\lib\site-packages\tensorflow\python\ops\array_ops.py", line 1046, in stack
return gen_array_ops.pack(values, axis=axis, name=name)
File "C:\Anaconda3\lib\site-packages\tensorflow\python\ops\gen_array_ops.py", line 6971, in pack
"Pack", values=values, axis=axis, name=name)
File "C:\Anaconda3\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 788, in _apply_op_helper
op_def=op_def)
File "C:\Anaconda3\lib\site-packages\tensorflow\python\util\deprecation.py", line 507, in new_func
return func(*args, **kwargs)
File "C:\Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py", line 3616, in create_op
op_def=op_def)
File "C:\Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py", line 2027, in __init__
control_input_ops)
File "C:\Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py", line 1867, in _create_c_op
raise ValueError(str(e))
ValueError: Shapes must be equal rank, but are 1 and 2
From merging shape 0 with other shapes. for 'utterance_encoder/stack' (op: 'Pack') with input shapes: [?], [?,?].
我试图通过将'utterance_lengths'更改为float32来解决此问题,以为在尝试堆叠float32和int32时出现错误,但这没有用。任何建议,将不胜感激。如果可以帮助诊断问题,请在下面的链接中找到整个代码:
https://github.com/YanWenqiang/HBLSTM-CRF/blob/master/HBLSTM-CRF.py
谢谢!