在模型keras训练期间fit_generator中的问题

时间:2020-11-02 02:42:24

标签: python tensorflow keras cnn

我是Python和CNN的初学者。 我确实编写了一个简单的代码来训练两个班级之间的模型 我的文件夹有2个用于训练的文件夹和2个用于验证的文件夹

import keras,os
from keras.models import Sequential
from keras.layers import Conv2D
from keras.layers import MaxPooling2D
from keras.layers import Flatten
from keras.layers import Dense
from keras.preprocessing.image import ImageDataGenerator

Classifier = Sequential()
classifier.add(Conv2D(32, (3, 3), input_shape = (64, 64, 3), activation = 'relu'))
classifier.add(MaxPooling2D(pool_size = (2, 2))) 
classifier.add(Flatten()) 
classifier.add(Dense(units = 128, activation = 'relu')) 
classifier.add(Dense(units = 1, activation = 'sigmoid')) 
classifier.compile(optimizer = 'adam', loss = 'binary_crossentropy' , metrics = ['raccuracy'])

train_datagen = ImageDataGenerator(rescale = 1./255, shear_range = 0.2,zoom_range = 0.2,horizontal_flip = True)
test_datagen = ImageDataGenerator(rescale = 1./255)
training_set = train_datagen.flow_from_directory(r'C:\Users\user1\Documents\code\Test', target_size = (244, 244), color_mode="rgb", batch_size = 32, class_mode = 'binary', shuffle=True,)
test_set = test_datagen.flow_from_directory(r'C:\Users\user1\Documents\code\Valid', target_size = (244, 244), color_mode="rgb", batch_size = 32, class_mode = 'binary', shuffle=True,)
STEP_SIZE_TRAIN=training_set.n    #train_generator.batch_size
STEP_SIZE_VALID=test_set.n    #valid_generator.batch_size
classifier.fit_generator(generator = training_set, steps_per_epoch = STEP_SIZE_TRAIN,epochs = 5,validation_data = test_set,validation_steps = STEP_SIZE_VALID)

一切顺利,但是我有一个错误。我不知道是什么问题。它会在第一个纪元立即开始,如下所示

Found 518 images belonging to 2 classes.
Found 40 images belonging to 2 classes.
Epoch 1/5
TypeError                                 Traceback (most recent call last) <ipython-input-21-c93c80bb7785> in <module>
     20 STEP_SIZE_VALID=test_set.n    #valid_generator.batch_size
     21 
---> 22 classifier.fit_generator(generator = training_set, 
     23                          steps_per_epoch = STEP_SIZE_TRAIN,
     24                          epochs = 5,

~\anaconda3\lib\site-packages\tensorflow\python\util\deprecation.py in new_func(*args, **kwargs)
    322               'in a future version' if date is None else ('after %s' % date),
    323               instructions)
--> 324       return func(*args, **kwargs)
    325     return tf_decorator.make_decorator(
    326         func, new_func, 'deprecated',


~\anaconda3\lib\site-packages\tensorflow\python\keras\engine\training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, validation_freq, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)    1813     """    1814    
_keras_api_gauge.get_cell('fit_generator').set(True)
-> 1815     return self.fit(    1816         generator,    1817         steps_per_epoch=steps_per_epoch,
     ~\anaconda3\lib\site-packages\tensorflow\python\keras\engine\training.py in _method_wrapper(self, *args, **kwargs)
    106   def _method_wrapper(self, *args, **kwargs):
    107     if not self._in_multi_worker_mode():  # pylint: disable=protected-access
--> 108       return method(self, *args, **kwargs)
    109 
    110     # Running inside `run_distribute_coordinator` already.  

~\anaconda3\lib\site-packages\tensorflow\python\keras\engine\training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)    1096                 batch_size=batch_size):   1097               callbacks.on_train_batch_begin(step)
-> 1098               tmp_logs = train_function(iterator)    1099               if data_handler.should_sync:    1100                 context.async_wait()

~\anaconda3\lib\site-packages\tensorflow\python\eager\def_function.py in __call__(self, *args, **kwds)
    778       else:
    779         compiler = "nonXla"
--> 780         result = self._call(*args, **kwds)
    781 
    782       new_tracing_count = self._get_tracing_count()

~\anaconda3\lib\site-packages\tensorflow\python\eager\def_function.py in _call(self, *args, **kwds)
    805       # In this case we have created variables on the first call, so we run the
    806       # defunned version which is guaranteed to never create variables.
--> 807       return self._stateless_fn(*args, **kwds)  # pylint: disable=not-callable
    808     elif self._stateful_fn is not None:
    809       # Release the lock early so that multiple threads can perform the call TypeError: 'NoneType' object is not callable.

Can any one please tell me what is the problem ?

0 个答案:

没有答案
相关问题