PySpark:在功能图

时间:2017-11-07 10:36:08

标签: pyspark classification pickle rdd

我正在努力学习使用Pyspark。 我在Python3中使用spark-2.2.0- 我现在面临着一个问题,我无法找到它的来源。 我的项目是调整由数据科学家编写的算法进行分发。下面的代码是我必须用来从图像中提取特征的东西,我必须调整它以提取pyspark的特征。

import json

import sys


# Dependencies can be installed by running:

# pip install keras tensorflow h5py pillow


# Run script as:

# ./extract-features.py images/*.jpg


from keras.applications.vgg16 import VGG16

from keras.models import Model

from keras.preprocessing import image

from keras.applications.vgg16 import preprocess_input

import numpy as np


def main():

    # Load model VGG16 as described in https://arxiv.org/abs/1409.1556

    # This is going to take some time...

    base_model = VGG16(weights='imagenet')

    # Model will produce the output of the 'fc2'layer which is the penultimate neural network layer

    # (see the paper above for mode details)

    model = Model(input=base_model.input, output=base_model.get_layer('fc2').output)


    # For each image, extract the representation

    for image_path in sys.argv[1:]:

        features = extract_features(model, image_path)

        with open(image_path + ".json", "w") as out:

            json.dump(features, out)


def extract_features(model, image_path):

    img = image.load_img(image_path, target_size=(224, 224))

    x = image.img_to_array(img)

    x = np.expand_dims(x, axis=0)

    x = preprocess_input(x)


    features = model.predict(x)

    return features.tolist()[0]


if __name__ == "__main__":

    main()

我写了守则的开头:

rdd = sc.binaryFiles(PathImages)
base_model = VGG16(weights='imagenet')    
model = Model(input=base_model.input, output=base_model.get_layer('fc2').output)
rdd2 = rdd.map(lambda x : (x[0], extract_features(model, x[0][5:]))) 
rdd2.collect()[0]

当我尝试提取该功能时。有一个错误。

  

〜/ Code / spark-2.2.0-bin-hadoop2.7 / python / pyspark / cloudpickle.py in   save_file(self,obj)       623 return self.save_reduce(getattr,(sys,'stderr'),obj = obj)       624如果obj是sys.stdin:    - > 625提升pickle.PicklingError(“不能腌制标准输入”)       626 if hasattr(obj,'isatty')和obj.isatty():       627引发pickle.PicklingError(“无法腌制映射到tty对象的文件”)

     

PicklingError:无法腌制标准输入

我尝试了多件事,这是我的第一个结果。我知道错误来自方法extract_features中的以下行:

features = model.predict(x)

当我尝试从地图函数或pyspark中运行此行时,此工作正常。 我认为问题来自对象“模型”和他的序列化与pyspark。 也许我没有用pyspark分发这个好方法,如果你有任何提示可以帮助我,我会接受它们。

提前致谢。

0 个答案:

没有答案