如何在Tensorflow中设置分层学习率?

时间:2016-01-22 11:22:11

标签: python deep-learning tensorflow

我想知道是否有一种方法可以为不同的层使用不同的学习率,就像Caffe中的那样。我正在尝试修改预先训练的模型并将其用于其他任务。我想要的是加快新增加的层的训练,并使训练好的层保持低学习率,以防止它们被扭曲。例如,我有一个5-conv层预训练模型。现在我添加一个新的转换层并对其进行微调。前5层的学习率为0.00001,最后一层的学习率为0.001。知道如何实现这个目标吗?

6 个答案:

答案 0 :(得分:76)

使用2个优化器可以很容易地实现:

var_list1 = [variables from first 5 layers]
var_list2 = [the rest of variables]
opt1 = tf.train.GradientDescentOptimizer(0.00001)
opt2 = tf.train.GradientDescentOptimizer(0.0001)
grads = tf.gradients(loss, var_list1 + var_list2)
grads1 = grads[:len(var_list1)]
grads2 = grads[len(var_list1):]
tran_op1 = opt1.apply_gradients(zip(grads1, var_list1))
train_op2 = opt2.apply_gradients(zip(grads2, var_list2))
train_op = tf.group(train_op1, train_op2)

此实现的一个缺点是它在优化器内部计算两次tf.gradients(。),因此在执行速度方面可能不是最佳的。这可以通过显式调用tf.gradients(。),将列表拆分为2并将相应的渐变传递给两个优化器来缓解。

相关问题:Holding variables constant during optimizer

编辑:增加了更有效但更长的实施:

tf.trainable_variables()

您可以使用tf.gradients(.)获取所有训练变量并决定从中进行选择。 不同之处在于,在第一个实现中,<html> <head> <meta http-equiv="Content-Type" content="text/html; charset=iso-8859-1"> <title>Real Time Graph</title> <?php include("mydb.php"); // run query $sql1 = "select to_char(WORKDATE,'dd-Mon-yyyy HH24:MI:SS') as WD,DATA from dat where to_char(WORKDATE,'dd/mm')='25/02'"; $stid1=oci_parse($conn, $sql1); // set array $arr1 = array(); if(!$stid1){ $e=oci_error($conn); trigger_error(htmlentities($e[message],ENT_QUOTES),E_USER_ERROR); } $r1=oci_execute($stid1); if(!$r1){ $e=oci_error($stid1); trigger_error(htmlentities($e[message],ENT_QUOTES),E_USER_ERROR); } // look through query while($row = oci_fetch_array($stid1,OCI_ASSOC)){ // add each row returned into an array $arr1[] = array((strtotime($row['WD'])*1000) , (float)$row['DATA']); } //oci_free_statement($stid); //oci_close($conn); ?> <?php include("mydb.php"); // run query $sql2 = "select to_char(WORKDATE,'dd-Mon-yyyy HH24:MI:SS') as WD,DATA from dat where to_char(WORKDATE,'dd/mm')='26/02'"; $stid2=oci_parse($conn, $sql2); // set array $arr2 = array(); if(!$stid1){ $e=oci_error($conn); trigger_error(htmlentities($e[message],ENT_QUOTES),E_USER_ERROR); } $r2=oci_execute($stid2); if(!$r2){ $e=oci_error($stid2); trigger_error(htmlentities($e[message],ENT_QUOTES),E_USER_ERROR); } // look through query while($row = oci_fetch_array($stid2,OCI_ASSOC)){ // add each row returned into an array $arr2[] = array((strtotime($row['WD'])*1000) , (float)$row['DATA']); } //oci_free_statement($stid); //oci_close($conn); ?> <?php include("mydb.php"); // run query $sql3 = "select to_char(WORKDATE,'dd-Mon-yyyy HH24:MI:SS') as WD,DATA from dat where to_char(WORKDATE,'dd/mm')='27/02'"; $stid3=oci_parse($conn, $sql3); // set array $arr3 = array(); if(!$stid3){ $e=oci_error($conn); trigger_error(htmlentities($e[message],ENT_QUOTES),E_USER_ERROR); } $r3=oci_execute($stid3); if(!$r3){ $e=oci_error($stid3); trigger_error(htmlentities($e[message],ENT_QUOTES),E_USER_ERROR); } // look through query while($row = oci_fetch_array($stid3,OCI_ASSOC)){ // add each row returned into an array // $arr=array_slice($arr,1,50); $arr3[] = array((strtotime($row['WD'])*1000) , (float)$row['DATA']); } //oci_free_statement($stid); //oci_close($conn); ?> <script type="text/javascript" src="jquery-1.11.3.min.js"></script> <script type="text/javascript" src="jquery.flot.js"></script> <script type="text/javascript"> //$(document).ready(function(){ var updateinterval=5000; var data1=[]; var data2=[]; var data3=[]; var data4=[]; function getdata(){ //data.shift(); data1=<?php echo json_encode($arr1); ?>; data2=<?php echo json_encode($arr2); ?>; data3=<?php echo json_encode($arr3); ?>; data4=<?php echo json_encode($arr4); ?>; } var options={      series: {          lines: {              show: true,              lineWidth: 3,              fill: true, radius: 5          }, points:{ show: "triangle" }      }, xaxis: {          mode: "time",          TickSize: [20, "seconds"], tickFormatter:function (v, axis) {              var date = new Date(v);    if (date.getSeconds() % 20 == 0) {                  var dates=date.getDate() <4 ? "0" +date.getDate() : date.getDate();                  var months=date.getMonth()< 10 ? "0" +(date.getMonth()+1) :date.getMonth(); var hours = date.getHours() < 10 ? "0" + date.getHours() : date.getHours();       var minutes = date.getMinutes() < 10 ? "0" + date.getMinutes() : date.getMinutes();                  var seconds = date.getSeconds() < 10 ? "0" + date.getSeconds() : date.getSeconds(); return dates+ "/"+ months +" "+hours + ":" + minutes + ":" + seconds;              } else {                  return "";              } },          axisLabel: "Time",          axisLabelUseCanvas: true,          axisLabelFontSizePixels: 12,          axisLabelFontFamily: 'Verdana, Arial',          axisLabelPadding: 10      },      yaxis: {                   axisLabel: "Data loading",          axisLabelUseCanvas: true,           axisLabelFontSizePixels: 12,          axisLabelFontFamily: 'Verdana, Arial',          axisLabelPadding: 6      }, legend: {                 labelBoxBorderColor: "#B0D5FF"      }, grid: { hoverable: true, clickable: true, backgroundColor: { colors: ["#B0D5FF", "#5CA8FF"] }              } }; $(document).ready(function () { getdata(); var dataset1=[ { label: "Day1", data: data1, points: { symbol: "triangle" } } ]; var dataset2=[ { label: "Day2", data: data2, points: { symbol: "cross" } } ]; var dataset3=[ { label: "Day3", data: data3, points: { symbol: "square" } } ]; var dataset4=[ { label: "Day4", data: data4, points: { symbol: "diamond" } } ]; $.plot($("#flot-container"), [dataset1], options); $.plot($("#flot-container1"), [dataset2], options); $.plot($("#flot-container2"), [dataset3], options); $.plot($("#flot-container3"), [dataset4], options); $.plot($("#flot-container4"), [dataset5], options); function update() {         $.plot($("#flot-container"), dataset1, options); $.plot($("#flot-container1"), dataset2, options); $.plot($("#flot-container2"), dataset3, options); $.plot($("#flot-container3"), dataset4, options); $.plot($("#flot-container4"), dataset5, options);     //    setTimeout(update, updateinterval); setInterval(update, updateinterval);      }       update(); }); </script> </head> <body> <center> <h3><b><u>Real-Time Chart</u></b></h3> <h2>DAY 1</h2> <div id="flot-container" aligh="right-side" style="width:840px;height:280px;"></div> <h2>DAY 2 </h2> <div id="flot-container1" style="width:840px;height:280px;"></div> <h2>DAY 3 </h2> <div id="flot-container2" style="width:990px;height:280px;"></div> </center> <div id="footer"> Copyright &copy; 2007 - 2014 </div> </body> </html> 在优化器内被调用两次。这可能导致执行一些冗余操作(例如,第一层上的梯度可以重用一些计算以用于后续层的梯度)。

答案 1 :(得分:8)

1月22日更新:下面的配方只是GradientDescentOptimizer的一个好主意,保持运行平均值的其他优化器将在参数更新前应用学习率,因此下面的配方不会影响等式的那部分

除了Rafal的方法,您还可以使用compute_gradients的{​​{1}},apply_gradients界面。例如,这是一个玩具网络,我使用2倍于第二个参数的学习率

Optimizer

你应该看到

x = tf.Variable(tf.ones([]))
y = tf.Variable(tf.zeros([]))
loss = tf.square(x-y)
global_step = tf.Variable(0, name="global_step", trainable=False)

opt = tf.GradientDescentOptimizer(learning_rate=0.1)
grads_and_vars = opt.compute_gradients(loss, [x, y])
ygrad, _ = grads_and_vars[1]
train_op = opt.apply_gradients([grads_and_vars[0], (ygrad*2, y)], global_step=global_step)

init_op = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init_op)
for i in range(5):
  sess.run([train_op, loss, global_step])
  print sess.run([x, y])

答案 2 :(得分:6)

Tensorflow 1.7引入了tf.custom_gradient,大大简化了设置学习速率乘数,其方式现在与任何优化器兼容,包括累积梯度统计数据的优化器。例如,

import tensorflow as tf

def lr_mult(alpha):
  @tf.custom_gradient
  def _lr_mult(x):
    def grad(dy):
      return dy * alpha * tf.ones_like(x)
    return x, grad
  return _lr_mult

x0 = tf.Variable(1.)
x1 = tf.Variable(1.)
loss = tf.square(x0) + tf.square(lr_mult(0.1)(x1))

step = tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(loss)

sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
tf.local_variables_initializer().run()

for _ in range(5):
  sess.run([step])
  print(sess.run([x0, x1, loss]))

答案 3 :(得分:5)

收集每个变量的学习速率乘数,如:

self.lr_multipliers[var.op.name] = lr_mult

然后在应用渐变之前应用它们,如:

def _train_op(self):
  tf.scalar_summary('learning_rate', self._lr_placeholder)
  opt = tf.train.GradientDescentOptimizer(self._lr_placeholder)
  grads_and_vars = opt.compute_gradients(self._loss)
  grads_and_vars_mult = []
  for grad, var in grads_and_vars:
    grad *= self._network.lr_multipliers[var.op.name]
    grads_and_vars_mult.append((grad, var))
    tf.histogram_summary('variables/' + var.op.name, var)
    tf.histogram_summary('gradients/' + var.op.name, grad)
  return opt.apply_gradients(grads_and_vars_mult)

您可以找到整个示例here

答案 4 :(得分:0)

谢尔盖·德米亚诺夫(Sergey Demyanov)答案略有不同,您只需指定要更改的学习率

from collections import defaultdict

self.learning_rates = defaultdict(lambda: 1.0)
...
x = tf.layers.Dense(3)(x)
self.learning_rates[x.op.name] = 2.0
...
optimizer = tf.train.MomentumOptimizer(learning_rate=1e-3, momentum=0.9)
grads_and_vars = optimizer.compute_gradients(loss)
grads_and_vars_mult = []
for grad, var in grads_and_vars:
    grad *= self.learning_rates[var.op.name]
    grads_and_vars_mult.append((grad, var))
train_op = optimizer.apply_gradients(grads_and_vars_mult, tf.train.get_global_step())

答案 5 :(得分:0)

如果您碰巧正在使用tf.slim + slim.learning.create_train_op,那么这里有一个很好的示例: https://github.com/google-research/tf-slim/blob/master/tf_slim/learning.py#L65

# Create the train_op and scale the gradients by providing a map from variable
  # name (or variable) to a scaling coefficient:
  gradient_multipliers = {
    'conv0/weights': 1.2,
    'fc8/weights': 3.4,
  }
  train_op = slim.learning.create_train_op(
      total_loss,
      optimizer,
      gradient_multipliers=gradient_multipliers)

不幸的是,如果要逐渐修改乘数,似乎无法使用tf.Variable而不是float值。