Tensorflow - 如何检查加载的权重

时间:2017-12-16 20:51:45

标签: tensorflow

我有一个在测试时加载的张量流预训练模型。我的问题是,如何验证我的架构中的所有权重都已更新?

  1. 如果我的代码中有额外的权重,tensorflow是否会引发错误?
  2. 如果我的代码中权重较小,tensorflow是否会引发错误?
  3. 以下是一个简单的代码段

    n_classes = 2
    batch_size=1000
    
    x = tf.placeholder(tf.float32, [None, 10, embedding_size], name='embedding') 
    keep_prob = tf.placeholder(tf.float32, name='keep_prob')
    
    # weights - fc
    fc1_w = tf.get_variable("fc1_w", shape=[1024, 256])
    fc2_w = tf.get_variable("fc2_w", shape=[256, 256])
    clf_w = tf.get_variable("clf_w", shape=[256, 2])
    
    fc1_b = tf.get_variable("fc1_b", shape=[256])
    fc2_b = tf.get_variable("fc2_b", shape=[256])
    clf_b = tf.get_variable("clf_b", shape=[2])
    
    # weights - lstm 
    lstm  = tf.nn.rnn_cell.LSTMCell(num_units = 1024, state_is_tuple=True)
    lstm_state = lstm.zero_state(batch_size, tf.float32)
    
    sess = tf.Session()
    saver = tf.train.Saver()
    saver.restore(sess, "./checkpoints/model-24000")
    

2 个答案:

答案 0 :(得分:1)

  

如果我的代码中有额外的权重,tensorflow是否会引发错误   ?

  

如果我的代码中权重较小,tensorflow是否会引发错误?

没有

var_list将检查代码中的所有(可保存或可训练)变量,并在预训练模型中为它​​们分配相同名称的值。

您也可以指定tf.train.Saver(var_list=a_list_of_variables),例如 Uri video = Uri.parse("Your link should be in this place "); mVideoView.setVideoURI(video); mVideoView.setZOrderOnTop(true); //Very important line, add it to Your code mVideoView.setOnPreparedListener(new MediaPlayer.OnPreparedListener() { @Override public void onPrepared(MediaPlayer mediaPlayer) { // here write another part of code, which provides starting the video }} ,以强制它检查列表中的部分变量。

答案 1 :(得分:0)

是的,如果您使用的图表与正在加载的权重之间存在任何差异,则tensorflow会通知您缺少的/其他变量。