是什么导致此Keras图像处理错误?

时间:2019-04-20 15:33:48

标签: python keras neural-network anaconda conv-neural-network

我正在尝试训练CNN,但程序每次都会失败,但会在随机位置进行。引发的错误是integ = 3 #number of sequences evenList = ['GAAGCTCG', 'AAATTT', 'CTCTAGGAC'] oddList = ['CCTCGGGA', 'GGGCCC', 'GAGTACCTG'] def matchList(evenList, oddList, integ): indexElement = 0 indexList = 0 totalIndexSeq = [] at_List = ['AT', 'TA', 'at', 'ta'] gc_List = ['GC', 'CG', 'gc', 'cg'] for x in evenList: indexedSeq = '' for y in x: if y + oddList[indexList][indexElement] in gc_List: indexedSeq += str(indexElement) indexElement += 1 elif y + oddList[indexList][indexElement] in gc_List: indexedSeq += str(indexElement) indexElement += 1 elif y + oddList[indexList][indexElement] in at_List: indexedSeq += str(indexElement) indexElement += 1 elif y + oddList[indexList][indexElement] in at_List: indexedSeq += str(indexElement) indexElement += 1 else: indexedSeq += "+" indexElement += 1 indexList += 1 indexElement -= indexElement totalIndexSeq.append(indexedSeq) return (totalIndexSeq) #This returns the positions with mismatched pairs omitted by a "+" # When you print 'totalIndexSeq' #['0+234+6+'] #['0+234+6+', '++++++'] #['0+234+6+', '++++++', '012++5678'] 。仅供参考,这是处理上一步中的增强数据。代码如下:

OSError: image file is truncated (15 bytes not processed)

CNN开始迭代,但是通常在第2阶段,我得到了这个被截断的图像。虽然所有图像都以相同的方式进行了增强。有人有主意吗?

整个回溯是:

import os

from keras import backend as K
from keras.layers import Activation, Dense, Dropout
from keras.layers import Conv2D, Flatten, MaxPooling2D
from keras.models import Sequential
from keras.preprocessing.image import ImageDataGenerator

import matplotlib.pyplot as plt

cwd = os.getcwd()

# dimensions of our images.
img_width, img_height = 150, 150

train_data_dir = (str(cwd) + r'\augmented\train\\')
validation_data_dir = (str(cwd) + r'\augmented\validation\\')
nb_train_samples = 1000
nb_validation_samples = 500
epochs = 20
batch_size = 10

if K.image_data_format() == 'channels_first':
    input_shape = (3, img_width, img_height)
else:
    input_shape = (img_width, img_height, 3)

model = Sequential()
model.add(Conv2D(32, (3, 3), input_shape=input_shape))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(32, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(64, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Flatten())
model.add(Dense(64))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(1))
model.add(Activation('sigmoid'))

model.compile(loss='binary_crossentropy',
              optimizer='rmsprop',
              metrics=['accuracy'])

# this is the augmentation configuration we will use for training
train_datagen = ImageDataGenerator(
    rescale=1. / 255,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True)

# this is the augmentation configuration we will use for testing:
# only rescaling
test_datagen = ImageDataGenerator(rescale=1. / 255)

train_generator = train_datagen.flow_from_directory(
    train_data_dir,
    target_size=(img_width, img_height),
    batch_size=batch_size,
    class_mode='binary')

validation_generator = test_datagen.flow_from_directory(
    validation_data_dir,
    target_size=(img_width, img_height),
    batch_size=batch_size,
    class_mode='binary')

history = model.fit_generator(
    train_generator,
    steps_per_epoch=nb_train_samples // batch_size,
    epochs=epochs,
    validation_data=validation_generator,
    validation_steps=nb_validation_samples // batch_size)

model.save_weights('chips.h5')

acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = range(1, len(acc) + 1)

plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and validation accuracy')
plt.legend()

plt.figure()

plt.plot(epochs, loss, 'bo', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.legend()

plt.show()

1 个答案:

答案 0 :(得分:1)

您的问题似乎是不同大小的图像,由于枕头中的设置,您的程序似乎崩溃了。这是枕头模块的official code,如果您搜索被截断的内容,您会发现为什么会出现此错误。 Here投票最高的答案提供了代码,以防止截断的图像导入错误。

您也可以删除图像,here是遇到问题的人,他们会简单地整理出所有小于50kB的图像。 希望这会有所帮助。