多个GPU中的ImageDataGenerator

时间:2020-06-04 19:14:52

标签: python tensorflow keras

以下代码中的错误:

import json
import os
import numpy as np 
import pandas as pd
import sys
import cv2
import tensorflow as tf
from tensorflow import keras
os.environ["CUDA_VISIBLE_DEVICES"]="0,1"

config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
BATCH_SIZE = 10
with tf.compat.v1.Session(config=config) as sess:

    patch_out = "data/train"
    patch_out1 = "data/validation"

    def create_model():
        model = keras.models.Sequential() 
        model.add(keras.layers.Conv2D(32, (3, 3), input_shape = (64, 64, 3), activation = 'relu'))
        model.add(keras.layers.MaxPooling2D(pool_size = (2, 2)))
        model.add(keras.layers.Conv2D(32, (3, 3), activation = 'relu'))
        model.add(keras.layers.MaxPooling2D(pool_size = (2, 2)))
        model.add(keras.layers.Flatten())
        model.add(keras.layers.Dense(units = 128, activation = 'relu'))
        model.add(keras.layers.Dense(units = 1, activation = 'sigmoid'))
        model.compile(optimizer = 'adam', loss = 'binary_crossentropy', metrics = ['accuracy'])
        return model  
    #-------------------------
    #save the contents of the complete arrays
    np.set_printoptions(threshold=sys.maxsize)
    #-------------------------------------
    # Create a MirroredStrategy.
    strategy = tf.distribute.MirroredStrategy()
    print("Number of devices: {}".format(strategy.num_replicas_in_sync))
    # Open a strategy scope.
    with strategy.scope():
        train_datagen = keras.preprocessing.image.ImageDataGenerator(rescale = 1./255,
                                   shear_range = 0.2,
                                   zoom_range = 0.2,
                                   horizontal_flip = True)
        validation_datagen = keras.preprocessing.image.ImageDataGenerator(rescale = 1./255)
        training_set = train_datagen.flow_from_directory(patch_out,
                                                 target_size = (64, 64),
                                                 batch_size = 32,
                                                 class_mode = 'binary')
        validation_set = validation_datagen.flow_from_directory(patch_out1,
                                                        target_size = (64, 64),
                                                        batch_size = 32,
                                                        class_mode = 'binary')
        model = create_model()
        model.fit(training_set,
                         steps_per_epoch = 8000,
                         epochs = 5,
                         validation_data = validation_set,
                         validation_steps = 2000)

'''追踪(最近通话最近一次): 文件“ test_CNN_v1.py”,第82行,在 validate_steps = 2000) 文件“ /root/anaconda3/envs/tf-gpu/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py”,第728行,适合 use_multiprocessing = use_multiprocessing) 文件“ /root/anaconda3/envs/tf-gpu/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_distributed.py”,第619行,适合 epochs = epochs) _distribution_standardize_user_data中第2315行的文件“ /root/anaconda3/envs/tf-gpu/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py” 断言isinstance(x,dataset_ops.DatasetV2) AssertionError'''

需要帮助。 谢谢

0 个答案:

没有答案