matplotlib FigureCanvasQTAgg:QWidget上的慢速旋转和网格重叠表面

时间:2012-11-23 07:24:58

标签: python matplotlib pyqt4

我有qtdesigner生成的小部件 - optim_plotting_frame.py(下面的代码)


    from PyQt4 import QtCore, QtGui

    try:
        _fromUtf8 = QtCore.QString.fromUtf8
    except AttributeError:
        _fromUtf8 = lambda s: s

    class Ui_optim_plotting_frame(object):
        def setupUi(self, optim_plotting_frame):
            optim_plotting_frame.setObjectName(_fromUtf8("optim_plotting_frame"))
            optim_plotting_frame.setWindowModality(QtCore.Qt.ApplicationModal)
            optim_plotting_frame.resize(700, 580)
            optim_plotting_frame.setWindowTitle(QtGui.QApplication.translate("optim_plotting_frame", "Plotting", None, QtGui.QApplication.UnicodeUTF8))
            self.verticalLayout_2 = QtGui.QVBoxLayout(optim_plotting_frame)
            self.verticalLayout_2.setObjectName(_fromUtf8("verticalLayout_2"))
            self.horizontalLayout = QtGui.QHBoxLayout()
            self.horizontalLayout.setObjectName(_fromUtf8("horizontalLayout"))
            self.verticalLayout = QtGui.QVBoxLayout()
            self.verticalLayout.setSizeConstraint(QtGui.QLayout.SetMinimumSize)
            self.verticalLayout.setObjectName(_fromUtf8("verticalLayout"))
            self.label_function = QtGui.QLabel(optim_plotting_frame)
            self.label_function.setMinimumSize(QtCore.QSize(111, 16))
            self.label_function.setMaximumSize(QtCore.QSize(16777215, 16777215))
            self.label_function.setToolTip(QtGui.QApplication.translate("optim_plotting_frame", "defines radial basis function for interpolation", None, QtGui.QApplication.UnicodeUTF8))
            self.label_function.setText(QtGui.QApplication.translate("optim_plotting_frame", "\n"
    "\n"
    "p, li { white-space: pre-wrap; }\n"
    "\n"
    "Interpolation function:

", None, QtGui.QApplication.UnicodeUTF8)) self.label_function.setObjectName(_fromUtf8("label_function")) self.verticalLayout.addWidget(self.label_function) self.cmb_function = QtGui.QComboBox(optim_plotting_frame) self.cmb_function.setMinimumSize(QtCore.QSize(111, 22)) self.cmb_function.setMaximumSize(QtCore.QSize(16777215, 16777215)) self.cmb_function.setToolTip(QtGui.QApplication.translate("optim_plotting_frame", "defines radial basis function for interpolation", None, QtGui.QApplication.UnicodeUTF8)) self.cmb_function.setObjectName(_fromUtf8("cmb_function")) self.cmb_function.addItem(_fromUtf8("")) self.cmb_function.setItemText(0, QtGui.QApplication.translate("optim_plotting_frame", "multiquadric", None, QtGui.QApplication.UnicodeUTF8)) self.cmb_function.addItem(_fromUtf8("")) self.cmb_function.setItemText(1, QtGui.QApplication.translate("optim_plotting_frame", "inverse", None, QtGui.QApplication.UnicodeUTF8)) self.cmb_function.addItem(_fromUtf8("")) self.cmb_function.setItemText(2, QtGui.QApplication.translate("optim_plotting_frame", "gaussian", None, QtGui.QApplication.UnicodeUTF8)) self.cmb_function.addItem(_fromUtf8("")) self.cmb_function.setItemText(3, QtGui.QApplication.translate("optim_plotting_frame", "linear", None, QtGui.QApplication.UnicodeUTF8)) self.cmb_function.addItem(_fromUtf8("")) self.cmb_function.setItemText(4, QtGui.QApplication.translate("optim_plotting_frame", "cubic", None, QtGui.QApplication.UnicodeUTF8)) self.cmb_function.addItem(_fromUtf8("")) self.cmb_function.setItemText(5, QtGui.QApplication.translate("optim_plotting_frame", "quintic", None, QtGui.QApplication.UnicodeUTF8)) self.cmb_function.addItem(_fromUtf8("")) self.cmb_function.setItemText(6, QtGui.QApplication.translate("optim_plotting_frame", "thin_plate", None, QtGui.QApplication.UnicodeUTF8)) self.verticalLayout.addWidget(self.cmb_function) self.label_alpha = QtGui.QLabel(optim_plotting_frame) self.label_alpha.setMinimumSize(QtCore.QSize(111, 16)) self.label_alpha.setMaximumSize(QtCore.QSize(16777215, 16777215)) self.label_alpha.setToolTip(QtGui.QApplication.translate("optim_plotting_frame", "Defines transparency: 0 - transparent, 1 - not transparent", None, QtGui.QApplication.UnicodeUTF8)) self.label_alpha.setText(QtGui.QApplication.translate("optim_plotting_frame", "Alpha:", None, QtGui.QApplication.UnicodeUTF8)) self.label_alpha.setObjectName(_fromUtf8("label_alpha")) self.verticalLayout.addWidget(self.label_alpha) self.dspb_alpha = QtGui.QDoubleSpinBox(optim_plotting_frame) self.dspb_alpha.setMinimumSize(QtCore.QSize(111, 0)) self.dspb_alpha.setMaximumSize(QtCore.QSize(16777215, 16777215)) self.dspb_alpha.setToolTip(QtGui.QApplication.translate("optim_plotting_frame", "Defines transparency: 0 - transparent, 1 - not transparent", None, QtGui.QApplication.UnicodeUTF8)) self.dspb_alpha.setMaximum(1.0) self.dspb_alpha.setSingleStep(0.1) self.dspb_alpha.setProperty("value", 0.7) self.dspb_alpha.setObjectName(_fromUtf8("dspb_alpha")) self.verticalLayout.addWidget(self.dspb_alpha) self.label_smooth = QtGui.QLabel(optim_plotting_frame) self.label_smooth.setMinimumSize(QtCore.QSize(111, 16)) self.label_smooth.setMaximumSize(QtCore.QSize(16777215, 16777215)) self.label_smooth.setToolTip(QtGui.QApplication.translate("optim_plotting_frame", "Smoothness of the approximation", None, QtGui.QApplication.UnicodeUTF8)) self.label_smooth.setText(QtGui.QApplication.translate("optim_plotting_frame", "Smoothness:", None, QtGui.QApplication.UnicodeUTF8)) self.label_smooth.setObjectName(_fromUtf8("label_smooth")) self.verticalLayout.addWidget(self.label_smooth) self.dspb_smooth = QtGui.QDoubleSpinBox(optim_plotting_frame) self.dspb_smooth.setMinimumSize(QtCore.QSize(111, 0)) self.dspb_smooth.setMaximumSize(QtCore.QSize(16777215, 16777215)) self.dspb_smooth.setToolTip(QtGui.QApplication.translate("optim_plotting_frame", "Smoothness of the approximation", None, QtGui.QApplication.UnicodeUTF8)) self.dspb_smooth.setSingleStep(0.1) self.dspb_smooth.setObjectName(_fromUtf8("dspb_smooth")) self.verticalLayout.addWidget(self.dspb_smooth) self.chb_normxy = QtGui.QCheckBox(optim_plotting_frame) self.chb_normxy.setText(QtGui.QApplication.translate("optim_plotting_frame", "normalized x,y ticks", None, QtGui.QApplication.UnicodeUTF8)) self.chb_normxy.setChecked(True) self.chb_normxy.setObjectName(_fromUtf8("chb_normxy")) self.verticalLayout.addWidget(self.chb_normxy) spacerItem = QtGui.QSpacerItem(20, 40, QtGui.QSizePolicy.Minimum, QtGui.QSizePolicy.Expanding) self.verticalLayout.addItem(spacerItem) self.horizontalLayout.addLayout(self.verticalLayout) self.widget = QtGui.QWidget(optim_plotting_frame) self.widget.setFocusPolicy(QtCore.Qt.StrongFocus) self.widget.setObjectName(_fromUtf8("widget")) self.horizontalLayout.addWidget(self.widget) self.horizontalLayout.setStretch(1, 10) self.verticalLayout_2.addLayout(self.horizontalLayout) self.retranslateUi(optim_plotting_frame) QtCore.QMetaObject.connectSlotsByName(optim_plotting_frame) def retranslateUi(self, optim_plotting_frame): pass

和main.py(下面的代码)中的类继承生成的小部件并在其上绘制曲面(将带有绘制曲面的FigureCanvasQTAgg添加到小部件)


    import sys
    from PyQt4 import QtGui, QtCore
    from PyQt4.QtGui import QApplication, QDialog

    import numpy as np
    from scipy.interpolate import Rbf

    from optim_plotting_frame import Ui_optim_plotting_frame

    from matplotlib.backends.backend_qt4agg import FigureCanvasQTAgg as FigureCanvas
    from matplotlib.backends.backend_qt4agg import NavigationToolbar2QTAgg as NavigationToolbar
    from mpl_toolkits.mplot3d import Axes3D
    from matplotlib.figure import Figure
    from matplotlib import cm


    class optim_plotting_frame(QtGui.QWidget):
        "Plotting settings widget"
        def __init__(self, x, y, z, minx=None, maxx=None, miny=None, maxy=None, xname=None, yname=None, zname=None):
            super(optim_plotting_frame, self).__init__()

            self.ui = Ui_optim_plotting_frame()
            self.ui.setupUi(self)

            """
            Parameters:
                x: numpy.ndarray
                    set of first parameter data (first coordinate)

                y: numpy.ndarray
                    set of second parameter data (second coordinate)

                z: numpy.ndarray
                    set of data corresponding to x and y (result)

                alpha: float, optional
                    Number between 0 and 1 that defines transparency: 0 - transparent,
                    1 - not transparent.

                function: string, optional
                    defines radial basis function for interpolation, by default - 'multiquadric'

                    'multiquadric': sqrt((r/self.epsilon)**2 + 1)
                    'inverse':      1.0/sqrt((r/self.epsilon)**2 + 1)
                    'gaussian':     exp(-(r/self.epsilon)**2)
                    'linear':       r
                    'cubic':        r**3
                    'quintic':      r**5
                    'thin_plate':   r**2 * log(r)

                smooth: float, optional
                    Values greater than zero increase the smoothness of the approximation.
                    0 is for interpolation (default), the function will always go through
                    the nodal points in this case.

                minx: float, opyional
                    minimum possible value of x

                maxx: float, opyional
                    maximum possible value of x

                miny: float, opyional
                    minimum possible value of y

                maxy: float, opyional
                    maximum possible value of y

                xname: string, optional
                    name of x parameter

                yname: string, optional
                    name of y parameter

                zname: string, optional
                    name of z parameter (objective function)
            """
            #==========================Data validation=============================
            if any((not isinstance(x, np.ndarray),
                    not isinstance(y, np.ndarray),
                    not isinstance(z, np.ndarray))):
                print "Error: x, y, z must be of numpy.ndarray type."
                return None

            if x.size != y.size != z.size:
                print "Error: x, y, z must be of equal size."
                return None

            if not isinstance(minx, (float, int)):
                minx = x.min()

            if not isinstance(maxx, (float, int)):
                maxx = x.max()

            if not isinstance(miny, (float, int)):
                miny = y.min()

            if not isinstance(maxy, (float, int)):
                maxy = y.max()

            if minx > maxx:
                tmp = minx
                minx = maxx
                maxx = tmp

            if miny > maxy:
                tmp = miny
                miny = maxy
                maxy = tmp

            if not isinstance(xname, str):
                xname = ""

            if not isinstance(yname, str):
                yname = ""

            if not isinstance(zname, str):
                zname = ""
            #======================================================================

            self.initialized = False

            self.x = x
            self.y = y
            self.z = z
            self.minx = minx
            self.maxx = maxx
            self.miny = miny
            self.maxy = maxy
            self.xname = xname
            self.yname = yname
            self.zname = zname

            # map to [0,1] range
            self.x = (x - minx) / maxx
            self.y = (y - miny) / maxy

            # update alpha, function, smooth values from widget
            self.alpha = self.ui.dspb_alpha.value()
            self.function = self.ui.cmb_function.currentText()
            self.smooth = self.ui.dspb_smooth.value()

            self.create_main_frame()

            self.initialized = True
            self.plot()

        def updateVals(self):
            """
            update alpha, function, smooth values from widget and parameters
            """
            self.alpha = self.ui.dspb_alpha.value()
            self.function = self.ui.cmb_function.currentText()
            self.smooth = self.ui.dspb_smooth.value()

        def plot(self):
            self.updateVals()
            self.axes.clear()

            try:
                # getting coordinate matrices from two coordinate vectors.
                tx = np.linspace(self.x.min(), self.x.max(), 100)
                ty = np.linspace(self.y.min(), self.y.max(), 100)
                XI, YI = np.meshgrid(tx, ty)
                # interpolating by radial basis function
                rbf = Rbf(self.x, self.y, self.z, function=str(self.function), smooth=self.smooth)
                # getting interpolation function results corresponding to (XI, YI)
                ZI = rbf(XI, YI)
                # plotting interpolaed surface
                self.axes.plot_surface(XI, YI, ZI, cmap=cm.jet, alpha=self.alpha)
            except Exception as e:
                print "Error occured! original message: " + e.message

            # plotting initial points
            self.axes.scatter(self.x, self.y, self.z)

            self.axes.set_xlim(self.x.min(), self.x.max())
            self.axes.set_ylim(self.y.min(), self.y.max())

            self.axes.set_title('RBF interpolation ' + self.function)
            if not self.ui.chb_normxy.isChecked():
                # setting ticks labels on the x line
                self.axes.set_xticklabels((self.axes.get_xticks() * self.maxx + self.minx).round(1))
                # setting ticks labels on the y line
                self.axes.set_yticklabels((self.axes.get_yticks() * self.maxy + self.miny).round(1))
            self.axes.set_xlabel(self.xname)
            self.axes.set_ylabel(self.yname)
            self.axes.set_zlabel(self.zname)

            ## adding colorbar
            #m = cm.ScalarMappable(cmap=cm.jet)
            #m.set_array(ZI)
            #self.axes.figure.colorbar(m)

            self.canvas.draw()

        def create_main_frame(self):
            # Create the mpl Figure and FigCanvas objects. 5x4 inches, 100 dots-per-inch
            self.dpi = 100
            self.fig = Figure((8.0, 4.0), dpi=self.dpi)
            # setting diagram background
            self.fig.patch.set_facecolor('white')
            self.canvas = FigureCanvas(self.fig)
            self.canvas.setParent(self.ui.widget)

            self.axes = Axes3D(self.fig)

            # Create the navigation toolbar, tied to the canvas
            self.mpl_toolbar = NavigationToolbar(self.canvas, self.ui.widget, coordinates=False)

            # connecting signals
            self.ui.cmb_function.currentIndexChanged.connect(self.plot)
            self.ui.dspb_alpha.valueChanged.connect(self.plot)
            self.ui.dspb_smooth.valueChanged.connect(self.plot)
            self.ui.chb_normxy.stateChanged.connect(self.plot)

            # Vertical layout for canvas and toolbar
            vbox = QtGui.QVBoxLayout()
            vbox.addWidget(self.canvas)
            vbox.addWidget(self.mpl_toolbar)

            self.ui.widget.setLayout(vbox)
            self.ui.widget.setFocus(QtCore.Qt.MouseFocusReason)

    # Create a Qt application
    app = QApplication(sys.argv)
    window = QDialog()

    opf = optim_plotting_frame(x = np.array([  100.,   300.,   500.,   700.,   900.,  1000.,   100.,   300., 500.,   700.,   900.,  1000.,   100.,   300.,   500.,   700., 900.,  1000.,   100.,   300.,   500.,   700.,   900.,  1000., 100.,   300.,   500.,   700.,   900.,  1000.,   100.,   300., 500.,   700.,   900.,  1000.]),
                               y = np.array([  100.,   100.,   100.,   100.,   100.,   100.,   300.,   300., 300.,   300.,   300.,   300.,   500.,   500.,   500.,   500., 500.,   500.,   700.,   700.,   700.,   700.,   700.,   700., 900.,   900.,   900.,   900.,   900.,   900.,  1000.,  1000., 1000.,  1000.,  1000.,  1000.]),
                               z = np.array([374712.60107421875, 526249.09765625, 500842.119140625, 391724.2041015625, 329192.123046875, 298277.92041015625, 526249.259765625, 601555.873046875, 598078.173828125, 529956.01953125, 502884.986328125, 485526.5244140625, 500841.181640625, 598078.400390625, 587555.86328125, 530815.837890625, 495623.544921875, 474902.572265625, 391725.0869140625, 529956.8408203125, 530815.6259765625, 447601.33081054688, 402540.9443359375, 385187.92944335938, 329192.2392578125, 502885.27734375, 495623.6396484375, 402541.17431640625, 365774.16870117188, 343962.6298828125, 298277.88305664062, 485526.775390625, 474903.0673828125, 385187.75439453125, 343962.728515625, 326735.05065917969]),
                               minx = 100,
                               maxx = 1000,
                               miny = 100,
                               maxy = 1000,
                               xname = 'width',
                               yname = 'height',
                               zname = 'WOPT')
    opf.show()
    sys.exit(app.exec_())

有两个问题: 1)轴网格与我的表面重叠。 2)sirface的旋转和缩放很慢。

当我不使用QWidget时,一切正常(下面的例子),但我应该在我的应用程序中绘制这个表面,我需要在QWidget上进行。因此,任何解决此问题的建议都将受到赞赏。


    import numpy as np
    from scipy.interpolate import Rbf
    import matplotlib.pyplot as plt
    import matplotlib.colors as colors
    from matplotlib import cm
    from mpl_toolkits.mplot3d import Axes3D

    def plotSurface(x, y, z, alpha=None, function=None, smooth=None,
                    minx=None, maxx=None, miny=None, maxy=None, xname=None, yname=None, zname=None):
        """
        This function will create new window with plotted surface on it.

        Parameters:
            x: numpy.ndarray
                set of first parameter data (first coordinate)

            y: numpy.ndarray
                set of second parameter data (second coordinate)

            z: numpy.ndarray
                set of data corresponding to x and y (result)

            alpha: float, optional
                Number between 0 and 1 that defines transparency: 0 - transparent,
                1 - not transparent.

            function: string, optional
                defines radial basis function for interpolation, by default - 'multiquadric'

                'multiquadric': sqrt((r/self.epsilon)**2 + 1)
                'inverse':      1.0/sqrt((r/self.epsilon)**2 + 1)
                'gaussian':     exp(-(r/self.epsilon)**2)
                'linear':       r
                'cubic':        r**3
                'quintic':      r**5
                'thin_plate':   r**2 * log(r)

            smooth: float, optional
                Values greater than zero increase the smoothness of the approximation.
                0 is for interpolation (default), the function will always go through
                the nodal points in this case.

            minx: float, opyional
                minimum possible value of x

            maxx: float, opyional
                maximum possible value of x

            miny: float, opyional
                minimum possible value of y

            maxy: float, opyional
                maximum possible value of y

            xname: string, optional
                name of x parameter

            yname: string, optional
                name of y parameter

            zname: string, optional
                name of z parameter (objective function)
        """
        #==========================Data validation=================================
        if any((not isinstance(x, np.ndarray),
                not isinstance(y, np.ndarray),
                not isinstance(z, np.ndarray))):
            print "Error: x, y, z must be of numpy.ndarray type."
            return None

        if x.size != y.size != z.size:
            print "Error: x, y, z must be of equal size."
            return None

        if not isinstance(alpha, (float, int)):
            alpha = 0.7

        if not isinstance(function, str):
            try:
                function = str(function)
            except:
                function = 'multiquadric'

        if not any((function == 'multiquadric',
                    function == 'inverse',
                    function == 'gaussian',
                    function == 'linear',
                    function == 'cubic',
                    function == 'quintic',
                    function == 'thin_plate')):
            function = 'multiquadric'

        if smooth is None:
            smooth = 0

        if not isinstance(minx, (float, int)):
            minx = x.min()

        if not isinstance(maxx, (float, int)):
            maxx = x.max()

        if not isinstance(miny, (float, int)):
            miny = y.min()

        if not isinstance(maxy, (float, int)):
            maxy = y.max()

        if minx > maxx:
            tmp = minx
            minx = maxx
            maxx = tmp

        if miny > maxy:
            tmp = miny
            miny = maxy
            maxy = tmp

        if not isinstance(xname, str):
            xname = ""

        if not isinstance(yname, str):
            yname = ""

        if not isinstance(zname, str):
            zname = ""
        #==========================================================================

        fig = plt.figure()
        # setting diagram background
        fig.patch.set_facecolor('white')
        ax = Axes3D(fig)

        # map to [0,1] range
        x = (x - minx) / maxx
        y = (y - miny) / maxy

        try:
            # getting coordinate matrices from two coordinate vectors.
            tx = np.linspace(x.min(), x.max(), 100)
            ty = np.linspace(y.min(), y.max(), 100)
            XI, YI = np.meshgrid(tx, ty)
            # interpolating by radial basis function
            rbf = Rbf(x, y, z, function=function, smooth=smooth)
            # getting interpolation function results corresponding to (XI, YI)
            ZI = rbf(XI, YI)
            # plotting interpolaed surface
            ax.plot_surface(XI, YI, ZI, cmap=cm.jet, alpha=alpha)
        except Exception as e:
            print "Error occured! original message: " + e.message

        # plotting initial points
        ax.scatter(x, y, z)

        ax.set_xlim(x.min(), x.max())
        ax.set_ylim(y.min(), y.max())

        ax.set_title('RBF interpolation ' + function)
        # setting ticks positions on the x line
        ax.set_xticklabels((ax.get_xticks() * maxx + minx).round(1))
        # setting ticks labelson the y line
        ax.set_yticklabels((ax.get_yticks() * maxy + miny).round(1))
        ax.set_xlabel(xname)
        ax.set_ylabel(yname)
        ax.set_zlabel(zname)

        # adding colorbar
        m = cm.ScalarMappable(cmap=cm.jet)
        m.set_array(ZI)
        ax.figure.colorbar(m)

        # showing window with diagram
        plt.show()

    plotSurface(x = np.array([  100.,   300.,   500.,   700.,   900.,  1000.,   100.,   300., 500.,   700.,   900.,  1000.,   100.,   300.,   500.,   700., 900.,  1000.,   100.,   300.,   500.,   700.,   900.,  1000., 100.,   300.,   500.,   700.,   900.,  1000.,   100.,   300., 500.,   700.,   900.,  1000.]),
                y = np.array([  100.,   100.,   100.,   100.,   100.,   100.,   300.,   300., 300.,   300.,   300.,   300.,   500.,   500.,   500.,   500., 500.,   500.,   700.,   700.,   700.,   700.,   700.,   700., 900.,   900.,   900.,   900.,   900.,   900.,  1000.,  1000., 1000.,  1000.,  1000.,  1000.]),
                z = np.array([374712.60107421875, 526249.09765625, 500842.119140625, 391724.2041015625, 329192.123046875, 298277.92041015625, 526249.259765625, 601555.873046875, 598078.173828125, 529956.01953125, 502884.986328125, 485526.5244140625, 500841.181640625, 598078.400390625, 587555.86328125, 530815.837890625, 495623.544921875, 474902.572265625, 391725.0869140625, 529956.8408203125, 530815.6259765625, 447601.33081054688, 402540.9443359375, 385187.92944335938, 329192.2392578125, 502885.27734375, 495623.6396484375, 402541.17431640625, 365774.16870117188, 343962.6298828125, 298277.88305664062, 485526.775390625, 474903.0673828125, 385187.75439453125, 343962.728515625, 326735.05065917969]),
                alpha = 0.7,
                function = 'multiquadric',
                smooth = 0.0,
                minx = 100,
                maxx = 1000,
                miny = 100,
                maxy = 1000,
                xname = 'width',
                yname = 'height',
                zname = 'WOPT')

1 个答案:

答案 0 :(得分:0)

使用matplotlib进行旋转和缩放始终是一个相当慢的操作,因为它需要再次渲染整个图形。