我有一些代码可以生成parallel coordinates plots。
似乎工作正常,除了有时很难跟踪一个数据集的连接线,特别是当绘制了许多数据集并且数据集线具有水平线段时。使这一点更加清晰的一种方法是确保数据集线在所有轴的顶部上绘制。怎么可以用matplotlib来完成?
代码如下:
#!/usr/bin/python
import os
import matplotlib.pyplot
import pyprel # pip install pyprel
import shijian # pip install shijian
def main():
data = [
[0.9498061441975048, 0.6049394236952985, 5.168760062095979, 6.571102071810909, 0.3258710500404989],
[0.6900900610786396, 0.044651963051884014, 5.6222021103273185, 6.413258445862534, 0.3989822610565017],
[0.40245822087321015, 0.9000644399147708, 5.698656759965443, 5.706219545980181, 1.2501169295120753],
[0.007788240630026866, 1.6630324065241182, 5.326062675832998, 6.198416367191167, 1.3785261503382713],
[1.2329258057665577, 1.3760208006135484, 5.40272344446576, 5.697339583623006, 1.6475096314400532],
[1.3736286937326467, 1.618862515954041, 5.9745190659395355, 5.252136196112335, 0.36162007500593485]
]
save_parallel_coordinates_matplotlib(
data,
filename = "parallel_coordinates_1.png"
)
def save_parallel_coordinates_matplotlib(
datasets,
styles = None,
title = None,
label_x = "",
label_y = "",
labels_ticks_x_axis = None, # under consideration
filename = None,
directory = ".",
overwrite = True,
LaTeX = False
):
matplotlib.pyplot.ioff()
if LaTeX is True:
matplotlib.pyplot.rc("text", usetex = True)
matplotlib.pyplot.rc("font", family = "serif")
if filename is None:
filename = shijian.propose_filename(
filename = title.replace(" ", "_") + ".png",
overwrite = overwrite
)
else:
filename = shijian.propose_filename(
filename = filename,
overwrite = overwrite
)
dimensions = len(datasets[0])
if labels_ticks_x_axis is None:
labels_ticks_x_axis = range(dimensions)
figure, axes = matplotlib.pyplot.subplots(
1,
dimensions - 1,
sharey = False,
figsize = (14, 14)
)
# If no list of line styles is set, create a list of colors for lines.
if styles is None:
colors = pyprel.access_palette(
name = "palette21",
minimum_number_of_colors_needed = len(datasets)
)
styles = colors
# Calculate limits of data for each feature.
range_minimum_maximum = list()
for dataset in zip(*datasets):
minimum_value = min(dataset)
maximum_value = max(dataset)
if minimum_value == maximum_value:
minimum_value -= 0.5
maximum_value = minimum_value + 1.
range_of_values = float(maximum_value - minimum_value)
range_minimum_maximum.append((
minimum_value,
maximum_value,
range_of_values
))
# Normalize datasets.
datasets_normalized = list()
for dataset in datasets:
dataset_normalized = [
(value - range_minimum_maximum[dimension][0]) /
range_minimum_maximum[dimension][2]
for dimension, value in enumerate(dataset)
]
datasets_normalized.append(dataset_normalized)
datasets = datasets_normalized
# Plot datasets on all subplots.
for index_axis, axis in enumerate(axes):
for index_dataset, dataset in enumerate(datasets):
axis.plot(labels_ticks_x_axis, dataset, styles[index_dataset])
axis.set_xlim([
labels_ticks_x_axis[index_axis],
labels_ticks_x_axis[index_axis + 1]
])
# Set all y-axis ticks except last.
for dimension, (axes_most, label_tick_x_axis) in enumerate(zip(axes, labels_ticks_x_axis[:-1])):
axes_most.xaxis.set_major_locator(matplotlib.ticker.FixedLocator([label_tick_x_axis]))
number_of_ticks = len(axes_most.get_yticklabels())
labels_ticks = list()
step_ticks = range_minimum_maximum[dimension][2] / (number_of_ticks - 1)
value_minimum_tick = range_minimum_maximum[dimension][0]
labels_ticks = [
"{value:4.2f}".format(
value = (value_minimum_tick + index_tick * step_ticks)
) for index_tick in xrange(number_of_ticks)
]
axes_most.set_yticklabels(labels_ticks)
# Set all last y-axis ticks to the right of the plot.
axes_last = matplotlib.pyplot.twinx(axes[-1])
dimension += 1
axes_last.xaxis.set_major_locator(matplotlib.ticker.FixedLocator([
labels_ticks_x_axis[-2], labels_ticks_x_axis[-1]
]))
number_of_ticks = len(axes_last.get_yticklabels())
step_ticks = range_minimum_maximum[dimension][2] / (number_of_ticks - 1)
value_minimum_tick = range_minimum_maximum[dimension][0]
labels_ticks = [
"{value:4.2f}".format(
value = (value_minimum_tick + index_tick * step_ticks)
) for index_tick in xrange(number_of_ticks)
]
axes_last.set_yticklabels(labels_ticks)
# Stack subplots.
matplotlib.pyplot.subplots_adjust(wspace = 0)
if title is not None:
figure.suptitle(title, fontsize = 20)
matplotlib.pyplot.xlabel(label_x)
matplotlib.pyplot.ylabel(label_y)
if not os.path.exists(directory):
os.makedirs(directory)
matplotlib.pyplot.savefig(
directory + "/" + filename,
bbox_inches = "tight",
dpi = 700
)
matplotlib.pyplot.close()
if __name__ == "__main__":
main()