使用keras.utils.to_categorical的OneHotEncoding无法转换为全长类大小的numpy数组

时间:2019-02-27 21:31:29

标签: python tensorflow keras pyspark

我正在分布式环境中跟踪keras机器学习模型,因此我需要在集群中分发数据。为此,我正在使用TensorflowOnSpark库。以下是卡住我的小段代码。

def generate_rdd_data(dataRDD):
    while True:
        feature_vector = []
        lbls = []
        for item in dataRDD:
            #record = item[0]
            feature_vector.append(item[0])
            lbls.append(keras.utils.to_categorical(item[1], num_classes = 14))
        features = numpy.array(feature_vector).astype('float32')
        #labels = numpy.array(lbls).astype('float32')
        return (features, labels)

此方法工作正常,但标签的keras OneHotEncoding转换失败。下面是错误

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<stdin>", line 10, in generate_rdd_data
ValueError: setting an array element with a sequence.

这是lbls的结果:

>>> lbls.take(4)
[array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.], dtype=float32), array([1.], dtype=float32), array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.], dtype=float32), array([1.], dtype=float32)]

但是,尽管我在如下所示的虚拟方法中执行了相同的逻辑,但效果很好

def temp(data):
    no = []
    for item in data:
            no.append(keras.utils.to_categorical(item, num_classes = 15))
    return no

temp方法的输入为

a = [14,13,2,5,1,0]

输出符合预期

array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
      dtype=float32)

添加额外的信息。对于第一种方法,输入是带有数据集及其对应标签的rdd压缩文件。

rdd示例的外观如下:

[([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], 12.0)]

如您所见,rdd有两个部分,一个是feature_vector,第二个是label,因此第一个方法以numpy格式进行特征和标签。

1 个答案:

答案 0 :(得分:0)

ValueError: setting an array element with a sequence.错误是由于尝试将非多维数组交换列表转换为数组而引起的

也许您有像这样的列表[2,[3,4]],或者您有字符串和整数对象[1,2,"r"]

并且该函数不需要一一转换:

def generate_rdd_data(dataRDD):
    return dataRDD,keras.utils.to_categorical(dataRDD,num_classes=14)