加载检查点并使用张量流DNN评估单个图像

时间:2016-05-05 18:57:30

标签: python tensorflow conv-neural-network

对于在大学的研究我正在研究牛津17花亚历山大的例子。该示例使用基于张量流的API tflearn。我的GPU上的训练效果非常好,一段时间后准确度达到了97%。

不幸的是评估单个图像在tflearn中还没有工作,我将不得不使用model.predict(...)来预测每批我的所有数据,并循环遍历我的所有测试集并自己计算准确性。

到目前为止我的培训代码:

...
import image_loader
X, Y = image_loader.load_data(one_hot=True, shuffle=False)

X = X.reshape(244,244)

# Build network
network = input_data(shape=[None, 224, 224, 3])

network = conv_2d(network, 96, 11, strides=4, activation='relu')
network = max_pool_2d(network, 3, strides=2)
network = local_response_normalization(network)

network = conv_2d(network, 256, 5, activation='relu')
network = max_pool_2d(network, 3, strides=2)
network = local_response_normalization(network)

network = conv_2d(network, 384, 3, activation='relu')
network = conv_2d(network, 384, 3, activation='relu')
network = conv_2d(network, 256, 3, activation='relu')
network = max_pool_2d(network, 3, strides=2)
network = local_response_normalization(network)

network = fully_connected(network, 4096, activation='tanh')
network = dropout(network, 0.5)

network = fully_connected(network, 4096, activation='tanh')
network = dropout(network, 0.5)

network = fully_connected(network, 17, activation='softmax')
network = regression(network, optimizer='momentum',
                 loss='categorical_crossentropy',
                 learning_rate=0.01)

# Training
model = tflearn.DNN(network, checkpoint_path='model_ba',
                max_checkpoints=1, tensorboard_verbose=0)
model.fit(X, Y, n_epoch=3, validation_set=0.1, shuffle=True,
      show_metric=True, batch_size=32, snapshot_step=400,
      snapshot_epoch=False, run_id='ba_soccer_network')

代码正在保存检查点“model_ba”以及网络形式的.meta文件。 是否有可能加载保存的检查点并使用张量流评估单个图像?

提前致谢, 阿诺

1 个答案:

答案 0 :(得分:0)

保存:     model.save(' name.tflearn&#39)

用于加载:     model.load(' name.tflearn&#39)

并且在循环测试中只需加载模型并遵循以下代码

files_path = '/your/test/images/directory/path'
img_files_path = os.path.join(files_path, '*.jpg')
img_files = sorted(glob(img_files_path))

for f in img_files:
    try:
        img = Image.open(f).convert('RGB')
        img = ImageOps.fit(img, ((64, 64)), Image.ANTIALIAS)

        img_arr = np.array(img)
        img_arr = img_arr.reshape(-1, 64, 64, 3).astype("float")

        pred = model.predict(img_arr)
        print(" %s" % pred[0])

    except:
        continue