这是我经过1000步训练后的模型中预测部分的代码。
class vandys_speak(object):
def __init__(self,session,input_mfcc, model_file):
self.model_file = model_file
self.model = model_trainer(training_mode=False, batch_size=2, sent_max_len = 10)
self.session = session
if len(input_mfcc) == 0:
sample_id = np.random.choice(8000,2)
self.input_mfcc = pickle_read("data/train/final_mfcc_tensor.pkl")[sample_id,:,:]
utt_id_list = pickle_read("data/train/utt_id_list.pkl")
print("utterance id is",utt_id_list[sample_id])
else:
self.input_mfcc = input_mfcc
mdl_dir = "./vandys_model_1_continue/"
self.model.saver.restore(self.session, tf.train.latest_checkpoint(mdl_dir))
self.session.run(tf.tables_initializer())
self.session.run(tf.local_variables_initializer())
text_dir = 'txt_from_mfcc'
txt_from_mfcc = txt_tokenize(text_dir, self.model.src_max_len)
self.txt_from_mfcc = compile_id_token_from_files(txt_from_mfcc[sample_id,:], self.model.word_list)
def speak(self):
emo_train_output, emo_Q_loss, tuning_loss, response_projection =
self.model.brainer_constructor()
logits, sample_id, outputs, final_context_state = \
self.model.brain_interpreter(response_projection=response_projection,
emo_train_output=emo_train_output,
training=False)
self.session.run(tf.global_variables_initializer())
test_id = self.session.run(sample_id,
feed_dict={self.model.mfcc: self.input_mfcc,
self.model.txt_from_mfcc: self.txt_from_mfcc,
self.model.txt_target_input: np.zeros((2,360)),
self.model.txt_target_output: np.zeros((2,360)),
self.model.emo_target_input: np.zeros((2,10)),
self.model.emo_target_output: np.zeros((2,10))})
test_words = ""
test_sentence = []
print(test_id.shape)
for sample in test_id:
sent = sample[:,0]
reference = sample[:,1:]
for id in sent:
test_words += " " + self.model.id2word_table[id]
sent_hypothesis = preid2sentVec(sent, self.model.id2word_table)
sent_reference = [preid2sentVec(sent, self.model.id2word_table) for sent in reference]
vandys_bleu = nltk.translate.bleu_score.sentence_bleu(sent_reference, sent_hypothesis )
if "<eos>" in test_words:
test_words = test_words[:text_words.find('<eos>')]
else:
test_words= test_words
print("The hypothesis word is {0}, the bleu score is {1:.2f}.".format(test_words, vandys_bleu))
test_sentence.append(test_words.split())
test_words = ""
return test_sentence
def preid2sentVec(input, id2word_table):
sentVec = []
sentVec = [id2word_table[id] for id in input]
return sentVec
if __name__ == "__main__":
sess = tf.Session()
vandys = vandys_speak(session=sess,input_mfcc="",model_file="")
vandys.speak()
因为预测的输出当前确实很差,所以如何判断我的模型是否成功加载,而不是仅仅使用随机分布初始化变量?另外我真的不确定我是否在这里使用变量初始化,当我尝试将它们加载到图中时,它们有什么用?
答案 0 :(得分:0)
为确保加载完成且不会再进行其他分配 位置,使用返回的状态对象的 assert_consumed ()方法 通过还原。
如果在检查点中未找到依赖关系图中的任何Python对象,或者任何检查点的值不具有匹配的Python对象,则会引发异常。
使用assert_consumed
,可以确保是否正确还原模型。
这是一个示例:
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
status = checkpoint.restore(tf.train.latest_checkpoint(DIRECTORY))
train_op = optimizer.minimize( ... )
status.assert_consumed() # Optional sanity checks.
with tf.Session() as session:
# Use the Session to restore variables, or initialize them if
# tf.train.latest_checkpoint returned None.
status.initialize_or_restore(session)
for _ in range(num_training_steps):
session.run(train_op)
checkpoint.save(file_prefix=checkpoint_prefix)