在Keras中复制RegisterGradient和gradient_override_map

时间:2018-04-17 10:44:04

标签: python tensorflow keras

以下是用于在tensorflow中注册渐变和覆盖操作渐变的代码。

# Registering a gradient
some_multiplier = 0.5 

@tf.RegisterGradient("AdaGrad")
def _ada_grad(op, grad):
    return grad * some_multiplier 

# Overriding 
g = tf.get_default_graph()
with g.gradient_override_map({"Ada": "AdaGrad"}):
    model.loss = tf.identity(model.loss, name="Ada")

我想在keras中复制同样的东西。在搜索了很多东西之后我找不到任何办法。

我尝试了以下代码,但它没有用。 梯度未被修改。我有和没有渐变覆盖相同的结果。我将some_multiplier设置为零来检查它。

model = Model(...) # Keras model
model.compile(loss='sparse_categorical_crossentropy', optimizer=adadelta, metrics=['accuracy']) # Compiling Keras Model

@tf.RegisterGradient("AdaGrad")
def _ada_grad(op, grad):
    return grad * some_multiplier 

g = tf.get_default_graph()
with g.gradient_override_map({"Ada": "AdaGrad"}):
    model.total_loss = tf.identity(model.total_loss, name="Ada")

3 个答案:

答案 0 :(得分:0)

同样的方法应该有效,但您需要确保使用Keras模型的图表。如果您使用keras.model.Modeltf.keras.Model

,如何检索图表会有所改变
model = Model(...) # Keras model
model.compile(loss='sparse_categorical_crossentropy', optimizer=adadelta, metrics=['accuracy']) # Compiling Keras Model

@tf.RegisterGradient("AdaGrad")
def _ada_grad(op, grad):
    return grad * some_multiplier 

# with keras.model.Model
from keras import backend as K
g = K.get_session().graph
# with tf.keras.Model
g = model.graph

with g.gradient_override_map({"Ada": "AdaGrad"}):
    model.total_loss = tf.identity(model.total_loss, name="Ada")

答案 1 :(得分:0)

TensorFlow的

public class MainActivity extends AppCompatActivity { String TAG = "MainActivity"; Context context; WebView mWebView; @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); setContentView(R.layout.activity_main); context = this; mWebView = (WebView) findViewById(R.id.webview); initWebView(); String ENROLLMENT_URL = "file:///android_asset/about_page.html"; mWebView.loadUrl(ENROLLMENT_URL); } @SuppressLint({ "SetJavaScriptEnabled" }) private void initWebView() { mWebView.getSettings().setJavaScriptEnabled(true); mWebView.setWebChromeClient(new WebChromeClient()); mWebView.addJavascriptInterface(new WebviewInterface(), "Interface"); } public class WebviewInterface { @JavascriptInterface public void javaMehod(String val) { Log.i(TAG, val); Toast.makeText(context, val, Toast.LENGTH_SHORT).show(); } } } 不适用于大多数Keras操作。 我发现的最简单的方法是用TensorFlow实现替换Keras中的操作。

例如,假设考虑了relu激活,那么它将很简单:

<?xml version="1.0" encoding="utf-8"?>
<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context="com.legendblogs.android.MainActivity">

<WebView
    android:layout_width="match_parent"
    android:layout_height="match_parent"
    android:id="@+id/webview"/>


</RelativeLayout>

适用于大多数网络,因为在Keras模型中,通常只有第一个参数用于ReLU。 如果其他操作不匹配,则可以围绕tf模拟创建包装函数,以使参数与Keras匹配。

具有VGG16网络的ReLU示例。

注册渐变。

gradient_override_map

使用自定义渐变初始化网络。

tf.keras.activations.relu = tf.nn.relu
# <function tensorflow.python.keras.activations.relu(x, alpha=0.0, max_value=None, threshold=0)>
# <function tensorflow.python.ops.gen_nn_ops.relu(features, name=None)>

答案 2 :(得分:0)

我有同样的问题。就我而言,我正在使用“ gradient_override_map”尝试实现“引导反向传播”。

@tf.RegisterGradient("GuidedRelu")
def GuidedReluGrad(op, grad):
    grad_filter = tf.cast(grad > 0, "float32")
    output_filter = tf.cast(op.outputs[0] > 0, "float32")
    return output_filter * grad_filter * grad

我正在使用VGG16作为预测模型。

model = VGG16(include_top=True, weights='imagenet')
predicted = np.argmax(model.predict(np.expand_dims(img, axis=0)))

我尝试了以下代码,但没有成功。

with K.get_session().graph.gradient_override_map({'Relu': 'GuidedRelu'}):
    # here is implementation to get gradients
    # but "GuidedRelu" is not used

因此,我在调用“ gradient_override_map”之前创建了一个新的Graph和Session,并成功将渐变函数从“ Relu”更改为“ GuidedRelu”。

new_graph = tf.Graph()
with new_graph.as_default():
    new_sess = tf.Session(graph = new_graph)
    with new_sess.as_default():
        with new_graph.gradient_override_map({'Relu': 'GuidedRelu'}):
            new_model = VGG16(include_top=True, weights='imagenet')
            # here is implementation to get gradients with new graph/session
            # "GuidedRelu" is used

我不知道为什么前者不起作用。但我希望这会有所帮助。