使用Python避免散点图中的垂直重叠

时间:2019-12-28 07:39:14

标签: python matplotlib visualization

我正在用3臂和伯纳利返回的土匪问题模拟epsilon-greedy算法。完成实验后,我想绘制每个手臂的收益,也就是说,如果每次都选择一个手臂,则它在相应时间上所取的值就是它的收益,对于其余两个手臂,该值将设置为-1。现在我想绘制一个手臂相对于该时隙的返回。(该值将为-1或1或0)

import matplotlib.pyplot as plt
import random
from scipy import stats
class greedy():
    def __init__(self,epsilon,n):
        self.epsilon=epsilon
        self.n=n
        self.value=[0,0,0]#estimator
        self.count=[0,0,0]
        self.prob=[0.4,0.6,0.8]
        self.greedy_reward=[[0 for x in range(10000)] for y in range(3)]
    def exploration(self,i):
        max_index=np.random.choice([0,1,2])
        r=np.random.choice([0,1],p=(1-self.prob[max_index],self.prob[max_index]))#do experiment, return r
        self.count[max_index]+=1
        for time in range(3):
            self.greedy_reward[time][i]=-1
        self.greedy_reward[max_index][i]=r
        self.value[max_index]=self.value[max_index]+(1/self.count[max_index])*(r-self.value[max_index])
    def exploitation(self,i):
        max_index=self.value.index(max(self.value))
        r=np.random.choice([0,1],p=(1-self.prob[max_index],self.prob[max_index]))
        self.count[max_index]+=1
        for time in range(3):
            self.greedy_reward[time][i]=-1
        self.greedy_reward[max_index][i]=r
        self.value[max_index]=self.value[max_index]+(1/self.count[max_index])*(r-self.value[max_index])
    def EE_choice(self,i):
        output=np.random.choice(# o is exploitation,1 is exploration
        [0,1], 
        p=[1-self.epsilon,self.epsilon]
        )
        if output==1:
            self.exploration(i);
        else:
            self.exploitation(i);
    def exp(self):
        for i in range(0,self.n):

然后,我们取出一只手臂的收益,例如arm3。

import matplotlib.pyplot as plt
x=[i for i in range(1,10001)]
arm_3_y=[0 for i in range(10000)]
for j in range(10000):
    arm_3_y[j]=greedy_1.greedy_reward[2][j]
plt.scatter(x,arm_3_y,marker='o')
plt.ylim([-1,1])
plt.show()

enter image description here

我们可以看到,一条垂直线上的所有点都重叠在一起,有什么办法可以避免这种情况?

1 个答案:

答案 0 :(得分:1)

取决于您要可视化的内容,可以有多种解决方法。 如果要查看分布但不需要单个点,请使用boxplot。它会向您显示平均值,四分位数和范围。

如果您确实需要散点图并查看点,并为数据中的每个点添加一些随机性(仅用于可视化过程),它将减少数据重叠的机会,并且您可以看到它们的重叠位置聚集。

def randomize(arr):
    stdev = .01*min(arr) #use any small value, small enough to not change the distribution
    return arr + np.random.randn(len(arr)) * stdev
plt.scatter(x,randomize(arm_3_y),marker='o')

它应该有助于可视化。尝试使用系数(此处为0.01)进行乱码处理,以增加抖动。