使用带有{tuple:float}格式的dict中的x和y标签在python matplotlib中创建热图

时间:2015-11-14 17:17:51

标签: python dictionary matplotlib tuples heatmap

我有一个dict,其中包含电影 - 标题对作为键,相似度得分为值:

{('Source Code ', 'Hobo with a Shotgun '): 1.0, ('Adjustment Bureau, The ', 'Just Go with It '): 1.0, ('Limitless ', 'Arthur '): 1.0, ('Adjustment Bureau, The ', 'Kung Fu Panda 2 '): 1.0, ('Rise of the Planet of the Apes ', 'Scream 4 '): 1.0, ('Source Code ', 'Take Me Home Tonight '): 1.0, ('Midnight in Paris ', 'Take Me Home Tonight '): 1.0, ('Harry Potter and the Deathly Hallows: Part 2 ', 'Pina '): 1.0, ('Avengers, The ', 'Drive Angry '): 1.0, ('Limitless ', 'Super 8 '): 1.0, ('Harry Potter and the Deathly Hallows: Part 2 ', 'Arthur '): 1.0, ('Source Code ', 'Melancholia '): 0.6666666666666666, ('Harry Potter and the Deathly Hallows: Part 2 ', 'Jane Eyre '): 1.0, ('Avengers, The ', 'Arthur '): 0.6666666666666666, ('The Artist ', 'Attack the Block '): 1.0, ('Midnight in Paris ', 'Priest '): 1.0, ('Adjustment Bureau, The ', 'Hanna '): 1.0, ('The Artist ', 'Thor '): 1.0, ('The Artist ', 'Zeitgeist: Moving Forward '): 1.0, ('The Artist ', 'Green Hornet, The '): 1.0, ('X-Men: First Class ', 'Sanctum '): 1.0, ('Source Code ', 'Green Hornet, The '): 1.0, ('Harry Potter and the Deathly Hallows: Part 2 ', 'Something Borrowed '): 1.0, ('Adjustment Bureau, The ', 'Rio '): 1.0, ('Avengers, The ', 'Mechanic, The '): 1.0, ('Rise of the Planet of the Apes ', 'Something Borrowed '): 0.6666666666666666, ('Captain America: The First Avenger ', 'Attack the Block '): 0.6666666666666666, ('Avengers, The ', 'Zeitgeist: Moving Forward '): 1.0, ('Midnight in Paris ', 'Arthur '): 1.0, ('Source Code ', 'Arthur '): 1.0, ('Limitless ', 'Take Me Home Tonight '): 1.0, ('Midnight in Paris ', 'Win Win '): 1.0, ('X-Men: First Class ', 'Something Borrowed '): 1.0, ('Avengers, The ', 'Dilemma, The '): 1.0, ('X-Men: First Class ', 'Green Hornet, The '): 1.0, ('The Artist ', 'Just Go with It '): 1.0, ('Rise of the Planet of the Apes ', 'Arthur '): 1.0, ('Captain America: The First Avenger ', 'Lincoln Lawyer, The '): 1.0, ('X-Men: First Class ', 'Hobo with a Shotgun '): 1.0, ('Limitless ', 'Mechanic, The '): 0.6666666666666666, ('Captain America: The First Avenger ', 'Green Hornet, The '): 1.0, ('Captain America: The First Avenger ', 'Hangover Part II, The '): 1.0, ('X-Men: First Class ', 'Hanna '): 1.0, ('Rise of the Planet of the Apes ', 'Priest '): 1.0, ('Midnight in Paris ', 'I Am Number Four '): 1.0, ('Rise of the Planet of the Apes ', 'Tree of Life, The '): 1.0, ('Captain America: The First Avenger ', 'Hanna '): 1.0, ('Harry Potter and the Deathly Hallows: Part 2 ', 'Win Win '): 1.0, ('Limitless ', 'Drive Angry '): 0.6666666666666666, ('Adjustment Bureau, The ', 'Hangover Part II, The '): 1.0}

我想使用matplotlib创建一个热图,每个键中的第一个电影作为y标签,每个键中的第二个电影作为x标签,相似性得分作为z轴。

到目前为止,我已经使用以下内容作为指南(Converting a dictionary of tuples into a numpy matrix),但它似乎没有绘制正确的分布(见图)。

enter image description here

我现在为z轴创建numpy.ndarrray的代码如下所示:

import numpy as np 
import matplotlib.pyplot as plt

# create a heatmap
sim_scores = np.array(dict_sim_scores.values())
movie_titles = np.array(dict_sim_scores.keys())
## return unique movie titles and indices of the input array
unq_titles, title_idx = np.unique(movie_titles, return_inverse=True)
title_idx = title_idx.reshape(-1,2)
n = len(unq_titles)
sim_matrix = np.zeros((n, n) ,dtype=sim_scores.dtype)
sim_matrix[title_idx[:,0], title_idx[: ,1]] = sim_scores

list_item =[]
list_other_item=[]
for i,key in enumerate(dict_sim_scores):

    list_item.append(str(key[0]))
    list_other_item.append(str(key[1]))

list_item = np.unique(list_item)
list_other_item = np.unique(list_other_item)


fig = plt.figure('Similarity Scores')
ax = fig.add_subplot(111)
cax = ax.matshow(sim_matrix, interpolation='nearest')
fig.colorbar(cax)
ax.set_xticks(np.arange(len(list_item)))
ax.set_yticks(np.arange(len(list_other_item)))
#ax.set_xticklabels(list_item,rotation=40,fontsize='x-small',ha='right')
ax.xaxis.tick_bottom()
ax.set_xticklabels(list_item,rotation=40,fontsize='x-small')
ax.set_yticklabels(list_other_item,fontsize='x-small') 
plt.show()

关于如何创造这种形象的任何想法?

1 个答案:

答案 0 :(得分:4)

我会使用pandas取消堆叠数据,然后seaborn绘制结果。以下是使用您提供的词典的示例:

import pandas as pd
ser = pd.Series(list(dict_sim_scores.values()),
                  index=pd.MultiIndex.from_tuples(dict_sim_scores.keys()))
df = ser.unstack().fillna(0)
df.shape
# (10, 27)

现在使用seaborn热图功能绘制结果:

import seaborn as sns
sns.heatmap(df);

enter image description here