Tensorflow的gradient_override_map函数

时间:2016-12-30 06:33:39

标签: python tensorflow

有人可以在TensorFlow中解释我gradient_override_map的功能吗? 我无法准确理解它的用法。

我将代码用法视为:

with G.gradient_override_map({"Floor": "Identity"}):
    return tf.reduce_mean(SomeVals) * SomeOtherVal

这到底发生了什么?什么是Identity

2 个答案:

答案 0 :(得分:5)

“Floor”和“Identity”都是操作的类型字符串,前者对应于 tf.floor ,而后者 tf.identity 因此,我想,您的代码的功能是替换 tf.identity 的反向传播梯度(简称BPG)计算机制,用于 tf.floor的BPG计算机制图G中的操作,同时传递 tf.reduce_mean 的输出。这看起来有点奇怪,因为到目前为止我发现的gradient_override_map的所有应用程序中,op_type_map的键始终与用于在上下文中生成输出的操作的类型字符串相同。我的意思是,我更熟悉返回tf.floor(SomeVals)的方案,而不是tf.reduce_mean(SomeVals)

gradient_override_map({op_A_type: op_B_type})做的是用op_B替换op_A的BPG计算机制,同时保留op_A_type的前向传播计算机制。一个常见的gradient_override_map应用程序显示在lahwran的答案中。

@tf.RegisterGradient("CustomGrad")
def _const_mul_grad(unused_op, grad):
    return 5.0 * grad

g = tf.get_default_graph()
with g.gradient_override_map({"Identity": "CustomGrad"}):
    output = tf.identity(input, name="Identity")

通过

@tf.RegisterGradient("CustomGrad")
def _const_mul_grad(unused_op, grad):
    return 5.0 * grad

装饰器tf.RegisterGradient("CustomGrad")_const_mul_grad(unused_op, grad)定义的渐变函数注册为自定义操作类型 - “CustomGrad”,

,而

g = tf.get_default_graph()
with g.gradient_override_map({"Identity": "CustomGrad"}):
    output = tf.identity(input, name="Identity") 

确保字符串类型为“Identity”( tf.identity )的所有操作的输出(在图g中)与 tf.identity 的BPG计算机制相同用字符串类型“CustomGrad”替换为BPG计算机制。

P.S。

  1. op的类型字符串对应于定义操作的proto的OpDef.name字段。要查找操作OpDef.name,请参阅明兴在this question下的回答

  2. 没有必要声明 tf.identity 操作的名称,因为 tf.identity 中的arg'name'是可选的。

答案 1 :(得分:2)

尽我所知,gradient_override_map允许你说“在这种情况下,任何时候使用X的渐变,而是使用Y的渐变”。这意味着你仍然需要Y的渐变成为你想要使用的渐变。

这是我看到的一个例子,在寻找它是如何工作时漂浮:

@tf.RegisterGradient("CustomGrad")
def _const_mul_grad(unused_op, grad):
    return 5.0 * grad

g = tf.get_default_graph()
with g.gradient_override_map({"Identity": "CustomGrad"}):
    output = tf.identity(input, name="Identity")

引用:https://stackoverflow.com/a/43948872/1102705

RegisterGradient()允许您注册您正在定义的新op的渐变,从而允许您拥有一个具有所需渐变的op,然后您可以在渐变覆盖图中使用该op。这有点笨重 - 你正在定义一个没有前锋传球的操作。

我不清楚的是名称=“身份”是否真的有必要。