使用Spark驱动程序从s3中存在的图像中提取特征导致错误

时间:2018-11-04 09:52:01

标签: keras pyspark deep-learning feature-extraction

我有一个Pyspark应用程序,它将基本上在s3处下载图像文件,并使用keras从这些图像文件中提取特征。 这是整个流程:-

1. Download images from s3 using.
    s3_files_rdd = sc.binaryFiles(s3_path) ## [('s3n://..',bytearray)]

2. Then convert the above byte inside the rdd to image object.

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from io import BytesIO

def convert_binary_to_image_obj(obj):
    img = mpimg.imread(BytesIO(obj), 'jpg')
    return img


images_rdd = s3_files_rdd.map(lambda x: (x[0], convert_binary_to_image_obj(x[1])))

3. Now pass the images_rdd to another function to extract features using keras vgg16 model.

def initVGG16():
    model = VGG16(weights='imagenet', include_top=True)
    return Model(inputs=model.input, outputs=model.get_layer("fc2").output)

def extract_features(img):
    img_data = image.img_to_array(img)
    img_data = np.expand_dims(img_data, axis=0)
    img_data = preprocess_input(img_data)
    vgg16_feature = initVGG16().predict(img_data)[0]
    return vgg16_feature


features_rdd = images_rdd.map(lambda x: (x[0], extract_features(x[1])))

但是当我尝试应用时,它会显示以下错误消息:-

ValueError: Error when checking input: expected input_1 to have shape (224, 224, 3) but got array with shape (300, 200, 3)

    at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:330)
    at org.apache.spark.api.python.PythonRunner$$anon$1.read(PythonRunner.scala:470)
    at org.apache.spark.api.python.PythonRunner$$anon$1.read(PythonRunner.scala:453)
    at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:284)
    at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
    at scala.collection.Iterator$class.foreach(Iterator.scala:893)

我知道这里的错误是在extract_features函数中,它期望图像的大小为224,224,3,但现在不是这种情况。因为我没有将映像保存到本地磁盘。从s3下载后,我将直接使用matplotlib lib转换为图像对象。

如何解决此问题?我基本上想要从s3下载图像,然后在内存中像image.load_img(image_path, target_size=(224, 224))函数一样调整其大小,然后将该图像对象传递给我的extract_features函数。

0 个答案:

没有答案