Tensorflow - 转移学习实现(语义分割)

时间:2016-07-25 10:36:55

标签: python tensorflow image-segmentation pre-trained-model

我正在努力实现CNN架构(FCN-8s模型,带有预训练的VGG16模型),用于对我自己的数据进行语义分割(2类,因此,每像素二进制分类)

我打算如何解决这个问题:

  1. 使用权重加载预先训练的模型
  2. 添加/删除其他更高层以转换为FCN
  3. 冻结预训练模型的较低层(在训练阶段不更新)
  4. 在特定数据集上训练网络
  5. 假设这是正确的,我如何在tensorflow模型上冻结下层? (我正在寻找具体的实现细节)我看过TensorFlow上的Inception重新训练教程,但我还不太确定。

    这是我想到的工作流程:

    1. 通过现有的预训练模型运行我的数据,并提取功能输出,无需训练它。 (怎么样?)

    2. 将这些功能输出提供给包含更高层的另一个网络 - 并开始训练它。

    3. 任何建议都会有所帮助!

      否则,如果我错了,我该怎么想呢?

      更新:

      我在下面提到了chasep255的建议,并尝试使用tf.stop_gradient以“冻结”我模型中的较低层。很明显,我的实施有问题。可能的替代方案/建议?

      该模型基于FCN(用于语义分割)论文构建。我从模型体系结构中提取logits,即我的特征,我最初直接将其提供给loss函数,以使用softmax分类器将其最小化。 (每像素分类)deconv_1是我的logits张量,形状[batch, h, w, num_classes] = [1, 750, 750, 2]实现:

      logits = vgg_fcn.deconv_1
      
      stopper = tf.stop_gradient(logits, 'stop_gradients')
      
      loss = train_func.loss(stopper, labels_placeholder, 2)
      
      with tf.name_scope('Optimizer'):
          train_op = train_func.training(loss, FLAGS.learning_rate)
      
          with tf.name_scope('Accuracy'):
              eval_correct = train_func.accuracy_eval(logits, labels_placeholder)
              accuracy_summary = tf.scalar_summary('Accuracy', eval_correct)
      

      然后我按如下方式运行这些Graph操作:

      _, acc, loss_value = sess.run([train_op,eval_correct, loss], feed_dict=feed_dict)
      

      当我运行训练周期时,没有优化损失值,绝大多数是因为我引入了tf.stop_gradient Op。

      有关详细信息,请参阅下面的损失函数:

      def loss(logits, labels, num_classes):
      
          logits = tf.reshape(logits, [-1, num_classes])
          #epsilon = tf.constant(value=1e-4)
          #logits = logits + epsilon
      
          labels = tf.to_int64(tf.reshape(labels, [-1]))
          print ('shape of logits: %s' % str(logits.get_shape()))
          print ('shape of labels: %s' % str(labels.get_shape()))
      
          cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, labels, name='Cross_Entropy')
          cross_entropy_mean = tf.reduce_mean(cross_entropy, name='xentropy_mean')
          tf.add_to_collection('losses', cross_entropy_mean)
      
          loss = tf.add_n(tf.get_collection('losses'), name='total_loss')
          return loss
      

1 个答案:

答案 0 :(得分:1)

您可以将预训练模型的输出传递给sess.run(pretrained_output,...)并捕获预训练模型的输出。保存输出后,您可以将其输入模型。在这种情况下,优化器将无法将渐变传播到预训练模型。

您还可以将预训练的模型正常附加到模型上,然后通过tf.stop_graidents()传递预训练的输出,这将阻止优化器将渐变传播回预训练模型。

最后,您可以浏览预训练模型中的所有变量,并将其从可训练变量列表中删除。