仅训练张量流中的一些变量

时间:2016-08-17 05:57:28

标签: python python-2.7 tensorflow

我使用tensorflow进行渐变体面分类。

train_op = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

此处cost是我在优化中使用的成本函数。 在会话中启动图表后,图表可以按以下方式输入:

sess.run(train_op, feed_dict)

这样,成本函数中的所有变量都将被更新,以便最大限度地降低成本。

这是我的问题。如何在训练时只更新成本函数中的一些变量..?有没有办法将创建的变量转换为常量或其他东西..?

1 个答案:

答案 0 :(得分:3)

有几个好的答案,这个主题应该已经关闭:  stackoverflow   Quora

为了避免再次点击此处的人:

tensorflow优化器的最小化函数为此目的采用var_list参数:

first_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                     "scope/prefix/for/first/vars")
first_train_op = optimizer.minimize(cost, var_list=first_train_vars)

second_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                      "scope/prefix/for/second/vars")                     
second_train_op = optimizer.minimize(cost, var_list=second_train_vars)

我从mrry

按原样拍摄

要获取您应该使用的名称列表而不是"scope/prefix/for/second/vars",您可以使用:

tf.get_default_graph().get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)