Keras:将所有图像保存在一个目录中

时间:2017-11-21 11:02:48

标签: python machine-learning keras deep-learning generator

我使用存储在单个目录中的许多图像(10M +)(每个类没有子文件夹)并使用pandas DataFrame来跟踪类标签。图像数量不适合内存,所以我必须从磁盘读取小批量。到目前为止,我已经使用了Keras .flow_from_directory(),但它要求我将图像移动到每个类的一个子文件夹(以及每个列车/验证拆分)。它工作得很好,但是当我想使用不同的图像子集并以各种方式定义类时,它变得非常不实用。有没有人有一个替代策略,使用数据库(例如pandas.DataFrame)来跟踪微型计算机的读取而不是将图像移动到子文件夹?

1 个答案:

答案 0 :(得分:2)

您需要自定义数据生成器。

public class CustomView extends View {
        private Paint paint;
        Context app_context;

        public CustomView(Context context) {
            super(context);
            paint = new Paint();
            paint.setColor(Color.GRAY);
            app_context = context;
        }

        @Override
        protected void onMeasure(int widthMeasureSpec, int heightMeasureSpec)
        {
            int width = 1000;
            int height = 1200;
            setMeasuredDimension(width, height);
        }

        @Override
        public void onDraw(Canvas canvas)
        {
            canvas.drawColor(Color.BLUE);

            paint.setColor(Color.GRAY);
            paint.setTextSize(50);
            canvas.drawText("Bilal",45f,55f, paint);

            float[] xStopPointsLine1 = new float[]{0f,200.1f,450.5f,650f,850f};
            float[] yStopPointsLine1 = new float[]{100f,380f,540f,400f,720f};
            float[] xStopPointsLine2 = new float[]{20f,170.1f,350.5f,480f,650f};
            float[] yStopPointsLine2 = new float[]{200f,480f,240f,600f,380f};

            for(int i=0; i<yStopPointsLine1.length; i++){
                paint.setColor(Color.GRAY);
                paint.setStrokeWidth(8);
                if(i==0){
                    canvas.drawLine(xStopPointsLine1[i],yStopPointsLine1[i],xStopPointsLine1[i+1],yStopPointsLine1[i+1], paint);
                    paint.setColor(Color.GREEN);
                    canvas.drawLine(xStopPointsLine2[i],yStopPointsLine2[i],xStopPointsLine2[i+1],yStopPointsLine2[i+1], paint);
                    paint.setColor(Color.RED);
                    canvas.drawCircle(xStopPointsLine1[i], yStopPointsLine1[i], 12, paint);
                    paint.setColor(Color.GREEN);
                    canvas.drawCircle(xStopPointsLine2[i], yStopPointsLine2[i], 12, paint);
                }
                else if(i>0 && i<yStopPointsLine1.length-1)
                {
                    canvas.drawLine(xStopPointsLine1[i],yStopPointsLine1[i],xStopPointsLine1[i+1],yStopPointsLine1[i+1], paint);
                    paint.setColor(Color.RED);
                    canvas.drawCircle(xStopPointsLine1[i], yStopPointsLine1[i], 12, paint);
                    paint.setColor(Color.GREEN);
                    canvas.drawLine(xStopPointsLine2[i],yStopPointsLine2[i],xStopPointsLine2[i+1],yStopPointsLine2[i+1], paint);
                    paint.setColor(Color.GREEN);
                    canvas.drawCircle(xStopPointsLine2[i], yStopPointsLine2[i], 12, paint);
                }
                else if(i == yStopPointsLine1.length-1){
                    paint.setColor(Color.RED);
                    canvas.drawCircle(xStopPointsLine1[i], yStopPointsLine1[i], 12, paint);
                    paint.setColor(Color.GREEN);
                    canvas.drawCircle(xStopPointsLine2[i], yStopPointsLine2[i], 12, paint);
                }
            }
        }
    }

然后你只能使用id(或图像名称)numpy数组调用生成器,如下所示:

import numpy as np
import cv2
def batch_generator(ids):
    while True:
        for start in range(0, len(ids), batch_size):
            x_batch = []
            y_batch = []
            end = min(start + batch_size, len(ids))
            ids_batch = ids[start:end]
            for id in ids_batch:
                img = cv2.imread(dpath+'train/{}.jpg'.format(id))
                #img = cv2.resize(img, (224, 224), interpolation = cv2.INTER_AREA)
                labelname=df_train.loc[df_train.id==id,'column_name'].values
                labelnum=classes.index(labelname)
                x_batch.append(img)
                y_batch.append(labelnum)
            x_batch = np.array(x_batch, np.float32) 
            y_batch = to_categorical(y_batch,120) 
            yield x_batch, y_batch