如果我有一个继承自tf.Module的类(例如my_module),并且该类内部有tf.keras.Models。我应该使用my_module.variables获取模型中的所有变量吗?
尝试了tf2.0中的一些简单示例。似乎tf.module无法收集tf.keras.models中的变量。
def nn(layers_sizes):
model = tf.keras.Sequential()
for i, size in enumerate(layers_sizes):
model.add(tf.keras.layers.Dense(
units=size,
activation=tf.keras.layers.ReLU() if i < len(layers_sizes) - 1 else None,
))
return model
class Actor(tf.Module):
def __init__(self, name=None):
super(Actor, self).__init__(name=name)
self.check = tf.Variable(initial_value=np.array((1,2,3)))
self.nn = nn([3,16])
def call(self, inputs):
self.check.assign_add(np.array((1,2,3)))
self.x = self.nn(inputs)
if __name__ == "__main__":
model3 = Actor(name="test")
input = np.array((1.,2.,3.,4.)).reshape(-1,1)
model3.call(input)
print(model3.variables)
print(model3.nn.variables)
我期望model3.variables包含model3.nn.variables。
答案 0 :(得分:0)
好的,这是一个错误。已在tf2beta中修复。