已经尝试使用plt.axis('off')和plt.grid(False),但都没有在图像上提供预期的输出。我想删除每个实例图像的白色网格。
# Run Session
with tf.Session() as sess:
# Initialize Variables
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
sess.run(iterator.initializer)
# Train the Model
for epoch in range(7):
prog_bar = tqdm(range(int(len(trainX)/batch_size)))
for step in prog_bar:
_,cost = sess.run([train_op,loss])
prog_bar.set_description("cost: {:.3f}".format(cost))
accuracy = sess.run(acc_op)
print('\nEpoch {} Accuracy: {:.3f}'.format(epoch+1, accuracy))
# Show Sample Predictions
predictions = sess.run(tf.argmax(CNN(testX[:25], weights, biases), axis=1))
f, axarr = plt.subplots(5, 5, figsize=(20,20))
for idx in range(25):
axarr[int(idx/5), idx%5].imshow(np.squeeze(testX[idx]), cmap='gray')
axarr[int(idx/5), idx%5].set_title(str(predictions[idx]), fontsize=20)
# Save Model
saver = tf.train.Saver()
saver.save(sess, './model.ckpt')