自定义图层仅适用于1个批处理大小

时间:2020-07-24 08:14:29

标签: tensorflow keras tf.keras

我试图在keras中创建一个自定义图层以用于我的VGG模型,但是在打印模型后出现问题。它运作良好并显示模型。但是,当我开始以大于1的批次大小进行训练时,会出现一个错误,如该主题中所示。您能在这方面请任何人帮助我吗?预先感谢。

# include packages
import random

import keras
from numpy.random import seed

seed(1)
random.seed(1)

from keras import backend as K
from keras.models import Model, Sequential
from keras.layers import Flatten, Dense, Dropout, Permute, Lambda, Conv2D, multiply, Concatenate, concatenate, \
    RepeatVector
from keras.preprocessing.image import ImageDataGenerator
import numpy as np
import os
import tensorflow as tf

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

from keras import optimizers

data_list = os.listdir('D:/xys/X_Ray_Image_DataSet/splits/f2/train')

print(len(data_list))

# Custom layer in our experiment
class Neighbours(keras.layers.Layer):
    def __init__(self):
        super(Neighbours, self).__init__()


    def call(self, vector_):
        row_max, col_max = 9, 9
        vector = vector_[0]
        total_stack_array = []
        row_tensor = None
        for row_index in range(9):
            hit_count = 0
            aggregate = K.variable(tf.zeros(shape=(8, 512), dtype=tf.float32))
            zeros = K.variable(tf.zeros(shape=(1, 512), dtype=tf.float32))
            stack_array = []
            for col_index in range(9):
                if (row_index - 1) >= 0:
                    # print("first ", row_index - 1, col_index)
                    # aggregate[0] = vector[row_index - 1][col_index]
                    aggregate[0].assign(vector[row_index - 1][col_index])
                    hit_count += 1
                else:
                    # aggregate[0] = K.zeros(shape=(1, 512))
                    aggregate[0].assign(zeros)

                if (col_index - 1) >= 0:
                    # print("second ", row_index, col_index - 1)
                    # aggregate[1] = vector[row_index][col_index - 1]
                    aggregate[1].assign(vector[row_index][col_index] - 1)
                    hit_count += 1

                else:
                    # aggregate[1] = K.zeros(shape=(1, 512))
                    aggregate[1].assign(zeros)

                if (row_index - 1) >= 0 and (col_index - 1) >= 0:
                    # print("third ", row_index - 1, col_index - 1)
                    # aggregate[2] = vector[row_index - 1][col_index - 1]
                    aggregate[2].assign(vector[row_index - 1][col_index - 1])
                    hit_count += 1

                else:
                    # aggregate[2] = K.zeros(shape=(1, 512))
                    aggregate[2].assign(zeros)

                if (row_index + 1) < row_max and (col_index - 1) >= 0:
                    # print("fourth ", row_index + 1, col_index - 1)
                    # aggregate[3] = vector[row_index + 1][col_index - 1]
                    aggregate[3].assign(vector[row_index + 1][col_index - 1])
                    hit_count += 1
                else:
                    # aggregate[3] = K.zeros(shape=(1, 512))
                    aggregate[3].assign(zeros)

                if (row_index + 1) < row_max:
                    # print("fifth ", row_index + 1, col_index)
                    # aggregate[4] = vector[row_index + 1][col_index]
                    aggregate[4].assign(vector[row_index + 1][col_index])
                    hit_count += 1
                else:
                    # aggregate[4] = K.zeros(shape=(1, 512))
                    aggregate[4].assign(zeros)

                if (col_index + 1) < col_max:
                    # print("Sixth ", row_index, col_index + 1)
                    # aggregate[5] = vector[row_index][col_index + 1]
                    aggregate[5].assign(vector[row_index][col_index + 1])
                    hit_count += 1

                else:
                    # aggregate[5] = K.zeros(shape=(1, 512))
                    aggregate[5].assign(zeros)
                    hit_count += 1

                if (row_index - 1) >= 0 and (col_index + 1) < col_max:
                    # print("Seventh ", row_index - 1, col_index + 1)
                    # aggregate[6] = vector[row_index - 1][col_index + 1]
                    aggregate[6].assign(vector[row_index - 1][col_index + 1])
                    hit_count += 1

                else:
                    # aggregate[6] = K.zeros(shape=(1, 512))
                    aggregate[6].assign(zeros)

                if (row_index + 1) < row_max and (col_index + 1) < col_max:
                    # print("Eighth ", row_index + 1, col_index + 1)
                    # aggregate[7] = vector[row_index + 1][col_index + 1]
                    aggregate[7].assign(vector[row_index + 1][col_index + 1])
                    hit_count += 1

                else:
                    # aggregate[7] = K.zeros(shape=(1, 512))
                    aggregate[7].assign(zeros)
                avg = K.sum(aggregate, axis=0) / (hit_count + 1)
                stack_array.append(avg)

            total_stack_array.append(K.stack(stack_array))
        modified_vector = K.stack([K.stack(total_stack_array)])

        # print(modified_vector)
        aux_array = tf.placeholder_with_default(modified_vector, tf.TensorShape([None, 9, 9, 512]))
        return aux_array


root_path = "D:/xys/X_Ray_Image_DataSet/splits/f2/"
DATASET_PATH = root_path + 'train'
test_dir = root_path + 'val'
IMAGE_SIZE = (150, 150)
NUM_CLASSES = len(data_list)
BATCH_SIZE = 1  # try reducing batch size or freeze more layers if your GPU runs out of memory
NUM_EPOCHS = 1
LEARNING_RATE = 0.0001
# Train datagen here is a preprocessor
train_datagen = ImageDataGenerator(rescale=1. / 255,
                                   rotation_range=50,
                                   width_shift_range=0.2,
                                   height_shift_range=0.2,
                                   shear_range=0.25,
                                   zoom_range=0.1,
                                   channel_shift_range=20,
                                   horizontal_flip=True,
                                   vertical_flip=True,
                                   validation_split=0.2,
                                   fill_mode='constant')

# For multiclass use categorical n for binary us
train_batches = train_datagen.flow_from_directory(DATASET_PATH,
                                                  target_size=IMAGE_SIZE,
                                                  shuffle=True,
                                                  batch_size=BATCH_SIZE,
                                                  subset="training",
                                                  seed=42,
                                                  class_mode="categorical"
                                                  # For multiclass use categorical n for binary use binary
                                                  )

valid_batches = train_datagen.flow_from_directory(DATASET_PATH,
                                                  target_size=IMAGE_SIZE,
                                                  shuffle=True,
                                                  batch_size=BATCH_SIZE,
                                                  subset="validation",
                                                  seed=42,
                                                  class_mode="categorical"
                                                  # For multiclass use categorical n for binary use binary

                                                  )

from keras.applications import VGG16, VGG19

conv_base = VGG16(weights='imagenet',
                  include_top=False,
                  input_shape=(150, 150, 3))

# conv_base.trainable = True

# layer_base=conv_base.output
conv_base = Model(inputs=conv_base.inputs, outputs=conv_base.get_layer('block4_pool').output)


def Trunc_VGG(conv_base):
    conv_base.trainable = True
    layer_4th = conv_base.get_layer('block4_pool').output
    neighbour_layer = Neighbours()
    neighbor_module = neighbour_layer(layer_4th)
    print(neighbor_module)
    flatten = Flatten()(neighbor_module)
    dropout = Dropout(0.5)(flatten)
    dense = Dense(256, activation='relu')(dropout)
    pred = (Dense(3, activation='softmax'))(dense)
    model = Model(inputs=conv_base.inputs, outputs=pred)
    return model


model = Trunc_VGG(conv_base)  # load our model
model.compile(loss='categorical_crossentropy',  # for multiclass use categorical_crossentropy
              optimizer=optimizers.Adam(lr=LEARNING_RATE),
              metrics=['acc'])

print(model.summary())
#
# FIT MODEL
print(len(train_batches))
print(len(valid_batches))

STEP_SIZE_TRAIN = train_batches.n // train_batches.batch_size
STEP_SIZE_VALID = valid_batches.n // valid_batches.batch_size

result = model.fit_generator(train_batches,
                             steps_per_epoch=STEP_SIZE_TRAIN,
                             validation_data=valid_batches,
                             validation_steps=STEP_SIZE_VALID,
                             epochs=NUM_EPOCHS,
                             )
import matplotlib

matplotlib.use('TKAgg')
import matplotlib.pyplot as plt

print(result.history.keys())
fig1 = plt.figure(1)
# summarize history for accuracy
plt.plot(result.history['acc'])
plt.plot(result.history['val_acc'])
plt.title('Model Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Training', 'Validation'], loc='upper left')
plt.show()
# plt.savefig(root_path + '/' + 'acc.png')

# summarize history for loss
fig2 = plt.figure(2)
plt.plot(result.history['loss'])
plt.plot(result.history['val_loss'])
plt.title('Model Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Training', 'Validation'], loc='upper left')
plt.show()
# plt.savefig(root_path + '/' + 'loss.png')

test_datagen = ImageDataGenerator(rescale=1. / 255)
eval_generator = test_datagen.flow_from_directory(test_dir, target_size=IMAGE_SIZE, batch_size=1,
                                                  shuffle=False, seed=42, class_mode="categorical")
eval_generator.reset()
eval_generator.reset()
x = model.evaluate_generator(eval_generator,
                             steps=np.ceil(len(eval_generator)),
                             use_multiprocessing=False,
                             verbose=1,
                             workers=1,
                             )

print('Test loss:', x[0])
print('Test accuracy:', x[1])

0 个答案:

没有答案