是否有任何文档描述Keras中哪些字符串名称映射到哪些对象?例如,下面我从tf.keras.layers
创建一个Embedding层,并且可以使用'uniform'
映射到tf.keras.initializers.RandomUniform
类。
tf.keras.layers.Embedding(1000, 64, embeddings_initializer='uniform')
但是我仅通过查看用法示例知道这一点。我认为支持的字符串形式已记录在某处,但我似乎找不到这样的文档,并且对代码的挖掘过于抽象而难以遵循。
版本:TF 1.13.1
答案 0 :(得分:1)
在TF的keras实现中没有可用的字符串常量列表(我想在原始keras中也没有)。
在initializer的情况下,'uniform'
字符串被转换为config,并且在该config上调用了一个fabric方法,并带有从初始化空间命名空间创建对象的提示(可以在此处找到{{3 }}):
config = {'class_name': str(identifier), 'config': {}}
deserialize_keras_object(
config,
module_objects=globals(),
custom_objects=custom_objects,
printable_module_name='initializer')
因此,我想不出比以下方法更好的方法:例如,列出所有初始化程序:
import tensorflow as tf
for k, v in tf.keras.initializers.__dict__.items():
if not k[0].isupper() and not k[0] == "_":
print(k)
尽管具有附加值,但输出类似于:
constant
glorot_normal
glorot_uniform
identity
ones
orthogonal
zeros
he_normal
he_uniform
lecun_normal
lecun_uniform
normal
random_normal
random_uniform
uniform
truncated_normal
deserialize
get
serialize