以下是用于在tensorflow中注册渐变和覆盖操作渐变的代码。
# Registering a gradient
some_multiplier = 0.5
@tf.RegisterGradient("AdaGrad")
def _ada_grad(op, grad):
return grad * some_multiplier
# Overriding
g = tf.get_default_graph()
with g.gradient_override_map({"Ada": "AdaGrad"}):
model.loss = tf.identity(model.loss, name="Ada")
我想在keras中复制同样的东西。在搜索了很多东西之后我找不到任何办法。
我尝试了以下代码,但它没有用。
梯度未被修改。我有和没有渐变覆盖相同的结果。我将some_multiplier
设置为零来检查它。
model = Model(...) # Keras model
model.compile(loss='sparse_categorical_crossentropy', optimizer=adadelta, metrics=['accuracy']) # Compiling Keras Model
@tf.RegisterGradient("AdaGrad")
def _ada_grad(op, grad):
return grad * some_multiplier
g = tf.get_default_graph()
with g.gradient_override_map({"Ada": "AdaGrad"}):
model.total_loss = tf.identity(model.total_loss, name="Ada")
答案 0 :(得分:0)
同样的方法应该有效,但您需要确保使用Keras模型的图表。如果您使用keras.model.Model
或tf.keras.Model
:
model = Model(...) # Keras model
model.compile(loss='sparse_categorical_crossentropy', optimizer=adadelta, metrics=['accuracy']) # Compiling Keras Model
@tf.RegisterGradient("AdaGrad")
def _ada_grad(op, grad):
return grad * some_multiplier
# with keras.model.Model
from keras import backend as K
g = K.get_session().graph
# with tf.keras.Model
g = model.graph
with g.gradient_override_map({"Ada": "AdaGrad"}):
model.total_loss = tf.identity(model.total_loss, name="Ada")
答案 1 :(得分:0)
public class MainActivity extends AppCompatActivity {
String TAG = "MainActivity";
Context context;
WebView mWebView;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
context = this;
mWebView = (WebView) findViewById(R.id.webview);
initWebView();
String ENROLLMENT_URL = "file:///android_asset/about_page.html";
mWebView.loadUrl(ENROLLMENT_URL);
}
@SuppressLint({ "SetJavaScriptEnabled" })
private void initWebView() {
mWebView.getSettings().setJavaScriptEnabled(true);
mWebView.setWebChromeClient(new WebChromeClient());
mWebView.addJavascriptInterface(new WebviewInterface(), "Interface");
}
public class WebviewInterface {
@JavascriptInterface
public void javaMehod(String val) {
Log.i(TAG, val);
Toast.makeText(context, val, Toast.LENGTH_SHORT).show();
}
}
}
不适用于大多数Keras操作。
我发现的最简单的方法是用TensorFlow实现替换Keras中的操作。
例如,假设考虑了relu激活,那么它将很简单:
<?xml version="1.0" encoding="utf-8"?>
<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context="com.legendblogs.android.MainActivity">
<WebView
android:layout_width="match_parent"
android:layout_height="match_parent"
android:id="@+id/webview"/>
</RelativeLayout>
适用于大多数网络,因为在Keras模型中,通常只有第一个参数用于ReLU。 如果其他操作不匹配,则可以围绕tf模拟创建包装函数,以使参数与Keras匹配。
具有VGG16网络的ReLU示例。
注册渐变。
gradient_override_map
使用自定义渐变初始化网络。
tf.keras.activations.relu = tf.nn.relu
# <function tensorflow.python.keras.activations.relu(x, alpha=0.0, max_value=None, threshold=0)>
# <function tensorflow.python.ops.gen_nn_ops.relu(features, name=None)>
答案 2 :(得分:0)
我有同样的问题。就我而言,我正在使用“ gradient_override_map”尝试实现“引导反向传播”。
@tf.RegisterGradient("GuidedRelu")
def GuidedReluGrad(op, grad):
grad_filter = tf.cast(grad > 0, "float32")
output_filter = tf.cast(op.outputs[0] > 0, "float32")
return output_filter * grad_filter * grad
我正在使用VGG16作为预测模型。
model = VGG16(include_top=True, weights='imagenet')
predicted = np.argmax(model.predict(np.expand_dims(img, axis=0)))
我尝试了以下代码,但没有成功。
with K.get_session().graph.gradient_override_map({'Relu': 'GuidedRelu'}):
# here is implementation to get gradients
# but "GuidedRelu" is not used
因此,我在调用“ gradient_override_map”之前创建了一个新的Graph和Session,并成功将渐变函数从“ Relu”更改为“ GuidedRelu”。
new_graph = tf.Graph()
with new_graph.as_default():
new_sess = tf.Session(graph = new_graph)
with new_sess.as_default():
with new_graph.gradient_override_map({'Relu': 'GuidedRelu'}):
new_model = VGG16(include_top=True, weights='imagenet')
# here is implementation to get gradients with new graph/session
# "GuidedRelu" is used
我不知道为什么前者不起作用。但我希望这会有所帮助。