加快我的交互式matplotlib数字

时间:2016-12-08 11:28:45

标签: numpy matplotlib scipy

我正在将一个jpeg图形(大:1300x2000)加载到matplotlib中,在其上绘制一个50x50正方形的网格,然后单击每个正方形以对其进行颜色编码。但是,我注意到该程序远远落后于我的点击,如果以合理的速度快速约50个方格,则需要30秒才能赶上。我想知道是否有人能够加快速度。下面是我的脚本,如果你复制/粘贴它已经准备好了(并且有scipy,numpy,matplotlib,pillow和tkinter)

欢迎任何建议。我是一名医学科学家,如果代码没有得到很好的解释,请原谅我:

import matplotlib
import matplotlib.pyplot as plt
import tkinter
import tkinter.filedialog
from matplotlib.figure import  Figure
import math, sys
import numpy as np
import scipy.io as sio
from PIL import Image
from numpy import arange, sin, pi
#from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
#matplotlib.matplotlib_fname()
import os, re

global stridesize, classnumber,x, im,fn, plt, fig, mask

classnumber = 1



def onmove(eve):
    global x,im, plt
    print(eve.ydata)
    print(eve.button)
    if (eve.ydata !=None) and (eve.xdata !=None):
        if eve.button==1:
            print(eve.button)
            xcoord = int(eve.xdata)
            ycoord = int(eve.ydata)
            startX = math.floor(xcoord/stridesize)*stridesize
            startY = math.floor(ycoord/stridesize)*stridesize
           # print(eve.xdata, int(eve.ydata), stridesize)
            if(classnumber==1):
                mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([255,0,0])
            if(classnumber==2):
                mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([0,255,0])
            if(classnumber==3):
                mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([0,0,255])
            if(classnumber==4):
                mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([255,0,255])
            if(classnumber==5):
                mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([255,255,0])
            if(classnumber==6):
                mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([0,255,255])
            if(classnumber==7):
                mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([100,255,50])


        if eve.button==3:
            xcoord = int(eve.xdata)
            ycoord = int(eve.ydata)
            startX = math.floor(xcoord/stridesize)*stridesize
            startY = math.floor(ycoord/stridesize)*stridesize
            print(eve.xdata, int(eve.ydata), stridesize)
            mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([0,0,0])
        im.set_data(mask)
        fig.canvas.draw()




def onclick(event):

    if (event.ydata !=None) and (event.xdata !=None):
        global x, im, fig
        if event.button==1:
            xcoord = int(event.xdata)
            ycoord = int(event.ydata)
            startX = math.floor(xcoord/stridesize)*stridesize
            startY = math.floor(ycoord/stridesize)*stridesize
            print(event.xdata, int(event.ydata), stridesize)
            if(classnumber==1):
                mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([255,0,0])
            if(classnumber==2):
                mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([0,255,0])
            if(classnumber==3):
                mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([0,0,255])
            if(classnumber==4):
                mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([255,0,255])
            if(classnumber==5):
                mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([255,255,0])
            if(classnumber==6):
                mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([0,255,255])
            if(classnumber==7):
                mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([100,255,50])
            im.set_data(mask)

        if event.button==3:
            xcoord = int(event.xdata)
            ycoord = int(event.ydata)
            startX = math.floor(xcoord/stridesize)*stridesize
            startY = math.floor(ycoord/stridesize)*stridesize
            print(event.xdata, int(event.ydata), stridesize)
            mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([0,0,0])
            im.set_data(mask)
        fig.canvas.draw()


def onpress(event):
    global classnumber, mask
    if (event.key == 'e'):
       print("YO")
       mask[:,:,:]=0;
       im.set_data(mask)
       fig.canvas.draw()

    if (event.key=='s'):
        savemask(fn)
    if (event.key=='r'):
        plt.figure();
        plt.imshow(mask);
        plt.show();
    if int(event.key) > 0 and int(event.key) <9 :
       classnumber = int(event.key)
       print(classnumber)


def onrelease(event):
    print(event.button)
 #   im.set_data(mask)




def savemask(fn):
    # matrixname =os.path.basename(filename)
    # matrixname = re.sub(r'\.jpg','',matrixname)
    pre, ext = os.path.splitext(fn)
    savename_default = os.path.basename(pre)
    options = {}
    options['defaultextension'] = ''
    options['filetypes'] = [('mat files', '.mat')]
    options['initialdir'] = ''
    options['initialfile'] = savename_default
    options['title'] = 'Save file'


    f = tkinter.filedialog.asksaveasfile(**options)
    if f is None: # asksaveasfile return `None` if aadialog closed with "cancel".
        return
    name = f.name
    sio.savemat(name,{'mask':mask},do_compression=True)
    f.close()




root = tkinter.Tk()
root.withdraw()

options = {}

options['defaultextension'] = '.jpg'

options['filetypes'] = [('Jpeg', '.jpg')]

options['initialdir'] = 'C:\\'
options['initialfile']= ''
options['parent'] = root

options['title'] = 'This is a title'


fn= tkinter.filedialog.askopenfilename(**options)


img = Image.open(fn)
x = np.asarray(img)
x.setflags(write=1)
#masksize= (x.shape[0],x.shape[1],4)
mask= np.zeros(x.shape,'uint8')
#mask[:,:,3]=0.2
fig = plt.figure()
fig.suptitle(r'Key codes: 1 = Tumour, 2 = stroma-hypocellular, 3=stroma cellular (inflammatory)' '\n4 = proteinaceous, 5= red cells, 6,7: anyother,''\nRight click: clear square''\n r:  review mask, e: erase mask, o : open mask image, s : save mask image;')


im=plt.imshow(x)
im=plt.imshow(mask,alpha=.25)
ax = plt.gca();

stridesize = 50;

plt.rcParams['keymap.save']=''
ax.set_yticks(np.arange(0, x.shape[0], stridesize));
ax.set_xticks(np.arange(0, x.shape[1], stridesize));

cid = fig.canvas.mpl_connect('button_press_event', onclick)
cod = fig.canvas.mpl_connect('key_press_event', onpress)
#cdd = fig.canvas.mpl_connect('motion_notify_event', onmove)
cdr = fig.canvas.mpl_connect('button_release_event', onrelease)

plt.grid(b=True, which='both', color='black',linestyle='-')
#
plt.show()

plt.ion()

1 个答案:

答案 0 :(得分:3)

首先,我建议不惜一切代价避免使用global个变量。您可以使用class替换它。找到您的代码打算执行的完整工作的摘要版本:

import numpy as np

import matplotlib
matplotlib.use('Qt4Agg')

from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap

class ColorCode(object):

    def __init__(self, block_size=(50,50), colors=['red', 'green', 'blue'], alpha=0.3):
        self.by, self.bx = block_size # block size
        self.selected = 0 # selected color
        self.colors = colors
        self.cmap = ListedColormap(colors) # color map for labels
        self.mask = None # annotation mask
        self.alpha = alpha
        # Plots
        self.fig = plt.figure()
        self.ax = self.fig.gca()
        # Events
        self.fig.canvas.mpl_connect('button_press_event', self.on_click)
        self.fig.canvas.mpl_connect('key_press_event', self.on_key)

    def color_code(self, img):
        self.imshape = img.shape[:2]
        self.mask = np.full(img.shape[:2], -1, np.int32) # masked labels
        self.ax.imshow(img) # show image
        self.ax.imshow(np.ma.masked_where(self.mask < 0, self.mask), cmap=self.cmap,
                       alpha=self.alpha, vmin=0, vmax=len(self.colors)) # show mask
        # Run
        plt.show(block=True)
        return self.mask

    def on_click(self, event):
        if not event.inaxes or self.mask is None:
            return
        # Get corresponding coordinates
        py, px = int(event.ydata), int(event.xdata)
        cy, cx = py//self.by, px//self.bx # grid coordinates
        ymin = cy * self.by
        ymax = min((cy+1) * self.by, self.imshape[0])
        xmin = cx * self.bx
        xmax = min((cx+1) * self.bx, self.imshape[1])
        # Update mask
        if event.button == 1:
            self.mask[ymin:ymax, xmin:xmax] = self.selected
        elif event.button == 3:
            self.mask[ymin:ymax, xmin:xmax] = -1
        # Update figure
        self.ax.images[1].set_data(np.ma.masked_where(self.mask < 0, self.mask))
        self.fig.canvas.draw_idle()

    def on_key(self, event):
        ikey = int(event.key)
        if 0 <= ikey < len(self.colors):
            self.selected = ikey

与您的代码的主要区别是:

  1. 它不使用全局变量,而是使用类变量。使运行更安全,更容易扩展/修改。

  2. 注释不是为注释着色三维mask,而是保存为二维mask,其中每个像素的值都在[1, len(colors)]范围内指示它属于哪种颜色。然后使用ListedColormap为绘图添加颜色,为图表添加颜色。

  3. 它绘制图像并在其上面叠加分割蒙版。最初,掩码用-1填充,这意味着它没有标签。通过使用numpy的masked array,您可以屏蔽绘图中mask < 0 不显示的位置,使绘图在mask < 0处变为透明,否则为彩色。

  4. 可选颜色列表作为类的参数提供。它允许您选择从0到len(colors)的颜色,最多10种颜色(因为它当前绑定到键盘中的数字)。

  5. fig.canvas.draw_idlefig.canvas.draw好得多。后者会阻止程序直到完成绘图。

  6. 由于所有内容都属于一个类,代码看起来更清晰。

  7. 您可以将代码称为:

    >>> random_image = np.random.randn(1000,2000, 3)
    >>> result = ColorCode().color_code(random_image)
    

    result将包含标签mask,其中每个像素都有一个数字,表示已标记的女巫颜色(如果没有,则为-1)。最后,其他参数可以传递给ColorCode的构造函数,例如block_size=(100,100)用于不同的块大小,alpha=0.5用于掩码中的不透明度较低(或者为alpha=1的无)

    希望它适合你,或者至少你可以从中获取一些想法。