Tensorflow - InvalidArgumentError(参见上面的回溯):不兼容的形状:[90]与[8704]

时间:2017-12-14 07:35:12

标签: python tensorflow deep-learning

Tensorflow - InvalidArgumentError(参见上面的追溯):不兼容的形状:[90]与[8704] ** ..

我有不兼容的形状错误。我正在使用Seq2Seq创建一个简单的聊天机器人。 Python版本3.5,Tensorflow版本1.3.0,Windows 10。 我希望你看到下面的代码。

CODE

.... omission ....
hidden_size = 128
layers_size = 2
num_classes = 0
batch_size = 0
vacab_dict = {}
data = []
vocab_file = "data_temp.txt"
dialogue_file = "data.txt"
dir_path = "\\chatbot\\"

.... omission ...

data = load_file(dialogue_file)
enc_input = tf.placeholder(tf.float32, [None, None, num_classes])
dec_input = tf.placeholder(tf.float32, [None, None, num_classes])
targets = tf.placeholder(tf.int64, [None, None])

max_len = 0
word_to_idxs = []
word_to_idx_encoders = []
word_to_idx_decoders = []
word_to_idx_targets = []

.... omission ...


with tf.variable_scope("encode"):
    enc_cell = rnn.BasicRNNCell(hidden_size)  
    enc_cell = rnn.DropoutWrapper(enc_cell, output_keep_prob=0.5)
    stacked_rnn = []
    for iiLyr in range(layers_size):
        stacked_rnn.append(tf.nn.rnn_cell.LSTMCell(num_units=num_classes, state_is_tuple=True))
    enc_cell = tf.nn.rnn_cell.MultiRNNCell(cells=stacked_rnn, state_is_tuple=True)
    outputs, enc_states = tf.nn.dynamic_rnn(enc_cell, enc_input, dtype=tf.float32)

with tf.variable_scope("decode"):
    dec_cell = rnn.BasicRNNCell(hidden_size)
    dec_cell = rnn.DropoutWrapper(dec_cell, output_keep_prob=0.5)
    stacked_rnn = []
    for iiLyr in range(layers_size):
        stacked_rnn.append(tf.nn.rnn_cell.LSTMCell(num_units=num_classes, state_is_tuple=True))
        dec_cell = tf.nn.rnn_cell.MultiRNNCell(cells=stacked_rnn, state_is_tuple=True)
    outputs, dec_states = tf.nn.dynamic_rnn(dec_cell, dec_input, initial_state=enc_states, dtype=tf.float32)

weights = tf.Variable(tf.ones([hidden_size, num_classes]), name="weights")
bias = tf.Variable(tf.zeros([num_classes]), name="bias")


x_for_fc = tf.reshape(outputs, [-1, hidden_size])

logit = tf.matmul(x_for_fc, weights) + bias
batch_size = tf.shape(outputs)[1]

logit = tf.reshape(logit, [-1, batch_size, num_classes])


cost = tf.reduce_mean(
    tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=logit, labels=targets))

sequence_loss = tf.contrib.seq2seq.sequence_loss(
   logits=outputs, targets=targets, weights=weights)
cost = tf.reduce_mean(sequence_loss)

optimizer = tf.train.AdamOptimizer(learning_rate=0.01).minimize(cost)


with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    #print(sess.run(enc_onehots))

    for epoch in range(100):
        _, loss = sess.run([optimizer, cost],
                           feed_dict={enc_input: word_to_idx_encoders,
                                      dec_input: word_to_idx_decoders,
                                      targets: targets_batch})


        print("Epcch: ", "%04d" % (epoch + 1),
              "cost =", "{:.6f}".format(loss))


sess.close()

在中间,代码被省略。

错误

Traceback (most recent call last): File "C:\Users\james\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", 
line 1327, in _do_call return fn(*args) File "C:\Users\james\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", 
line 1306, in _run_fn status, run_metadata) File "C:\Users\james\Anaconda3\lib\contextlib.py", 
line 66, in exit next(self.gen) File "C:\Users\james\Anaconda3\lib\site-packages\tensorflow\python\framework\errors_impl.py", 
line 466, in raise_exception_on_not_ok_status pywrap_tensorflow.TF_GetCode(status)) 
tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes: [90] vs. [8704] [[Node: sequence_loss/mul = Mul[T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/cpu:0"](sequence_loss/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits, sequence_loss/Reshape_2)]]

0 个答案:

没有答案