获取Keras变量列表

时间:2018-10-30 17:56:34

标签: python-3.x tensorflow keras

我想将Keras模型中的变量与TensorFlow检查点中的变量进行比较。我可以这样获得TF变量:

vars_in_checkpoint = tf.train.list_variables(os.path.join("./model.ckpt"))

如何从model中获取Keras变量进行比较?

2 个答案:

答案 0 :(得分:0)

您可以通过model.weightstf.Variable实例列表)获得Keras模型的变量。

答案 1 :(得分:0)

要获取变量的名称,您需要从模型层的权重属性访问它。像这样:

names = [weight.name for layer in model.layers for weight in layer.weights]

并获得权重的形状:

weights = [weight.shape for weight in model.get_weights()]