Jupyter笔记本交互式悬停,可显示许多数字

时间:2019-11-08 16:05:05

标签: python matplotlib jupyter-notebook

我正在对较小的MNIST数据子集使用降维算法,我想从图中检查每个图像的位置。可以根据以下StackOverflow主题的答案来完成此操作:Python show image upon hovering over a point

我想用jupyter笔记本来实现这一点,以便一次运行就可以得到所有的数字。问题是我可能需要定义悬停功能的图形编号才能使其正常工作。我完全不确定是否有可能。我将通过以下简单示例演示我的问题:

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.decomposition import PCA
from sklearn.manifold import MDS, Isomap
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
import matplotlib as mpl
plt.rcParams['figure.figsize'] = (15, 10)
%matplotlib notebook

# Load MNIST data
data, Y = datasets.fetch_openml('mnist_784', version=1, return_X_y=True)
data, Y = np.array(data, 'int16'), np.array(Y, 'int')
data, Y = data[:100], Y[:100]

def transform_data(data):
    """
    This function makes dimensional rediction with PCA, MDS and ISOMAP and returns them in numpy list
    """
    data_pca = PCA(n_components=2).fit_transform(data)
    data_mds = MDS(n_jobs=-1).fit_transform(data)
    data_isomap = Isomap().fit_transform(data)
    transformed_datas = np.array([data_pca, data_mds, data_isomap])
    return transformed_datas

def hover(event):
    if line.contains(event)[0]:
        inds = line.contains(event)[1]['ind']
        ind = inds[0]
        w,h = fig.get_size_inches()*fig.dpi
        ws = (event.x > w/2.)*-1 + (event.x <= w/2.) 
        hs = (event.y > h/2.)*-1 + (event.y <= h/2.)
        ab.xybox = (xybox[0]*ws, xybox[1]*hs)
        ab.set_visible(True)
        ab.xy =(x_coords[ind], y_coords[ind])
        offset_image.set_data(arr[ind,:,:])
    else:
        ab.set_visible(False)
    fig.canvas.draw_idle()


transformed_datas = transform_data(data)
algorithms_text = ['PCA','MDS','ISOMAP']
colors = np.unique(Y, return_inverse=True)[1].tolist()
arr = np.reshape(data, (100,28,28))

for i in range(transformed_datas.shape[0]):
    fig, ax = plt.subplots(figsize=(10,6))
    x_coords = transformed_datas[i][:,0]
    y_coords = transformed_datas[i][:,1]
    line = plt.scatter(x_coords,y_coords, s=30, c=colors,  cmap='jet', edgecolor='k')

    offset_image = OffsetImage(arr[0,:,:], zoom=2, cmap=plt.cm.gray_r)
    xybox = (40, 40)
    ab = AnnotationBbox(offset_image, (0,0), xybox=xybox, xycoords='data', boxcoords='offset points',  pad=0.3,  arrowprops=dict(arrowstyle='->'))
    ax.add_artist(ab)
    ab.set_visible(False)

    plt.title('2D-visualization of MNIST data with {} algorithm'.format(algorithms_text[i]), fontsize=10)
    ax.set_yticklabels([])
    ax.set_xticklabels([])
    fig.canvas.mpl_connect('motion_notify_event', hover)
    plt.show()

0 个答案:

没有答案