如何在线图线上绘制标签?

时间:2017-04-23 16:23:08

标签: matplotlib seaborn

我想在matplotlib的一行线图上绘制标签。

最小的例子

#!/usr/bin/env python
import numpy as np
import seaborn as sns
sns.set_style("whitegrid")
sns.set_palette(sns.color_palette("Greens", 8))
from scipy.ndimage.filters import gaussian_filter1d

for i in range(8):
    # Create data
    y = np.roll(np.cumsum(np.random.randn(1000, 1)),
                np.random.randint(0, 1000))
    y = gaussian_filter1d(y, 10)
    sns.plt.plot(y, label=str(i))
sns.plt.legend()
sns.plt.show()

产生

enter image description here

相反,我更喜欢像

这样的东西

enter image description here

1 个答案:

答案 0 :(得分:7)

也许有点hacky,但这会解决你的问题吗?

#!/usr/bin/env python
import numpy as np
import seaborn as sns
sns.set_style("whitegrid")
sns.set_palette(sns.color_palette("Greens", 8))
from scipy.ndimage.filters import gaussian_filter1d

for i in range(8):
    # Create data
    y = np.roll(np.cumsum(np.random.randn(1000, 1)),
                np.random.randint(0, 1000))
    y = gaussian_filter1d(y, 10)
    p = sns.plt.plot(y, label=str(i))
    color = p[0].get_color()
    for x in [250, 500, 750]:
        y2 = y[x]
        sns.plt.plot(x, y2, 'o', color='white', markersize=9)
        sns.plt.plot(x, y2, 'k', marker="$%s$" % str(i), color=color,
                     markersize=7)
sns.plt.legend()
sns.plt.show()

这是我得到的结果:

My Result

修改:我更多地考虑了一下,并提出了一个解决方案,可以自动尝试找到标签的最佳位置,以避免标签位于x值的位置两条线彼此非常接近(这可能导致标签之间重叠):

#!/usr/bin/env python
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
sns.set_style("whitegrid")
sns.set_palette(sns.color_palette("Greens", 8))
from scipy.ndimage.filters import gaussian_filter1d

# -----------------------------------------------------------------------------

def inline_legend(lines, n_markers=1):
    """
    Take a list containing the lines of a plot (typically the result of 
    calling plt.gca().get_lines()), and add the labels for those lines on the
    lines themselves; more precisely, put each label n_marker times on the 
    line. 
    [Source of problem: https://stackoverflow.com/q/43573623/4100721]
    """

    import matplotlib.pyplot as plt
    from scipy.interpolate import interp1d
    from math import fabs

    def chunkify(a, n):
        """
        Split list a into n approximately equally sized chunks and return the 
        indices (start/end) of those chunks.
        [Idea: Props to http://stackoverflow.com/a/2135920/4100721 :)]
        """
        k, m = divmod(len(a), n)
        return list([(i * k + min(i, m), (i + 1) * k + min(i + 1, m)) 
                     for i in range(n)])

    # Calculate linear interpolations of every line. This is necessary to 
    # compare the values of the lines if they use different x-values
    interpolations = [interp1d(_.get_xdata(), _.get_ydata()) 
                      for _ in lines]


    # Loop over all lines
    for idx, line in enumerate(lines):

        # Get basic properties of the current line
        label = line.get_label()
        color = line.get_color()
        x_values = line.get_xdata()
        y_values = line.get_ydata()

        # Get all lines that are not the current line, as well as the
        # functions that are linear interpolations of them
        other_lines = lines[0:idx] + lines[idx+1:]
        other_functions = interpolations[0:idx] + interpolations[idx+1:]

        # Split the x-values in chunks to get regions in which to put 
        # labels. Creating 3 times as many chunks as requested and using only
        # every third ensures that no two labels for the same line are too
        # close to each other.
        chunks = list(chunkify(line.get_xdata(), 3*n_markers))[::3]

        # For each chunk, find the optimal position of the label
        for chunk_nr in range(n_markers):

            # Start and end index of the current chunk
            chunk_start = chunks[chunk_nr][0]
            chunk_end = chunks[chunk_nr][1]

            # For the given chunk, loop over all x-values of the current line,
            # evaluate the value of every other line at every such x-value,
            # and store the result.
            other_values = [[fabs(y_values[int(x)] - f(x)) for x in 
                             x_values[chunk_start:chunk_end]]
                            for f in other_functions]

            # Now loop over these values and find the minimum, i.e. for every
            # x-value in the current chunk, find the distance to the closest
            # other line ("closest" meaning abs_value(value(current line at x)
            # - value(other lines at x)) being at its minimum)
            distances = [min([_ for _ in [row[i] for row in other_values]]) 
                         for i in range(len(other_values[0]))]

            # Now find the value of x in the current chunk where the distance
            # is maximal, i.e. the best position for the label and add the
            # necessary offset to take into account that the index obtained
            # from "distances" is relative to the current chunk
            best_pos = distances.index(max(distances)) + chunks[chunk_nr][0]

            # Short notation for the position of the label
            x = best_pos
            y = y_values[x]

            # Actually plot the label onto the line at the calculated position
            plt.plot(x, y, 'o', color='white', markersize=9)
            plt.plot(x, y, 'k', marker="$%s$" % label, color=color,
                     markersize=7)

# -----------------------------------------------------------------------------

for i in range(8):
    # Create data
    y = np.roll(np.cumsum(np.random.randn(1000, 1)),
                np.random.randint(0, 1000))
    y = gaussian_filter1d(y, 10)
    sns.plt.plot(y, label=str(i))

inline_legend(plt.gca().get_lines(), n_markers=3)
sns.plt.show()

此解决方案的示例输出(请注意标签的x位置不再完全相同): Improved solution 如果想要避免使用scipy.interpolate.interp1d,可以考虑一种解决方案,其中对于线A的给定x值,人们发现线B的x值最接近该值。我认为如果线条使用非常不同和/或稀疏的网格,这可能会有问题吗?