按数据集标签定义标记和颜色设置

时间:2016-07-18 07:26:38

标签: python matplotlib

我认为我的目标很简单。但我还没有在某个地方查看我隐藏的错误。

我正在学习PLA(感知器线性算法),并尝试用Python语言实现它。

算法本身已经解决了。然后,我想通过算法绘制调整过程。

数据集

dataset = np.array([
((1, -0.4, 0.3), -1),
((1, -0.3, -0.1), -1),
((1, -0.2, 0.4), -1),
((1, -0.1, 0.1), -1),
((1, 0.9, -0.5), 1),
((1, 0.7, -0.9), 1),
((1, 0.8, 0.2), 1),
((1, 0.2, -0.6), 1)])   
  

我想用标签绘制不同风格的散点(" -1"或者#34; 1"在这个例子中)

所以,这就是我编码的内容:

def marker_choice(s):
    if s == 1:
        marker = "o"
    else:
        marker = "x"  
    return marker

def color_choice(s):
    if s == 1:
        color = "r"
    else:
        color = "b"  
    return color  


ps = [v[0] for v in dataset]
label = [v[1] for v in dataset]

fig = plt.figure()
ax = fig.add_subplot()  

ax.scatter([v[1] for v in ps], [v[2] for v in ps], s=80, \
           c=color_choice(v for v in np.array(label)), 
           marker=marker_choice(v for v in np.array(label)),
       )

enter image description here

目标

enter image description here

1 个答案:

答案 0 :(得分:1)

问题是您的一行循环不会产生您所需的输出。如果您只是运行它们来测试它们的输出,那么结果就是颜色:'b'和标记:'x'这就解释了为什么输出就是这样的。

下面的解决方案不使用一个行循环,但确实生成了所需的图形。需要注意的一点是,下面代码输出上的标记是问题中所需输出的错误方法。这只是改变marker_choice(s)函数并将if s==1更改为if s == -1的情况。

import matplotlib.pyplot as plt
import numpy as np

dataset = np.array([
((1, -0.4, 0.3), -1),
((1, -0.3, -0.1), -1),
((1, -0.2, 0.4), -1),
((1, -0.1, 0.1), -1),
((1, 0.9, -0.5), 1),
((1, 0.7, -0.9), 1),
((1, 0.8, 0.2), 1),
((1, 0.2, -0.6), 1)])

def marker_choice(s):
    if s == 1:       # change to -1 here to get the markers the other way around
        marker = "o"
    else:
        marker = "x"
    return marker

def color_choice(s):
    if s == 1:
        color = "r"
    else:
        color = "b"
    return color

ps = [v[0] for v in dataset]
label = [v[1] for v in dataset]

str_label = []
str_marker = []

for i in (label):
    a = color_choice(label[i])
    str_label.append(a)
    b = marker_choice(label[i])
    str_marker.append(b)

fig, ax = plt.subplots()

for i in range (len(ps)):
    data = ps[i]
    data_x = data[1]
    data_y = data[2]
    ax.scatter(data_x,data_y, s=80, color = str_label[i], marker=str_marker[i])

plt.show()

这会产生以下输出:

enter image description here

注意:我还没有测试过此代码的性能与原始代码的对比情况。