如何确认我的张量流模型已成功恢复?

时间:2018-12-21 09:03:32

标签: python tensorflow

这是我经过1000步训练后的模型中预测部分的代码。

class vandys_speak(object):

def __init__(self,session,input_mfcc, model_file):
    self.model_file = model_file
    self.model = model_trainer(training_mode=False, batch_size=2, sent_max_len = 10)
    self.session = session

    if len(input_mfcc) == 0:
        sample_id = np.random.choice(8000,2)
        self.input_mfcc = pickle_read("data/train/final_mfcc_tensor.pkl")[sample_id,:,:]
        utt_id_list = pickle_read("data/train/utt_id_list.pkl")
        print("utterance id is",utt_id_list[sample_id])
    else:
        self.input_mfcc = input_mfcc
    mdl_dir = "./vandys_model_1_continue/"

    self.model.saver.restore(self.session, tf.train.latest_checkpoint(mdl_dir))


    self.session.run(tf.tables_initializer())
    self.session.run(tf.local_variables_initializer())

    text_dir = 'txt_from_mfcc'
    txt_from_mfcc = txt_tokenize(text_dir, self.model.src_max_len)
    self.txt_from_mfcc = compile_id_token_from_files(txt_from_mfcc[sample_id,:], self.model.word_list)

def speak(self):
    emo_train_output, emo_Q_loss, tuning_loss, response_projection = 
    self.model.brainer_constructor()
    logits, sample_id, outputs, final_context_state = \
        self.model.brain_interpreter(response_projection=response_projection,
                               emo_train_output=emo_train_output,
                               training=False)

    self.session.run(tf.global_variables_initializer())

    test_id = self.session.run(sample_id,
             feed_dict={self.model.mfcc: self.input_mfcc,
                        self.model.txt_from_mfcc: self.txt_from_mfcc,
                        self.model.txt_target_input: np.zeros((2,360)),
                        self.model.txt_target_output: np.zeros((2,360)),
                        self.model.emo_target_input: np.zeros((2,10)),
                        self.model.emo_target_output: np.zeros((2,10))})



    test_words = ""
    test_sentence = []

    print(test_id.shape)
    for sample in test_id:

        sent = sample[:,0]
        reference = sample[:,1:]

        for id in sent:

            test_words += " " + self.model.id2word_table[id]

        sent_hypothesis = preid2sentVec(sent, self.model.id2word_table)
        sent_reference = [preid2sentVec(sent, self.model.id2word_table) for sent in reference]

        vandys_bleu = nltk.translate.bleu_score.sentence_bleu(sent_reference, sent_hypothesis )

        if "<eos>" in test_words:
            test_words = test_words[:text_words.find('<eos>')]
        else:
            test_words= test_words

        print("The hypothesis word is {0}, the bleu score is {1:.2f}.".format(test_words, vandys_bleu))


        test_sentence.append(test_words.split())

        test_words = ""

    return test_sentence

def preid2sentVec(input, id2word_table):

sentVec = []

sentVec = [id2word_table[id] for id in input]

return sentVec

if __name__ == "__main__":
    sess = tf.Session()
    vandys = vandys_speak(session=sess,input_mfcc="",model_file="")
    vandys.speak()

因为预测的输出当前确实很差,所以如何判断我的模型是否成功加载,而不是仅仅使用随机分布初始化变量?另外我真的不确定我是否在这里使用变量初始化,当我尝试将它们加载到图中时,它们有什么用?

1 个答案:

答案 0 :(得分:0)

根据Tensorflow docs

  

为确保加载完成且不会再进行其他分配   位置,使用返回的状态对象的 assert_consumed ()方法   通过还原。

     

如果在检查点中未找到依赖关系图中的任何Python对象,或者任何检查点的值不具有匹配的Python对象,则会引发异常。

使用assert_consumed,可以确保是否正确还原模型。 这是一个示例:

checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
status = checkpoint.restore(tf.train.latest_checkpoint(DIRECTORY))
train_op = optimizer.minimize( ... )
status.assert_consumed()  # Optional sanity checks.
with tf.Session() as session:
  # Use the Session to restore variables, or initialize them if
  # tf.train.latest_checkpoint returned None.
  status.initialize_or_restore(session)
  for _ in range(num_training_steps):
     session.run(train_op)
  checkpoint.save(file_prefix=checkpoint_prefix)