我是tensorflow的新手。尝试训练语言模型,但发生错误。我怎么解决这个问题。下面是我的代码。
num_iterations = input_li_size // batch_size
print("number of iterations for each epoch :", num_iterations)
epochs = args.epochs
num_steps = num_iterations * epochs + 1
with tf.Session(graph=graph) as session:
init.run()
print("Initialized - Tensorflow")
average_loss = 0
for step in range(num_steps):
batch_inputs, batch_labels = generate_batch(step, batch_size)
word_list = []
for word in batch_inputs:
word_list.append(word_to_pos_dict[word])
feed_dict = {}
for i in range(batch_size):
feed_dict[words_matrix[i]] = word_list[i]
feed_dict[train_inputs] = batch_inputs
feed_dict[train_labels] = batch_labels
_, loss_val = session.run([optimizer, loss], feed_dict=feed_dict)
average_loss += loss_val
if step % 2000 == 0:
if step > 0:
average_loss /= 2000
print("Average loss at step ", step, ": ", average_loss)
average_loss = 0
if step % 20000 == 0:
pos_embed = pos_embeddings.eval()
# Print nearest words
sim = similarity.eval()
for i in range(valid_size):
valid_pos = pos_reverse_dict[valid_examples[i]]
top_k = 8
nearest = (-sim[i, :]).argsort()[1:top_k + 1]
log_str = 'Nearest to %s:' % str(valid_pos)
for k in range(top_k):
close_word = pos_reverse_dict[nearest[k]]
log_str = '%s %s,' % (log_str, str(close_word))
print(log_str)
# Save vectors
save_model(pos_li, pos_embeddings.eval(), "pos.vec")
这是实际结果。
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-64-d994e55dced8> in <module>()
1 # Save vectors
----> 2 save_model(pos_li, pos_embeddings.eval(), "pos.vec")
~\Anaconda3\lib\site-packages\tensorflow\python\ops\variables.py in eval(self, session)
1649 A numpy `ndarray` with a copy of the value of this variable.
1650 """
-> 1651 return self._variable.eval(session=session)
1652
1653 def initialized_value(self):
~\Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py in eval(self, feed_dict, session)
711
712 """
--> 713 return _eval_using_default_session(self, feed_dict, self.graph, session)
714
715
~\Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py in _eval_using_default_session(tensors, feed_dict, graph, session)
5146 "`eval(session=sess)`")
5147 if session.graph is not graph:
-> 5148 raise ValueError("Cannot use the default session to evaluate tensor: "
5149 "the tensor's graph is different from the session's "
5150 "graph. Pass an explicit session to "
ValueError: Cannot use the default session to evaluate tensor: the tensor's graph is different from the session's graph. Pass an explicit session to `eval(session=sess)`.