从另一个函数内部进行悬停事件处理

时间:2018-12-18 19:18:40

标签: python matplotlib

我正在尝试将新代码集成到其他人编写的现有代码中,但是遇到了一些问题。现有代码使用matplotlib来制作GUI绘图仪,该绘图仪可以在给定输入文件的情况下绘制各种波形。我希望能够将鼠标悬停在图形上的任何迹线上,并使注释框显示它是哪条线(想象在一张图形上有30条线,而不能将它们彼此区分开)。我找到了以下代码(我不回答第一个答案):Possible to make labels appear when hovering over a point in matplotlib?

代码如下:

import matplotlib.pyplot as plt
import numpy as np; np.random.seed(1)

x = np.random.rand(15)
y = np.random.rand(15)
names = np.array(list("ABCDEFGHIJKLMNO"))
c = np.random.randint(1,5,size=15)

norm = plt.Normalize(1,4)
cmap = plt.cm.RdYlGn

fig,ax = plt.subplots()
sc = plt.scatter(x,y,c=c, s=100, cmap=cmap, norm=norm)

annot = ax.annotate("", xy=(0,0), xytext=(20,20),textcoords="offset points",
                    bbox=dict(boxstyle="round", fc="w"),
                    arrowprops=dict(arrowstyle="->"))
annot.set_visible(False)

def update_annot(ind):

pos = sc.get_offsets()[ind["ind"][0]]
annot.xy = pos
text = "{}, {}".format(" ".join(list(map(str,ind["ind"]))), 
                       " ".join([names[n] for n in ind["ind"]]))
annot.set_text(text)
annot.get_bbox_patch().set_facecolor(cmap(norm(c[ind["ind"][0]])))
annot.get_bbox_patch().set_alpha(0.4)


def hover(event):
    vis = annot.get_visible()
    if event.inaxes == ax:
    cont, ind = sc.contains(event)
    if cont:
        update_annot(ind)
        annot.set_visible(True)
        fig.canvas.draw_idle()
    else:
        if vis:
            annot.set_visible(False)
            fig.canvas.draw_idle()

fig.canvas.mpl_connect("motion_notify_event", hover)

plt.show()

现有代码在绘图函数内定义ax。如果时间不长,我会在此处粘贴整个函数,但这是一个代码段(下面是上面的一些代码):

            else:
                print ('The label is: %s' % label)
                ax = plt.subplot('111')
                axesDict[labelKey] = ax
            #end if
            annot = ax.annotate("", xy=(-20,20), xytext=(None),textcoords="offset points",
                                bbox=dict(fc="b"),
                                arrowprops=dict(arrowstyle="->"))
            annot.set_visible(True)

            fig.canvas.mpl_connect("motion_notify_event", hover)

问题是我不知道如何将ax传递到悬停函数中,因为由于mpl_connect的性质,您不能使用参数调用该函数。

我真的是Python的新手,处理这样大的现有代码一直是一个挑战。也许我在错误地考虑实现,请随时指出所有这些!我确定我还有其他问题,但这是一个好的开始。感谢您的帮助和提前的时间。

编辑:这是长绘图功能(这只是我要处理的第一部分):

    def plotData(self, refreshPlotAxes = False):

        if len(self.waveformObjectList) == 0:
            print ('no waveforms to plot')
            return
        #end if

        startFigureNumber = self.startFigureNumber
        nextFigureNumber = startFigureNumber

        if self.fileDataTypeMode == 'ascii':
            markerArray = self.defaultMarkerArray
        else:
            markerArray = ['']

        waveformIndexList = self.getFilteredWaveformObjectIndexList()

        ###################### First Plot #############################

        if self.plotFreqResp:
            firstLoop = True
            markerIndex = 0
            #which labels are in each figure
            xAxisLabelDictionary = {}
            yAxisLabelDictionary = {}
            subplotDictionary = {}   #subplots for each figure
            plotAxisDictionary = {}  #plot axis for every subplot
            #at the moment, I don't support multiple figures and multiple subplots at the same time,
            #but I might someday
            logXDictDict = {}
            logYDictDict = {}

            plotFilename = 'blank_freqresp.png'

            plotAxisList = []
            numberOfFigures = 0
            numberOfSubPlots = 0
            numberOfLabels = 0

            #set up the plots
            axesDict = {}

            labelList = []
            for waveformObj in self.waveformObjectList:
                label = waveformObj.label
                labelPieces = label.split('_')
                labelList.append(labelPieces)
            #end for waveformObj

            commonLabelPieces = []
            if len(labelList) > 1:
                labelPieces0 = labelList[0]
                for labelPiece in labelPieces0:
                    isCommon = True
                    for labelPieces in labelList:
                        if labelPieces.count(labelPiece) == 0:
                            isCommon = False
                            break
                        #end if
                    #end for
                    if isCommon:
                        commonLabelPieces.append(labelPiece)
                    #end if
                #end for labelPiece
            #end if

            for waveformIndex in waveformIndexList:
                waveformObj = self.waveformObjectList[waveformIndex]

                plotFilename = waveformObj.filename
                [plotFilename, ext] = os.path.splitext(plotFilename)
                plotFilename += '_freqresp.png'

                if firstLoop or (self.plot1SeparatePlots and not self.plot1SubPlots):
                    currentFigureNumber = nextFigureNumber
                    logXDictDict[currentFigureNumber] = {}
                    logYDictDict[currentFigureNumber] = {}
                    nextFigureNumber += 1
                    numberOfFigures += 1
                    figureTxt = 'Figure %d - %s' % (currentFigureNumber, self.appTitle)
                    fig = plt.figure(figureTxt, figsize=self.cwPlotSize)
                #end if

                label = waveformObj.getLabel(shortLabel = self.shortLabel, includeXLabel = self.showXInLabel)
                shortLabel = waveformObj.getLabel(shortLabel = True, includeXLabel = self.showXInLabel)

                if self.enableShortenedLabels:
                    label = waveformObj.label
                    labelPieces = label.split('_')
                    uniqueLabelPieces = []
                    for labelPiece in labelPieces:
                        if commonLabelPieces.count(labelPiece) == 0:
                            uniqueLabelPieces.append(labelPiece)
                        #end if
                    #end for
                    label = '_'.join(uniqueLabelPieces)
                    label += '(' + shortLabel + ')'
                #end if

                try:
                    if waveformObj.hasReference():
                        label += '%s%s @ %s' % (waveformObj.referenceWaveformOperation, waveformObj.referenceWaveform, waveformObj.referenceWaveformFreq)
                    #end if
                except:
                    pass

                [xAxisLabel, yAxisLabel] = waveformObj.axisLabels()[0:2]
                if xAxisLabel == 'none':
                    xAxisLabel = waveformObj.getDataLabels()[0]
                if yAxisLabel == 'none' or yAxisLabel == 'mag':
                    yAxisLabel = waveformObj.getDataLabels()[1]

##                print ('data labels = %s' % str([xAxisLabel, yAxisLabel]))
##                print ('shortLabel = %s' % shortLabel
##                print ('label = %s' % label

                #when there is just one subplot (the default), it's designated '111'
                subplotString = '1'
                logX = self.logHorizontalAxis
                dbY = self.dBVerticalAxis
                if self.plot1SubPlots:
                    subplotString = '000'
                    for subplotNum in self.plot1SubPlotDict['filter'].keys():
                        matchList = self.plot1SubPlotDict['filter'][subplotNum]
                        for matchItem in matchList:
                            if re.search(matchItem, shortLabel):
                                subplotString = subplotNum
                                break
                            #end if
                        #end for
                    #end for

                    if subplotString == '000':
                        firstLoop = False
                        continue

                    try:
                        logX = self.plot1SubPlotDict['xlog'][subplotString]
                    except:
                        pass

                    try:
                        dbY = self.plot1SubPlotDict['ydb'][subplotString]
                    except:
                        pass

                #end if

#                if waveformObj.yUnits.lower().count('db'):
#                    yData = waveformObj.getNormalizeddBVector()
#                    logY = False
                if waveformObj.yUnits.lower().count('bits') or \
                     waveformObj.yUnits.lower().count('data'):
                    yData = waveformObj.getMagnitudeVector()
                    logY = False
                    dbY = False
                    forceLinearYAxis = True
                else:
                    forceLinearYAxis = False
                    if dbY:
                        yData = waveformObj.getNormalizeddBVector(self.absoluteValueForDB)
                        logY = False
                    else:
                        yData = waveformObj.getNormalizedMagnitudeVector()
                        logY = self.logVerticalAxis
                    #end if
                #end if

                fData = waveformObj.getFreqVector()

                labelKey = str(currentFigureNumber) + '_' + subplotString

                if not labelKey in xAxisLabelDictionary:
                    xAxisLabelDictionary[labelKey] = []
                if not labelKey in yAxisLabelDictionary:
                    yAxisLabelDictionary[labelKey] = []
                if not currentFigureNumber in subplotDictionary:
                    subplotDictionary[currentFigureNumber] = []

                xAxisLabelDictionary[labelKey].append(xAxisLabel)
                yAxisLabelDictionary[labelKey].append(yAxisLabel)

                plot1FormatMatchesKey = False
                for key in self.plot1Format.keys():

                    if re.search(key, waveformObj.yLabel) or re.search(key, waveformObj.label):
                        plot1FormatMatchesKey = True
                        break
                    elif re.search(key, label):
                        plot1FormatMatchesKey = True
                        break
                    #end if
                #end for key

                if plot1FormatMatchesKey:
                    pltFormatText = self.plot1Format[key][0]
                    pltLineWidth = self.plot1Format[key][1]
                    pltMarkerSize = self.plot1Format[key][2]
                    allowLabel = self.plot1Format[key][3]
                    if len(self.plot1Format[key]) > 4:
                        markerColor = self.plot1Format[key][4]
                    else:
                        markerColor = -1

                    if pltFormatText is None:
                        pltFormatText = markerArray[markerIndex]+'-'
                        markerIndex += 1
                    if pltLineWidth < 0:
                        pltLineWidth = self.defaultLineWidth
                    if pltMarkerSize < 0:
                        pltMarkerSize = self.defaultMarkerSize
                    if not allowLabel:
                        label = ''
                    if markerColor != -1:
                        markerEdgeColor = None
                        markerEdgeWidth = self.defaultMarkerEdgeWidth
                        markerFaceColor = markerColor
                    else:
                        markerEdgeColor = None
                        markerEdgeWidth = self.defaultMarkerEdgeWidth
                        markerFaceColor = None
                    #end if

                else:
                    pltFormatText = markerArray[markerIndex] + self.defaultLinePattern
                    markerIndex += 1
                    pltLineWidth = self.defaultLineWidth
                    pltMarkerSize = self.defaultMarkerSize
                    markerEdgeColor = None
                    markerEdgeWidth = self.defaultMarkerEdgeWidth
                    markerFaceColor = None
                #end if

                if markerIndex >= len(markerArray):
                    markerIndex = 0

                if labelKey in axesDict:
                    try:
                        plt.sca(axesDict[labelKey])
                    except:
                        print ('something went wrong with subplot label %s' % labelKey)
                        print ('probably due to overlapping subplots.')
                        print ('make adjustments to the figInfoDict items')
                    #end try
                elif self.plot1SubPlots:
                    gridShape = self.plot1SubPlotDict['gridShape']
                    subplotInfo = self.plot1SubPlotDict['figInfoDict'][subplotString]
                    ax = plt.subplot2grid(gridShape, subplotInfo[0], subplotInfo[1], subplotInfo[2])
                    axesDict[labelKey] = ax
                else:
                    print ("Made it inside else condition")
                    print ('The label is: %s' % label)
                    ax = plt.subplot('111')
                    axesDict[labelKey] = ax
                #end if

    #
                annot = ax.annotate("", xy=(-20,20), xytext=(None),textcoords="offset points",
                                    bbox=dict(fc="b"),
                                    arrowprops=dict(arrowstyle="->"))
                annot.set_visible(True)

                h = lambda x: hover(x, annot, label)

                fig.canvas.mpl_connect("motion_notify_event", h)
#

格式化图

for p in range(numberOfFigures):
figureNumber = p + startFigureNumber

figureTxt = 'Figure %d - %s' % (figureNumber, self.appTitle)
plt.figure(figureTxt)

if not figureNumber in subplotDictionary:
    continue

for subplotString in subplotDictionary[figureNumber]:

    labelKey = str(figureNumber) + '_' + subplotString
    try:
        plt.sca(axesDict[labelKey])
    except:
        print ('something went wrong with subplot label %s' % labelKey)
        print ('probably due to overlapping subplots.')
        print ('make adjustments to the figInfoDict items')
        continue
    #end try
    #plt.subplot(subplotString)
    plotAxis = plotAxisDictionary[labelKey]
    #print ('start misc plot settings';
    plt.grid(self.plot1Grid, 'both')

    plot1YticksList = self.plot1YticksList
    plot1XticksList = self.plot1XticksList
    plot1YLimits = self.cwPlotYLimits
    plot1XLimits = self.cwPlotXLimits
    vcursors = []

    logX = logXDictDict[figureNumber][subplotString]
    logY = logYDictDict[figureNumber][subplotString]

    enablePlotXLabel = True
    legendEnable = True

    if self.plot1SubPlots:
        if not logY:
            try:
                plot1YticksList = self.plot1SubPlotDict['yticks'][subplotString]
            except:
                pass
        else:
            plot1YticksList = []
        #end if

        if not logX:
            try:
                plot1XticksList = self.plot1SubPlotDict['xticks'][subplotString]
            except:
                pass
        else:
            plot1XticksList = []
        #end if

        try:
            plot1YLimits = self.plot1SubPlotDict['ylimits'][subplotString]
        except:
            pass

        try:
            plot1XLimits = self.plot1SubPlotDict['xlimits'][subplotString]
        except:
            pass

        try:
            vcursors = self.plot1SubPlotDict['vcursors'][subplotString]
        except:
            pass

        try:
            enablePlotXLabel = self.plot1SubPlotDict['xLabelEnable'][subplotString]
        except:
            pass
        #end

        try:
            legendEnable = self.plot1SubPlotDict['legendEnable'][subplotString]
        except:
            pass
        #end

    #end if

    if logY:
        for tick in plot1YticksList:
            if tick <= 0:
                plot1YticksList = []
                break
            #end if
        #end for
        if len(plot1YLimits) == 2:
            if plot1YLimits[0] <= 0:
                plot1YLimits = []
            #end if
        #end if
    #end if

    if len(plot1YticksList):
        plt.yticks(plot1YticksList)
    if len(plot1XticksList):
        plt.xticks(plot1XticksList)

    if plotAxis == (0.0,1.0,0.0,1.0) or refreshPlotAxes:
        if len(plot1YLimits) == 2:
            plt.ylim(plot1YLimits)

        if len(plot1XLimits) == 2:
            plt.xlim(plot1XLimits)
    else:
        plt.axis(plotAxis)
    #end if

    if len(vcursors):
        ylimits = plt.ylim()
        for x in vcursors:
            plt.plot([x,x], ylimits, self.vcursorFormatText, linewidth = self.vcursorWidth)

    yAxisLabelListSet = list(set(yAxisLabelDictionary[labelKey]))
    if len(yAxisLabelListSet) == 1:
        yAxisLabel = yAxisLabelDictionary[labelKey][0]
    elif len(yAxisLabelListSet) > 1:
        yAxisLabel = yAxisLabelListSet[0]
        for buf in yAxisLabelListSet[1:]:
            yAxisLabel += ',' + buf
        #end for
    else:
        yAxisLabel = ''
    #end if

    xAxisLabelListSet = list(set(xAxisLabelDictionary[labelKey]))
    if len(xAxisLabelListSet) == 1:
        xAxisLabel = xAxisLabelDictionary[labelKey][0]
    elif len(xAxisLabelListSet) > 1:
        xAxisLabel = xAxisLabelListSet[0]
        for buf in xAxisLabelListSet[1:]:
            xAxisLabel += ',' + buf
        #end for
    else:
        xAxisLabel = ''
    #end if

    if not forceLinearYAxis:
        if dbY:
            if not waveformObj.yUnits.lower().count('db'):
                yAxisLabel += ' (dB)'
        else:
            yAxisLabel += ' (lin)'
    #end if

    plt.ylabel(yAxisLabel)
    if enablePlotXLabel:
        plt.xlabel(xAxisLabel)
    else:
        xtickList = plt.xticks()[0]
        plt.xticks(xtickList, '')
    #end if

    prop=matplotlib.font_manager.FontProperties(size=self.legendFontSize)
    if self.shortLabel:
        plt.title(waveformObj.filename, fontsize=12)
    #end if
    if self.cwPlotLegend and legendEnable:
        plt.legend(loc=self.plot1LegendLocation,prop=prop,borderpad=0.3,labelspacing=0.1,handletextpad=0,numpoints=self.numLegendPoints)
    #end if
    #print ('done'

#end for subplotString

plt.draw()

if self.savePlotAsImage:
    plt.savefig(plotFilename, format='png')

1 个答案:

答案 0 :(得分:1)

您可以执行以下操作:

h = lambda x: hover(x, ax)
fig.canvas.mpl_connect("motion_notify_event", h)

然后将您的悬停功能更改为:

def hover(event, ax):
    ...