我使用tflearn训练了一些数据模型来进行二进制分类。该模型的训练精度达到97%。
我想在另一个程序中使用model.load()
来预测某些测试输入数据的类。
但是,model.load()
仅在我包含参数weights_only=True
时才有效。当我从model.load()中省略该参数时,它会抛出一个错误:
NotFoundError (see above for traceback): Key is_training not found in checkpoint
当我加载模型并在我的小测试集上运行一些预测时 - 分类看起来很奇怪。模型预测每次第一个索引中的完美1。对我来说,如果模型训练得非常精确,那么这不应该发生。这是预测的样子(右边的预期输出):
[[ 5.59889193e-22 1.00000000e+00] [0, 1]
[ 4.25160435e-22 1.00000000e+00] [0, 1]
[ 6.65333618e-23 1.00000000e+00] [0, 1]
[ 2.07748895e-21 1.00000000e+00] [0, 1]
[ 1.77639440e-21 1.00000000e+00] [0, 1]
[ 5.77486922e-18 1.00000000e+00] [1, 0]
[ 2.70562403e-19 1.00000000e+00] [1, 0]
[ 2.78288828e-18 1.00000000e+00] [1, 0]
[ 6.10306495e-17 1.00000000e+00] [1, 0]
[ 2.35787162e-19 1.00000000e+00]] [1, 0]
注意:此测试数据是用于训练模型的数据,因此应该能够以高精度正确分类。
培训模型的代码:
tf.reset_default_graph()
train = pd.read_csv("/Users/darrentaggart/Library/Mobile Documents/com~apple~CloudDocs/Uni Documents/MEE4040 - Project 4/Coding Related Stuff/Neural Networks/modeltraindata_1280.csv")
test = pd.read_csv("/Users/darrentaggart/Library/Mobile Documents/com~apple~CloudDocs/Uni Documents/MEE4040 - Project 4/Coding Related Stuff/Neural Networks/modeltestdata_320.csv")
X = train.iloc[:,1:].values.astype(np.float32)
Y = np.array([np.array([int(i == l) for i in range(2)]) for l in
train.iloc[:,:1].values])
test_x = test.iloc[:,1:].values.astype(np.float32)
test_y = np.array([np.array([int(i == l) for i in range(2)]) for l in
test.iloc[:,:1].values])
X = X.reshape([-1, 16, 16, 1])
test_x = test_x.reshape([-1, 16, 16, 1])
convnet = input_data(shape=[None, 16, 16, 1], name='input')
initialization = tf.contrib.layers.variance_scaling_initializer(factor=1.0, mode='FAN_IN', uniform=False)
convnet = conv_2d(convnet, 32, 2, activation='elu',
weights_init=initialization)
convnet = max_pool_2d(convnet, 2)
convnet = tflearn.layers.normalization.batch_normalization(convnet, beta=0.0, gamma=1.0, epsilon=1e-05,
decay=0.9, stddev=0.002, trainable=True, restore=True, reuse=False, scope=None, name='BatchNormalization')
convnet = conv_2d(convnet, 64, 2, activation='elu',
weights_init=initialization)
convnet = max_pool_2d(convnet, 2)
convnet = tflearn.layers.normalization.batch_normalization(convnet, beta=0.0, gamma=1.0, epsilon=1e-05,
decay=0.9, stddev=0.002, trainable=True, restore=True, reuse=False, scope=None, name='BatchNormalization')
convnet = fully_connected(convnet, 254, activation='elu', weights_init=initialization)
convnet = dropout(convnet, 0.8)
convnet = tflearn.layers.normalization.batch_normalization(convnet, beta=0.0, gamma=1.0, epsilon=1e-05,
decay=0.9, stddev=0.002, trainable=True, restore=True, reuse=False, scope=None, name='BatchNormalization')
convnet = fully_connected(convnet, 2, activation='softmax')
adam = tflearn.optimizers.Adam(learning_rate=0.00065, beta1=0.9, beta2=0.999, epsilon=1e-08)
convnet = regression(convnet, optimizer=adam, loss='categorical_crossentropy', name='targets')
model = tflearn.DNN(convnet, tensorboard_dir='/Users/darrentaggart/Library/Mobile Documents/com~apple~CloudDocs/Uni Documents/MEE4040 - Project 4/Coding Related Stuff/Neural Networks/latest logs',
tensorboard_verbose=3)
model.fit({'input': X}, {'targets': Y}, n_epoch=100, batch_size=16,
validation_set=({'input': test_x}, {'targets': test_y}), snapshot_step=10, show_metric=True, run_id='1600 - ConvConvFC254 LR0.00065decay BN VSinit 16batchsize 100epochs')
model.save('tflearncnn.model')
加载和生成预测的代码:
test = pd.read_csv("/Users/darrentaggart/Library/Mobile Documents/com~apple~CloudDocs/Uni Documents/MEE4040 - Project 4/Coding Related Stuff/Neural Networks/modelpredictiondata.csv")
X = test.iloc[:,1:].values.astype(np.float32)
sess=tf.InteractiveSession()
tflearn.is_training(False)
convnet = input_data(shape=[None, 16, 16, 1], name='input')
initialization = tf.contrib.layers.variance_scaling_initializer(factor=1.0, mode='FAN_IN', uniform=False)
convnet = conv_2d(convnet, 32, 2, activation='elu', weights_init=initialization)
convnet = max_pool_2d(convnet, 2)
convnet = tflearn.layers.normalization.batch_normalization(convnet, beta=0.0, gamma=1.0, epsilon=1e-05,
decay=0.9, stddev=0.002, trainable=True, restore=True, reuse=False, scope=None, name='BatchNormalization')
convnet = conv_2d(convnet, 64, 2, activation='elu', weights_init=initialization)
convnet = max_pool_2d(convnet, 2)
convnet = tflearn.layers.normalization.batch_normalization(convnet, beta=0.0, gamma=1.0, epsilon=1e-05,
decay=0.9, stddev=0.002, trainable=True, restore=True, reuse=False, scope=None, name='BatchNormalization')
convnet = fully_connected(convnet, 254, activation='elu', weights_init=initialization)
convnet = tflearn.layers.normalization.batch_normalization(convnet, beta=0.0, gamma=1.0, epsilon=1e-05,
decay=0.9, stddev=0.002, trainable=True, restore=True, reuse=False, scope=None, name='BatchNormalization')
convnet = fully_connected(convnet, 2, activation='softmax')
adam = tflearn.optimizers.Adam(learning_rate=0.00065, beta1=0.9, beta2=0.999, epsilon=1e-08)
convnet = regression(convnet, optimizer=adam, loss='categorical_crossentropy', name='targets')
model = tflearn.DNN(convnet)
if os.path.exists('{}.meta'.format('tflearncnn.model')):
model.load('tflearncnn.model', weights_only=False)
print('model loaded!')
for i in enumerate(X):
X = X.reshape([-1, 16, 16, 1])
model_out = model.predict(X)
if np.argmax(model_out) == 1: str_label='Boss'
else: str_label = 'Slot'
print(model_out)
我知道这是一个很长的镜头,但认为有人可能会对此事有所了解。感谢。
答案 0 :(得分:0)
你没有尝试model.load(<path-to-saved-model>)
例如:model.load(“./ model.tflearn”)
我认为这将解决您的问题。
答案 1 :(得分:0)
问这个问题已经有一年半了,但是共享毕竟很重要。使用tflearn和Alexnet对图像进行二进制分类。
诀窍是在转换为nparray之后进行归一化。不要忘记更改目录路径。
def train(model,optimizer,criterion,BATCH_SIZE,train_loader,clip):
model.train(True)
total_loss = 0
hidden = model._init_hidden(BATCH_SIZE)
for i, (batch_of_data, batch_of_labels) in enumerate(train_loader, 1):
hidden=hidden.detach()
model.zero_grad()
output,hidden= model(batch_of_data,hidden)
loss = criterion(output, sorted_batch_target_scores)
total_loss += loss.item()
loss.backward()
torch.nn.utils.clip_grad_norm(model.parameters(), clip)
optimizer.step()
return total_loss/len(train_loader.dataset)