tensorflow2.0具有类初始化和调用的格式
例如
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = Conv2D(32, 3, activation='relu')
self.flatten = Flatten()
self.d1 = Dense(128, activation='relu')
self.d2 = Dense(10, activation='softmax')
def call(self, x):
x = self.conv1(x)
x = self.flatten(x)
x = self.d1(x)
return self.d2(x)
model = MyModel()
我的问题是,如果我想更改
> def call(self, x):
> x = self.conv1(x)
> x = self.flatten(x)
> x = self.d1(x)
> return self.d2(x,activation='relu')
这会导致错误。
如果我想在某些过程中更改属性
我该怎么办?
答案 0 :(得分:0)
如果要根据条件更改前向通行的行为,l可以仅向call
方法中添加参数。
在您的示例中,您似乎想要更改最后一层的激活功能。因此,您可以仅使用线性激活函数定义最后一层,然后根据条件应用所需的激活函数。
class MyModel(tf.keras.Model):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = tf.keras.layers.Conv2D(32, 3, activation='relu')
self.flatten = tf.keras.layers.Flatten()
self.d1 = tf.keras.layers.Dense(128, activation='relu')
# note: no activation = linear activation
self.d2 = tf.keras.layers.Dense(10)
# Create two activation layers
self.relu = tf.keras.layers.ReLU()
self.softmax = tf.keras.layers.Softmax()
def call(self, x, condition):
x = self.conv1(x)
x = self.flatten(x)
x = self.d1(x)
x = self.d2(x)
# Change the activation depending on the condition
if condition:
tf.print("callign with activation=relu")
x = self.relu(x)
return self.softmax(x)
model = MyModel()
fake_input = tf.zeros((1, 28, 28, 1))
tf.print("condition false")
tf.print(model(fake_input, condition=False))
tf.print("condition true")
tf.print(model(fake_input, condition=True))