沿轴的Matplotlib值选择器小部件

时间:2019-01-21 14:26:33

标签: python matplotlib pyqt

我想在python pyqt GUI中实现类似于我在AstroImageJ中看到的功能,您可以在其中调整图像的对比度。

AstroImageJ对比度调整示例

我是Python的新手,还没有找到实现此目的的任何方法。也许matplotlib小部件或艺术家提供了此类功能?

如果标题令人困惑,也表示抱歉。我欢迎任何改进的建议!

1 个答案:

答案 0 :(得分:0)

所以我认为我终于找到了解决方案,但是可能可以改进很多。我在这里发布了一个工作示例的代码,但是它很长,并且还包含其他一些缩放和平移图像的方法。如果有人想尝试一下并给我一些反馈,欢迎您!

有时,当我关闭窗口时,也会多次收到此错误消息:

Exception ignored in: <function WeakMethod.__new__.<locals>._cb at 0x00000193A3D7C7B8>
Traceback (most recent call last):
  File "C:\Users\mapf\Anaconda3\lib\weakref.py", line 58, in _cb
  File "C:\Users\mapf\Anaconda3\lib\site-packages\matplotlib\cbook\__init__.py", line 182, in _remove_proxy
  File "C:\Users\mapf\Anaconda3\lib\weakref.py", line 74, in __eq__
TypeError: isinstance() arg 2 must be a type or tuple of types

这是它的样子:

enter image description here

代码如下:

import sys
import numpy as np
import copy
import matplotlib.pyplot as plt
from matplotlib.text import Annotation
import matplotlib.patheffects as PathEffects
from matplotlib.backends.backend_qt5agg import \
    FigureCanvasQTAgg as FigureCanvas
from matplotlib.patches import Rectangle

from PyQt5.QtWidgets import QDialog, QApplication, QGridLayout
from astropy.visualization import ImageNormalize, LinearStretch, ZScaleInterval


class IDAnnotation(Annotation):
    def __init__(
            self, text, position, ha='center', rotation=0, fontsize=15,
            picker=False, zorder=3, clip_on=True, identifier='',
            verticalalignment='baseline'
    ):
        super().__init__(
            text, position, ha=ha, rotation=rotation, fontsize=fontsize,
            picker=picker, zorder=zorder, clip_on=clip_on,
            verticalalignment=verticalalignment
        )
        self._id = identifier

    def get_id(self):
        return self._id

    def set_id(self, identifier):
        self._id = identifier


class ImageFigure:
    def __init__(self, image):
        self.fig, self.ax = plt.subplots()
        self.canvas = FigureCanvas(self.fig)
        self.base_image = image
        self.base_image[np.where(self.base_image < 0)] = 0
        self.image = copy.deepcopy(self.base_image)
        self.norm = ImageNormalize(
            self.image, stretch=LinearStretch(),
            interval=ZScaleInterval()
        )
        self.image_artist = self.ax.imshow(
            image, cmap='gray', interpolation='nearest', norm=self.norm
        )
        self.clim = self.image_artist.get_clim()
        self.base_scale = 2.0
        self.base_xlim = self.ax.get_xlim()
        self.base_ylim = self.ax.get_ylim()
        self.new_xlim = [0, 1]
        self.new_ylim = [0, 1]
        self.x_press = 0
        self.y_press = 0
        self.fig.canvas.mpl_connect('scroll_event', self.zoom)
        self.fig.canvas.mpl_connect('button_press_event', self.pan_press)
        self.fig.canvas.mpl_connect('motion_notify_event', self.pan_move)

        self.hist = np.hstack(self.base_image)
        self.hist = np.delete(self.hist, np.where(self.hist == 0))
        self.contrast = HistogramFigure(self.hist, self.clim)
        # self.contrast.fig.canvas.mpl_connect(
        #     'button_release_event', self.adjust_contrast
        # )
        self.contrast.fig.canvas.mpl_connect(
            'motion_notify_event', self.adjust_contrast
        )

    def adjust_contrast(self, event):
        self.contrast.on_move_event(event)
        low_in = self.contrast.lclim
        high_in = self.contrast.uclim

        self.image_artist.set_clim(low_in, high_in)
        self.canvas.draw_idle()

    def zoom(self, event):
        xdata = event.xdata
        ydata = event.ydata
        if xdata is None or ydata is None:
            pass
        else:
            cur_xlim = self.ax.get_xlim()
            cur_ylim = self.ax.get_ylim()
            x_left = xdata - cur_xlim[0]
            x_right = cur_xlim[1] - xdata
            y_top = ydata - cur_ylim[0]
            y_bottom = cur_ylim[1] - ydata
            if event.button == 'up':
                scale_factor = 1 / self.base_scale
            elif event.button == 'down':
                scale_factor = self.base_scale
            else:
                scale_factor = 1

            new_xlim = [
                xdata - x_left*scale_factor, xdata + x_right*scale_factor
            ]
            new_ylim = [
                ydata - y_top*scale_factor, ydata + y_bottom*scale_factor
            ]

            # intercept new plot parameters if they are out of bound
            self.new_xlim, self.new_ylim = check_limits(
                self.base_xlim, self.base_ylim, new_xlim, new_ylim
            )

            self.ax.set_xlim(self.new_xlim)
            self.ax.set_ylim(self.new_ylim)
            self.canvas.draw()

    def pan_press(self, event):
        if event.button == 1:
            if event.xdata is None or event.ydata is None:
                pass
            else:
                self.x_press = event.xdata
                self.y_press = event.ydata

    def pan_move(self, event):
        if event.button == 1:
            xdata = event.xdata
            ydata = event.ydata
            if xdata is None or ydata is None:
                pass
            else:
                cur_xlim = self.ax.get_xlim()
                cur_ylim = self.ax.get_ylim()
                dx = xdata - self.x_press
                dy = ydata - self.y_press
                new_xlim = [cur_xlim[0] - dx, cur_xlim[1] - dx]
                new_ylim = [cur_ylim[0] - dy, cur_ylim[1] - dy]

                # intercept new plot parameters that are out of bound
                new_xlim, new_ylim = check_limits(
                    self.base_xlim, self.base_ylim, new_xlim, new_ylim
                )

                self.ax.set_xlim(new_xlim)
                self.ax.set_ylim(new_ylim)
                self.canvas.draw()


class HistogramFigure:
    def __init__(self, image, clim):
        self.fig, self.ax = plt.subplots()
        self.canvas = FigureCanvas(self.fig)
        self.image = image
        self.clim = clim
        self.uclim = self.clim[1]
        self.lclim = self.clim[0]
        self.nbins = 20
        self.dragged = None
        self.pick_pos = None
        self.uclim_hightlight = False
        self.lclim_hightlight = False
        self.dummy_patches = [False, False]
        self.cropped_patches_index = [0, 0]
        self.canvas.setMaximumHeight(100)
        self.fig.subplots_adjust(left=0.07, right=0.98, bottom=0.1, top=0.75)
        self.ax.tick_params(
            axis="both", labelsize=6, left=True, top=True, labelleft=True,
            labeltop=True, bottom=False, labelbottom=False
        )
        self.ax.tick_params(which='minor', bottom=False, top=True)
        self.bins = np.geomspace(
            min(self.image), max(self.image), self.nbins
        )
        _, _, self.patches = self.ax.hist(
            self.image, bins=self.bins, log=True, zorder=1
        )
        self.ax.set_xscale("log", nonposx='clip')
        self.color_patches()

        self.ax.margins(0, 0.1)
        self.uclim_marker = IDAnnotation(
            r'$\blacktriangledown$',
            (self.uclim, self.ax.get_ylim()[1]/6),
            ha='center', fontsize=15, picker=True, zorder=3, clip_on=False,
            identifier='uclim'
        )
        self.lclim_marker = IDAnnotation(
            r'$\blacktriangle$',
            (self.lclim+self.ax.get_xlim()[0], self.ax.get_ylim()[0]*16),
            ha='center', verticalalignment='top', fontsize=15, picker=True,
            zorder=2, clip_on=False, identifier='lclim'
        )
        self.ax.add_artist(self.uclim_marker)
        self.ax.add_artist(self.lclim_marker)

        self.fig.canvas.mpl_connect('pick_event', self.on_pick_event)
        self.fig.canvas.mpl_connect(
            'motion_notify_event', self.highlight_picker
        )
        self.fig.canvas.mpl_connect(
            'button_release_event', self.on_release_event
        )
        self.fig.canvas.mpl_connect(
            'button_press_event', self.on_button_press_event
        )

        self.canvas.draw()

    def color_patches(self):
        j = 0
        i = self.bins[j]
        overlap = False
        while i < self.lclim:
            self.patches[j].set_facecolor('gray')
            j += 1
            i = self.bins[j]
        if j > 0:
            self.cropped_patches_index[0] = j - 1
            self.patches[j - 1].set_width(self.lclim - self.bins[j - 1])
            self.patches[j - 1].set_facecolor('gray')
            if self.uclim <= self.bins[j]:
                width = self.uclim - self.lclim
                overlap = True
            else:
                width = self.bins[j] - self.lclim
            if self.dummy_patches[0]:
                self.dummy_patches[0].set_xy(
                    (self.lclim, self.patches[j - 1].get_y())
                )
                self.dummy_patches[0].set_width(width)
                self.dummy_patches[0].set_height(
                    self.patches[j - 1].get_height())
            else:
                self.dummy_patches[0] = Rectangle(
                    (self.lclim, self.patches[j - 1].get_y()),
                    width=width, linewidth=0,
                    height=self.patches[j - 1].get_height(), color='c'
                )
                self.ax.add_artist(self.dummy_patches[0])
        if not overlap:
            while np.logical_and(
                    i < np.max(self.bins), i < self.uclim
            ):
                self.patches[j].set_facecolor('c')
                j += 1
                i = self.bins[j]
            self.cropped_patches_index[1] = j-1
            self.patches[j-1].set_width(self.uclim - self.bins[j-1])
            self.patches[j-1].set_facecolor('c')
        if self.dummy_patches[1]:
            self.dummy_patches[1].set_xy(
                (self.uclim, self.patches[j-1].get_y())
            )
            self.dummy_patches[1].set_width(self.bins[j]-self.uclim)
            self.dummy_patches[1].set_height(self.patches[j-1].get_height())
        else:
            self.dummy_patches[1] = Rectangle(
                (self.uclim, self.patches[j-1].get_y()),
                width=self.bins[j]-self.uclim, linewidth=0,
                height=self.patches[j-1].get_height(), color='gray'
            )
        self.ax.add_artist(self.dummy_patches[1])
        while i < max(self.bins):
            self.patches[j].set_facecolor('gray')
            j += 1
            i = self.bins[j]

    def add_dummy(self, j, colors, limit):
        if colors[0] == 'gray':
            idx = 0
        else:
            idx = 1
        self.cropped_patches_index[idx] = j
        self.patches[j].set_width(limit - self.bins[j])
        self.patches[j].set_facecolor(colors[0])
        self.dummy_patches[idx].set_xy((limit, self.patches[j].get_y()))
        self.dummy_patches[idx].set_width(self.bins[j]-limit)
        self.dummy_patches[idx].set_height(self.patches[j].get_height())
        # self.dummy_patches[0] = Rectangle(
        #     (limit, self.patches[j].get_y()),
        #     width=self.bins[j]-limit, linewidth=0,
        #     height=self.patches[j].get_height(),
        #     color=colors[1]
        # )
        # self.ax.add_artist(self.dummy_patches[0])

    def on_pick_event(self, event):
        """
            Store which text object was picked and were the pick event occurs.
        """
        if isinstance(event.artist, Annotation):
            self.dragged = event.artist
            inv = self.ax.transData.inverted()
            self.pick_pos = inv.transform(
                (event.mouseevent.x, event.mouseevent.y)
            )[0]
            if self.pick_pos < self.ax.get_xlim()[0]:
                self.pick_pos = self.ax.get_xlim()[0]
            if self.pick_pos > self.ax.get_xlim()[1]:
                self.pick_pos = self.ax.get_xlim()[1]
        return True

    def on_button_press_event(self, event):
        if np.logical_and(
            event.button == 1,
            self.lclim_marker.contains(event)[0]
            == self.uclim_marker.contains(event)[0]
        ):
            inv = self.ax.transData.inverted()
            self.pick_pos = inv.transform(
                (event.x, event.y)
            )[0]

    def on_release_event(self, _):
        if self.dragged is not None:
            self.dragged = None

    def on_move_event(self, event):
        """Update text position and redraw"""
        if event.button == 1:
            inv = self.ax.transData.inverted()
            new_pos = (inv.transform((event.x, event.y))[0])
            if self.dragged is not None:
                old_pos = self.dragged.get_position()
                if self.dragged.get_id() == 'lclim':
                    if new_pos < self.ax.get_xlim()[0]:
                        new_pos = self.ax.get_xlim()[0]
                    self.lclim = new_pos
                    if self.lclim > self.uclim:
                        self.lclim = self.uclim*0.999
                    self.dragged.set_position(
                        (self.lclim, old_pos[1])
                    )
                    self.patches[
                        self.cropped_patches_index[0]].set_width(
                        self.bins[self.cropped_patches_index[0] + 1]
                        - self.bins[self.cropped_patches_index[0]]
                    )
                elif self.dragged.get_id() == 'uclim':
                    if new_pos > self.ax.get_xlim()[1]:
                        new_pos = self.ax.get_xlim()[1]
                    self.uclim = new_pos
                    if self.uclim < self.lclim:
                        self.uclim = self.lclim*1.001
                    self.dragged.set_position(
                        (self.uclim, old_pos[1])
                    )
                    self.patches[
                        self.cropped_patches_index[1]].set_width(
                        self.bins[self.cropped_patches_index[1] + 1]
                        - self.bins[self.cropped_patches_index[1]]
                    )
                else:
                    pass

                # self.dummy_patches = []

                self.color_patches()

                self.ax.figure.canvas.draw()
            else:
                pass

        return True

    def highlight_picker(self, event):
        if event.button == 1:
            pass
        else:
            if self.uclim_marker.contains(event)[0]:
                if not self.uclim_hightlight:
                    self.uclim_hightlight = True
                    self.uclim_marker.set_path_effects(
                        [PathEffects.withStroke(linewidth=2, foreground="c")]
                    )
                    self.ax.figure.canvas.draw()
                else:
                    pass
            else:
                if self.uclim_hightlight:
                    self.uclim_hightlight = False
                    self.uclim_marker.set_path_effects(
                        [PathEffects.Normal()]
                    )
                    self.ax.figure.canvas.draw()
                else:
                    pass

            if self.lclim_marker.contains(event)[0]:
                if self.lclim_hightlight:
                    pass
                else:
                    self.lclim_hightlight = True
                    self.lclim_marker.set_path_effects(
                        [PathEffects.withStroke(linewidth=2, foreground="c")]
                    )
                    self.ax.figure.canvas.draw()
            else:
                if self.lclim_hightlight:
                    self.lclim_hightlight = False
                    self.lclim_marker.set_path_effects(
                        [PathEffects.Normal()]
                    )
                    self.ax.figure.canvas.draw()
                else:
                    pass

        return True


class MainWindow(QDialog):
    def __init__(self):
        super().__init__()
        self.img = np.random.random((500, 500))
        self.layout = None
        self.image = None
        self.contrast = None

        self.create_widgets()

    def create_widgets(self):
        self.layout = QGridLayout(self)
        self.image = ImageFigure(self.img)
        self.contrast = self.image.contrast

        self.layout.addWidget(self.image.canvas, 0, 0)
        self.layout.addWidget(self.contrast.canvas, 1, 0)


def check_limits(base_xlim, base_ylim, new_xlim, new_ylim):
    if new_xlim[0] < base_xlim[0]:
        overlap = base_xlim[0] - new_xlim[0]
        new_xlim[0] = base_xlim[0]
        if new_xlim[1] + overlap > base_xlim[1]:
            new_xlim[1] = base_xlim[1]
        else:
            new_xlim[1] += overlap
    if new_xlim[1] > base_xlim[1]:
        overlap = new_xlim[1] - base_xlim[1]
        new_xlim[1] = base_xlim[1]
        if new_xlim[0] - overlap < base_xlim[0]:
            new_xlim[0] = base_xlim[0]
        else:
            new_xlim[0] -= overlap
    if new_ylim[1] < base_ylim[1]:
        overlap = base_ylim[1] - new_ylim[1]
        new_ylim[1] = base_ylim[1]
        if new_ylim[0] + overlap > base_ylim[0]:
            new_ylim[0] = base_ylim[0]
        else:
            new_ylim[0] += overlap
    if new_ylim[0] > base_ylim[0]:
        overlap = new_ylim[0] - base_ylim[0]
        new_ylim[0] = base_ylim[0]
        if new_ylim[1] - overlap < base_ylim[1]:
            new_ylim[1] = base_ylim[1]
        else:
            new_ylim[1] -= overlap

    return new_xlim, new_ylim


if __name__ == '__main__':
    app = QApplication(sys.argv)
    GUI = MainWindow()
    GUI.show()
    sys.exit(app.exec_())