我只想在自定义层中进行一些数值验证。
假设我们有一个非常简单的自定义层:
class test_layer(keras.layers.Layer):
def __init__(self, **kwargs):
super(test_layer, self).__init__(**kwargs)
def build(self, input_shape):
self.w = K.variable(1.)
self._trainable_weights.append(self.w)
super(test_layer, self).build(input_shape)
def call(self, x, **kwargs):
m = x * x # Set break point here
n = self.w * K.sqrt(x)
return m + n
主程序:
import tensorflow as tf
import keras
import keras.backend as K
input = keras.layers.Input((100,1))
y = test_layer()(input)
model = keras.Model(input,y)
model.predict(np.ones((100,1)))
如果在行m = x * x
上设置断点调试,则执行y = test_layer()(input)
时程序将在此处暂停,这是因为生成了图形,因此调用了call()
方法。 / p>
但是当我使用model.predict()
赋予它真正的价值,并且想要在图层内部正常工作时,它不会停留在m = x * x
行上
我的问题是:
仅在构建计算图时调用call()
方法吗? (提供实际价值时不会调用它吗?)
如何在层内部调试(或在何处插入断点)以在输入实值时查看变量的值?
答案 0 :(得分:2)
是的。 call()
方法仅用于构建计算图。
关于调试。我更喜欢使用TFDBG
,这是张量流的推荐调试工具,尽管它不提供断点功能。
对于Keras,您可以将这些行添加到脚本中以使用TFDBG
from tensorflow.python import debug as tf_debug
sess = K.get_session()
sess = tf_debug.LocalCLIDebugWrapperSession(sess)
TF.set_session(sess)
答案 1 :(得分:2)
在TensorFlow 2中,您现在可以向TensorFlow Keras模型/图层添加断点,包括使用拟合,评估和预测方法时。但是,您必须在调用model.run_eagerly = True
的 之后添加model.compile()
,以使张量的值在调试器中的断点处可用。例如,
import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
class SimpleModel(Model):
def __init__(self):
super().__init__()
self.dense0 = Dense(2)
self.dense1 = Dense(1)
def call(self, inputs):
z = self.dense0(inputs)
z = self.dense1(z) # Breakpoint in IDE here. =====
return z
x = tf.convert_to_tensor([[1, 2, 3], [4, 5, 6]], dtype=tf.float32)
model0 = SimpleModel()
y0 = model0.call(x) # Values of z shown at breakpoint. =====
model1 = SimpleModel()
model1.run_eagerly = True
model1.compile(optimizer=Adam(), loss=BinaryCrossentropy())
y1 = model1.predict(x) # Values of z *not* shown at breakpoint. =====
model2 = SimpleModel()
model2.compile(optimizer=Adam(), loss=BinaryCrossentropy())
model2.run_eagerly = True
y2 = model2.predict(x) # Values of z shown at breakpoint. =====
注意:这已在TensorFlow 2.0.0-rc0
中进行了测试。