典型的tensorflow模型类如下:
class Model:
def __init__(self):
build()
def build(self):
self.x = tf.placeholder()
self.y = f(self.x)
self.z = g(self.y)
如果需要稍作修改(即将self.y=f(self.x)
更改为slef.y=h(self.x)
),我们真的很想继承这个Model
类并添加一些代码来实现这一点。
但是,一旦调用build
函数,便会构建完整的图形。覆盖属性不会改变图形结构。有什么办法可以整齐地完成这项工作吗?
答案 0 :(得分:1)
您可以参数化f
和g
(或您拥有的任何参数)并将它们传递给构造函数:
class Model:
def __init__(self, f=default_f, g=default_g):
self.f = f
self.g = g
self.build()
def build(self):
self.x = tf.placeholder()
self.y = self.f(self.x)
self.z = self.g(self.y)
或者您可以使它们成为可重写的类级变量,以避免构造函数的签名变得肿,但是您不能在构造函数中隐式调用.build()
:
class Model:
f = default_f
g = default_g
def build(self):
self.x = tf.placeholder()
self.y = self.f(self.x)
self.z = self.g(self.y)
# ...
m = Model()
m.f = some_other_f
m.build()