我想知道是否有一种方法可以为不同的层使用不同的学习率,就像Caffe中的那样。我正在尝试修改预先训练的模型并将其用于其他任务。我想要的是加快新增加的层的训练,并使训练好的层保持低学习率,以防止它们被扭曲。例如,我有一个5-conv层预训练模型。现在我添加一个新的转换层并对其进行微调。前5层的学习率为0.00001,最后一层的学习率为0.001。知道如何实现这个目标吗?
答案 0 :(得分:76)
使用2个优化器可以很容易地实现:
var_list1 = [variables from first 5 layers]
var_list2 = [the rest of variables]
opt1 = tf.train.GradientDescentOptimizer(0.00001)
opt2 = tf.train.GradientDescentOptimizer(0.0001)
grads = tf.gradients(loss, var_list1 + var_list2)
grads1 = grads[:len(var_list1)]
grads2 = grads[len(var_list1):]
tran_op1 = opt1.apply_gradients(zip(grads1, var_list1))
train_op2 = opt2.apply_gradients(zip(grads2, var_list2))
train_op = tf.group(train_op1, train_op2)
此实现的一个缺点是它在优化器内部计算两次tf.gradients(。),因此在执行速度方面可能不是最佳的。这可以通过显式调用tf.gradients(。),将列表拆分为2并将相应的渐变传递给两个优化器来缓解。
相关问题:Holding variables constant during optimizer
编辑:增加了更有效但更长的实施:
tf.trainable_variables()
您可以使用tf.gradients(.)
获取所有训练变量并决定从中进行选择。
不同之处在于,在第一个实现中,<html>
<head>
<meta http-equiv="Content-Type" content="text/html; charset=iso-8859-1">
<title>Real Time Graph</title>
<?php
include("mydb.php");
// run query
$sql1 = "select to_char(WORKDATE,'dd-Mon-yyyy HH24:MI:SS') as WD,DATA from dat where to_char(WORKDATE,'dd/mm')='25/02'";
$stid1=oci_parse($conn, $sql1);
// set array
$arr1 = array();
if(!$stid1){
$e=oci_error($conn);
trigger_error(htmlentities($e[message],ENT_QUOTES),E_USER_ERROR);
}
$r1=oci_execute($stid1);
if(!$r1){
$e=oci_error($stid1);
trigger_error(htmlentities($e[message],ENT_QUOTES),E_USER_ERROR);
}
// look through query
while($row = oci_fetch_array($stid1,OCI_ASSOC)){
// add each row returned into an array
$arr1[] = array((strtotime($row['WD'])*1000) , (float)$row['DATA']);
}
//oci_free_statement($stid);
//oci_close($conn);
?>
<?php
include("mydb.php");
// run query
$sql2 = "select to_char(WORKDATE,'dd-Mon-yyyy HH24:MI:SS') as WD,DATA from dat where to_char(WORKDATE,'dd/mm')='26/02'";
$stid2=oci_parse($conn, $sql2);
// set array
$arr2 = array();
if(!$stid1){
$e=oci_error($conn);
trigger_error(htmlentities($e[message],ENT_QUOTES),E_USER_ERROR);
}
$r2=oci_execute($stid2);
if(!$r2){
$e=oci_error($stid2);
trigger_error(htmlentities($e[message],ENT_QUOTES),E_USER_ERROR);
}
// look through query
while($row = oci_fetch_array($stid2,OCI_ASSOC)){
// add each row returned into an array
$arr2[] = array((strtotime($row['WD'])*1000) , (float)$row['DATA']);
}
//oci_free_statement($stid);
//oci_close($conn);
?>
<?php
include("mydb.php");
// run query
$sql3 = "select to_char(WORKDATE,'dd-Mon-yyyy HH24:MI:SS') as WD,DATA from dat where to_char(WORKDATE,'dd/mm')='27/02'";
$stid3=oci_parse($conn, $sql3);
// set array
$arr3 = array();
if(!$stid3){
$e=oci_error($conn);
trigger_error(htmlentities($e[message],ENT_QUOTES),E_USER_ERROR);
}
$r3=oci_execute($stid3);
if(!$r3){
$e=oci_error($stid3);
trigger_error(htmlentities($e[message],ENT_QUOTES),E_USER_ERROR);
}
// look through query
while($row = oci_fetch_array($stid3,OCI_ASSOC)){
// add each row returned into an array
// $arr=array_slice($arr,1,50);
$arr3[] = array((strtotime($row['WD'])*1000) , (float)$row['DATA']);
}
//oci_free_statement($stid);
//oci_close($conn);
?>
<script type="text/javascript" src="jquery-1.11.3.min.js"></script>
<script type="text/javascript" src="jquery.flot.js"></script>
<script type="text/javascript">
//$(document).ready(function(){
var updateinterval=5000;
var data1=[];
var data2=[];
var data3=[];
var data4=[];
function getdata(){
//data.shift();
data1=<?php echo json_encode($arr1); ?>;
data2=<?php echo json_encode($arr2); ?>;
data3=<?php echo json_encode($arr3); ?>;
data4=<?php echo json_encode($arr4); ?>;
}
var options={
series: {
lines: {
show: true,
lineWidth: 3,
fill: true,
radius: 5
},
points:{
show: "triangle"
}
},
xaxis: {
mode: "time",
TickSize: [20, "seconds"],
tickFormatter:function (v, axis) {
var date = new Date(v);
if (date.getSeconds() % 20 == 0) {
var dates=date.getDate() <4 ? "0" +date.getDate() : date.getDate();
var months=date.getMonth()< 10 ? "0" +(date.getMonth()+1) :date.getMonth();
var hours = date.getHours() < 10 ? "0" + date.getHours() : date.getHours();
var minutes = date.getMinutes() < 10 ? "0" + date.getMinutes() : date.getMinutes();
var seconds = date.getSeconds() < 10 ? "0" + date.getSeconds() : date.getSeconds();
return dates+ "/"+ months +" "+hours + ":" + minutes + ":" + seconds;
}
else {
return "";
}
},
axisLabel: "Time",
axisLabelUseCanvas: true,
axisLabelFontSizePixels: 12,
axisLabelFontFamily: 'Verdana, Arial',
axisLabelPadding: 10
},
yaxis: {
axisLabel: "Data loading",
axisLabelUseCanvas: true,
axisLabelFontSizePixels: 12,
axisLabelFontFamily: 'Verdana, Arial',
axisLabelPadding: 6
},
legend: {
labelBoxBorderColor: "#B0D5FF"
},
grid: {
hoverable: true,
clickable: true,
backgroundColor: {
colors: ["#B0D5FF", "#5CA8FF"]
}
}
};
$(document).ready(function () {
getdata();
var dataset1=[
{
label: "Day1",
data: data1,
points: {
symbol: "triangle"
}
}
];
var dataset2=[
{
label: "Day2",
data: data2,
points: {
symbol: "cross"
}
}
];
var dataset3=[
{
label: "Day3",
data: data3,
points: {
symbol: "square"
}
}
];
var dataset4=[
{
label: "Day4",
data: data4,
points: {
symbol: "diamond"
}
}
];
$.plot($("#flot-container"), [dataset1], options);
$.plot($("#flot-container1"), [dataset2], options);
$.plot($("#flot-container2"), [dataset3], options);
$.plot($("#flot-container3"), [dataset4], options);
$.plot($("#flot-container4"), [dataset5], options);
function update() {
$.plot($("#flot-container"), dataset1, options);
$.plot($("#flot-container1"), dataset2, options);
$.plot($("#flot-container2"), dataset3, options);
$.plot($("#flot-container3"), dataset4, options);
$.plot($("#flot-container4"), dataset5, options);
// setTimeout(update, updateinterval);
setInterval(update, updateinterval);
}
update();
});
</script>
</head>
<body>
<center>
<h3><b><u>Real-Time Chart</u></b></h3>
<h2>DAY 1</h2>
<div id="flot-container" aligh="right-side" style="width:840px;height:280px;"></div>
<h2>DAY 2 </h2>
<div id="flot-container1" style="width:840px;height:280px;"></div>
<h2>DAY 3 </h2>
<div id="flot-container2" style="width:990px;height:280px;"></div>
</center>
<div id="footer">
Copyright © 2007 - 2014
</div>
</body>
</html>
在优化器内被调用两次。这可能导致执行一些冗余操作(例如,第一层上的梯度可以重用一些计算以用于后续层的梯度)。
答案 1 :(得分:8)
1月22日更新:下面的配方只是GradientDescentOptimizer
的一个好主意,保持运行平均值的其他优化器将在参数更新前应用学习率,因此下面的配方不会影响等式的那部分
除了Rafal的方法,您还可以使用compute_gradients
的{{1}},apply_gradients
界面。例如,这是一个玩具网络,我使用2倍于第二个参数的学习率
Optimizer
你应该看到
x = tf.Variable(tf.ones([]))
y = tf.Variable(tf.zeros([]))
loss = tf.square(x-y)
global_step = tf.Variable(0, name="global_step", trainable=False)
opt = tf.GradientDescentOptimizer(learning_rate=0.1)
grads_and_vars = opt.compute_gradients(loss, [x, y])
ygrad, _ = grads_and_vars[1]
train_op = opt.apply_gradients([grads_and_vars[0], (ygrad*2, y)], global_step=global_step)
init_op = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init_op)
for i in range(5):
sess.run([train_op, loss, global_step])
print sess.run([x, y])
答案 2 :(得分:6)
Tensorflow 1.7引入了tf.custom_gradient
,大大简化了设置学习速率乘数,其方式现在与任何优化器兼容,包括累积梯度统计数据的优化器。例如,
import tensorflow as tf
def lr_mult(alpha):
@tf.custom_gradient
def _lr_mult(x):
def grad(dy):
return dy * alpha * tf.ones_like(x)
return x, grad
return _lr_mult
x0 = tf.Variable(1.)
x1 = tf.Variable(1.)
loss = tf.square(x0) + tf.square(lr_mult(0.1)(x1))
step = tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(loss)
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
tf.local_variables_initializer().run()
for _ in range(5):
sess.run([step])
print(sess.run([x0, x1, loss]))
答案 3 :(得分:5)
收集每个变量的学习速率乘数,如:
self.lr_multipliers[var.op.name] = lr_mult
然后在应用渐变之前应用它们,如:
def _train_op(self):
tf.scalar_summary('learning_rate', self._lr_placeholder)
opt = tf.train.GradientDescentOptimizer(self._lr_placeholder)
grads_and_vars = opt.compute_gradients(self._loss)
grads_and_vars_mult = []
for grad, var in grads_and_vars:
grad *= self._network.lr_multipliers[var.op.name]
grads_and_vars_mult.append((grad, var))
tf.histogram_summary('variables/' + var.op.name, var)
tf.histogram_summary('gradients/' + var.op.name, grad)
return opt.apply_gradients(grads_and_vars_mult)
您可以找到整个示例here。
答案 4 :(得分:0)
谢尔盖·德米亚诺夫(Sergey Demyanov)答案略有不同,您只需指定要更改的学习率
from collections import defaultdict
self.learning_rates = defaultdict(lambda: 1.0)
...
x = tf.layers.Dense(3)(x)
self.learning_rates[x.op.name] = 2.0
...
optimizer = tf.train.MomentumOptimizer(learning_rate=1e-3, momentum=0.9)
grads_and_vars = optimizer.compute_gradients(loss)
grads_and_vars_mult = []
for grad, var in grads_and_vars:
grad *= self.learning_rates[var.op.name]
grads_and_vars_mult.append((grad, var))
train_op = optimizer.apply_gradients(grads_and_vars_mult, tf.train.get_global_step())
答案 5 :(得分:0)
如果您碰巧正在使用tf.slim + slim.learning.create_train_op,那么这里有一个很好的示例: https://github.com/google-research/tf-slim/blob/master/tf_slim/learning.py#L65
# Create the train_op and scale the gradients by providing a map from variable
# name (or variable) to a scaling coefficient:
gradient_multipliers = {
'conv0/weights': 1.2,
'fc8/weights': 3.4,
}
train_op = slim.learning.create_train_op(
total_loss,
optimizer,
gradient_multipliers=gradient_multipliers)
不幸的是,如果要逐渐修改乘数,似乎无法使用tf.Variable而不是float值。