我目前正在尝试训练Google的草图识别模型,该模型只是链接中的模型:Github。但是我最近遇到了困扰我很长一段时间的问题。
问题如下: 我已使用链接中的代码和quickdraw中的数据来完成培训。我现在有一个带有三个文件(.meta,.index,.data)的训练模型,现在我想为345个类别的训练模型计算混淆矩阵。但是,由于我从未使用过tensorflow的“估计器”,所以我不知道如何将经过训练的模型文件加载到代码中并对其进行测试(不进行训练),以及如何在softmax层之后获得分类分数(用于计算混淆矩阵)。
“估算器” API确实让我很困惑。请根据链接中的代码解决我的问题:
def create_estimator_and_specs(run_config):
"""Creates an Experiment configuration based on the estimator and input fn."""
model_params = tf.contrib.training.HParams(
num_layers=FLAGS.num_layers,
num_nodes=FLAGS.num_nodes,
batch_size=FLAGS.batch_size,
num_conv=ast.literal_eval(FLAGS.num_conv),
conv_len=ast.literal_eval(FLAGS.conv_len),
num_classes=get_num_classes(),
learning_rate=FLAGS.learning_rate,
gradient_clipping_norm=FLAGS.gradient_clipping_norm,
cell_type=FLAGS.cell_type,
batch_norm=FLAGS.batch_norm,
dropout=FLAGS.dropout)
estimator = tf.estimator.Estimator(
model_fn=model_fn,
config=run_config,
params=model_params)
train_spec = tf.estimator.TrainSpec(
input_fn=get_input_fn(
mode=tf.estimator.ModeKeys.TRAIN,
tfrecord_pattern=FLAGS.training_data,
batch_size=FLAGS.batch_size),
max_steps=FLAGS.steps)
eval_spec = tf.estimator.EvalSpec(
input_fn=get_input_fn(
mode=tf.estimator.ModeKeys.EVAL,
tfrecord_pattern=FLAGS.eval_data,
batch_size=FLAGS.batch_size)
)
return estimator, train_spec, eval_spec
def main(unused_args):
estimator, train_spec, eval_spec = create_estimator_and_specs(
run_config=tf.estimator.RunConfig(
model_dir=FLAGS.model_dir,
save_checkpoints_secs=300,
save_summary_steps=100)
)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
我想将训练有素的模型加载到上面的代码中,并计算345个类别的混淆矩阵。
答案 0 :(得分:1)
您可以使用库函数tf.confusion_matrix
tf.confusion_matrix(
labels,
predictions,
num_classes=None,
dtype=tf.int32,
name=None,
weights=None
)
根据预测和标签计算混淆矩阵。
tf.confusion_matrix([1, 2, 4], [2, 2, 4]) ==>
[[0 0 0 0 0]
[0 0 1 0 0]
[0 0 1 0 0]
[0 0 0 0 0]
[0 0 0 0 1]]
针对您的情况,以下代码可能会帮助您:
labels = list(test_set.target)
predictions = list(estimator.predict(input_fn=test_input_fn))
confusion_matrix = tf.confusion_matrix(labels, predictions)
答案 1 :(得分:1)
我不知道如何将训练有素的模型文件加载到代码中, 测试
将数据集用于估算器
tf.data
模块包含一个类的集合,这些类使您可以轻松地加载,操作数据并将其通过管道传递到模型中。
softmax层后如何获得分类分数(用于 计算混淆矩阵)
使用tf.keras
(一种高级API)在TensorFlow中构建和训练模型
test_dataset = keras.datasets.test_dataset
(train_images, train_labels), (test_images, test_labels) = test_dataset.load_data()