我想将Keras模型中的变量与TensorFlow检查点中的变量进行比较。我可以这样获得TF变量:
vars_in_checkpoint = tf.train.list_variables(os.path.join("./model.ckpt"))
如何从model
中获取Keras变量进行比较?
答案 0 :(得分:0)
您可以通过model.weights
(tf.Variable
实例列表)获得Keras模型的变量。
答案 1 :(得分:0)
要获取变量的名称,您需要从模型层的权重属性访问它。像这样:
names = [weight.name for layer in model.layers for weight in layer.weights]
并获得权重的形状:
weights = [weight.shape for weight in model.get_weights()]