导入图像(.jpg)数据集的正确方法Keras,Pandas

时间:2019-03-01 06:27:37

标签: python pandas keras

最近几天,我在一个机器学习项目中工作。

我有一个图像数据集(.jpg)。我有超过50万张图片。

此外,我有一个CSV文件,其中包含图像的名称(每个图像都有唯一的名称)和两个标签(目标值)。这两个目标标签完全不同,彼此之间没有任何关系。

我将为两个目标标签使用模型分离的模型。

我的解决方案

  1. 将所有内容转换为大CSV文件。类似于CSV格式的MNIST数据集。这种方法的问题是图像尺寸大(我需要大图像)和三个通道(彩色图像)。因此CSV文件的大小变得非常大。

  2. 使用Keras ImageDataGenerator & flow_from_directory类。如前所述,我有两个标签(目标),因此需要创建同一数据集的两个副本(因为flow_from_directory需要特定的数据结构)

现在,我的两种解决方案都可以使用,但是有特定的问题。

我想知道是否还有其他导入数据集的方法。这样我就可以避免上述问题。

我正在为此项目使用Keras,Pandas,Numpy和Sklearn。我也可以自由使用任何其他库。

我没有在此问题上附加任何解决方案代码。请让我知道是否需要。

Thnx 阿比舍克

1 个答案:

答案 0 :(得分:0)

您提到了熊猫,但我认为这不能解决您的问题。

您为什么不编写自己的解决方案?

您可以尝试实现scikit-learn的方式。

Recognizing hand-written digits 为例,

示例代码

# Author: Gael Varoquaux <gael dot varoquaux at normalesup dot org>
# License: BSD 3 clause

import matplotlib.pyplot as plt    
# Import datasets, classifiers and performance metrics
from sklearn import datasets, svm, metrics

# The digits dataset 
digits = datasets.load_digits() # <--- right here

images_and_labels = list(zip(digits.images, digits.target))
for index, (image, label) in enumerate(images_and_labels[:4]):
    plt.subplot(2, 4, index + 1)
    plt.axis('off')
    plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
    plt.title('Training: %i' % label)

n_samples = len(digits.images)
data = digits.images.reshape((n_samples, -1))

classifier = svm.SVC(gamma=0.001)

classifier.fit(data[:n_samples // 2], digits.target[:n_samples // 2])

expected = digits.target[n_samples // 2:]
predicted = classifier.predict(data[n_samples // 2:])

print("Classification report for classifier %s:\n%s\n"
      % (classifier, metrics.classification_report(expected, predicted)))
print("Confusion matrix:\n%s" % metrics.confusion_matrix(expected, predicted))

images_and_predictions = list(zip(digits.images[n_samples // 2:], predicted))
for index, (image, prediction) in enumerate(images_and_predictions[:4]):
    plt.subplot(2, 4, index + 5)
    plt.axis('off')
    plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
    plt.title('Prediction: %i' % prediction)

plt.show()

源代码

scikit-learn构建了一个名为dataset的模块,仅用于加载MNIST之类的不同数据集(图像和标签)。

阅读dataset.load_digits()的源代码也将很有趣

short 整洁。希望您能找到更好的解决方案。