对于在大学的研究我正在研究牛津17花亚历山大的例子。该示例使用基于张量流的API tflearn。我的GPU上的训练效果非常好,一段时间后准确度达到了97%。
不幸的是评估单个图像在tflearn中还没有工作,我将不得不使用model.predict(...)来预测每批我的所有数据,并循环遍历我的所有测试集并自己计算准确性。
到目前为止我的培训代码:
...
import image_loader
X, Y = image_loader.load_data(one_hot=True, shuffle=False)
X = X.reshape(244,244)
# Build network
network = input_data(shape=[None, 224, 224, 3])
network = conv_2d(network, 96, 11, strides=4, activation='relu')
network = max_pool_2d(network, 3, strides=2)
network = local_response_normalization(network)
network = conv_2d(network, 256, 5, activation='relu')
network = max_pool_2d(network, 3, strides=2)
network = local_response_normalization(network)
network = conv_2d(network, 384, 3, activation='relu')
network = conv_2d(network, 384, 3, activation='relu')
network = conv_2d(network, 256, 3, activation='relu')
network = max_pool_2d(network, 3, strides=2)
network = local_response_normalization(network)
network = fully_connected(network, 4096, activation='tanh')
network = dropout(network, 0.5)
network = fully_connected(network, 4096, activation='tanh')
network = dropout(network, 0.5)
network = fully_connected(network, 17, activation='softmax')
network = regression(network, optimizer='momentum',
loss='categorical_crossentropy',
learning_rate=0.01)
# Training
model = tflearn.DNN(network, checkpoint_path='model_ba',
max_checkpoints=1, tensorboard_verbose=0)
model.fit(X, Y, n_epoch=3, validation_set=0.1, shuffle=True,
show_metric=True, batch_size=32, snapshot_step=400,
snapshot_epoch=False, run_id='ba_soccer_network')
代码正在保存检查点“model_ba”以及网络形式的.meta文件。 是否有可能加载保存的检查点并使用张量流评估单个图像?
提前致谢, 阿诺
答案 0 :(得分:0)
保存: model.save(' name.tflearn&#39)
用于加载: model.load(' name.tflearn&#39)
并且在循环测试中只需加载模型并遵循以下代码
files_path = '/your/test/images/directory/path'
img_files_path = os.path.join(files_path, '*.jpg')
img_files = sorted(glob(img_files_path))
for f in img_files:
try:
img = Image.open(f).convert('RGB')
img = ImageOps.fit(img, ((64, 64)), Image.ANTIALIAS)
img_arr = np.array(img)
img_arr = img_arr.reshape(-1, 64, 64, 3).astype("float")
pred = model.predict(img_arr)
print(" %s" % pred[0])
except:
continue