在Tensorflow2.0中,我发现可以通过以下方式初始化模型中的变量
class MyModel(tf.keras.Model):
def __init__(self, *args, kwargs**):
""" some definition here """
self(tf.keras.Input(shape=(3,)))
def call(self, x):
""" some implementation """
但是我不能做
class MyModel(tf.keras.Model):
def __init__(self, *args, kwargs**):
""" some definition here """
self.step(tf.keras.Input(shape=(3,)))
def step(self, x):
""" some implementation """
这将导致错误
我想做第二个原因的原因是我尝试从MyModel
继承tf.Module
,而__call__
没有可用的tf.Module
---即使定义了一个,也会出现相同的错误。我想知道是否可以像在第一个代码块中那样初始化从{
"Statement": [
{
"Effect": "Allow",
"Action": [
"ec2:CreateSnapshot",
"ec2:CreateTags",
"ec2:DeleteSnapshot",
"ec2:DescribeSnapshots",
"ec2:DescribeTags"
],
"Resource": [
"*"
]
}
]
}
继承的类中的变量?
答案 0 :(得分:2)
可悲的是Keras功能/符号API仅与Keras(例如compile + fit)兼容。
您可能可以改为使用src > typings > index.d.ts
(例如src > typings > moduleA > index.d.ts
),尽管这可能会带来不良的副作用(例如,影响批次规范统计信息)。
如果您想要一个可靠的解决方案,则可能要考虑使用Sonnet 2(包含一堆常见src > typings > moduleB > index.d.ts
的库)中的实用程序功能tf.zeros
[0]。
[0] https://github.com/deepmind/sonnet/blob/v2/sonnet/src/build.py#L50