我正在研究皮肤组织的多类分割,我有一个包含 1500 个标签的列表,每个标签分为 4 个类 [0,1,2,3] ,我的目标是分别分割每个标签类。当我尝试训练我的模型我的模型出现此错误。我定义模型的方式有误吗?请给我建议。
ValueError: Input 0 of layer conv2d is incompatible with the layer: expected axis -1 of input shape to have value 1 but received input with shape (None, 224, 224, 1, 3)
我已经通过这种方法加载了标签集和与标签相关的相应类
imagePath = "/home/SKIN_TISSUE/TISSUE_SET_1500/"
x_train = []
y_train = []
for index, row in tqdm(x_train_list.iterrows(), total=x_train_list.shape[0]):
try:
loadedImage = plt.imread(imagePath + str(row[0]) + ".jpg")
x_train.append(loadedImage)
y_train.append(row[1])
except:
# Try with .png file format if images are not properly loaded
try:
loadedImage = plt.imread(imagePath + str(row[0]) + ".png")
x_train.append(loadedImage)
y_train.append(row[1])
except:
# Print file names whenever it is impossible to load image files
print(imagePath + str(row[0]))
x_test = []
y_test = []
for index, row in tqdm(x_test_list.iterrows(), total=x_test_list.shape[0]):
try:
loadedImage = plt.imread(imagePath + str(row[0]) + ".jpg")
x_test.append(loadedImage)
y_test.append(row[1])
except:
# Try with .png file format if images are not properly loaded
try:
loadedImage = plt.imread(imagePath + str(row[0]) + ".png")
x_test.append(loadedImage)
y_test.append(row[1])
except:
# Print file names whenever it is impossible to load image files
print(imagePath + str(row[0]))
并且我已经通过这种方法加载了类文件
# Path to superpixels class files
classes_file = "/home/DEV/SKIN_TISSUE/TISSUE_1500_CLASSES.csv"
concatenated_data= pd.read_csv(classes_file, header=None)
# Instances with targets
targets = concatenated_data[1].tolist()
# Split data according to their classes
class_0 = concatenated_data[concatenated_data[1] == 0]
class_1 = concatenated_data[concatenated_data[1] == 1]
class_2 = concatenated_data[concatenated_data[1] == 2]
class_3 = concatenated_data[concatenated_data[1] == 3]
# Holdout split train/test set (Other options are k-folds or leave-one-out)
split_proportion = 0.9
split_size_0 = int(len(class_0)*split_proportion)
split_size_1 = int(len(class_1)*split_proportion)
split_size_2 = int(len(class_2)*split_proportion)
split_size_3 = int(len(class_3)*split_proportion)
new_class_0_train = np.random.choice(len(class_0), split_size_0, replace=False)
new_class_0_train = class_0.iloc[new_class_0_train]
new_class_0_test = ~class_0.iloc[:][0].isin(new_class_0_train.iloc[:][0])
new_class_0_test = class_0[new_class_0_test]
new_class_1_train = np.random.choice(len(class_1), split_size_1, replace=False)
new_class_1_train = class_1.iloc[new_class_1_train]
new_class_1_test = ~class_1.iloc[:][0].isin(new_class_1_train.iloc[:][0])
new_class_1_test = class_1[new_class_1_test]
new_class_2_train = np.random.choice(len(class_2), split_size_2, replace=False)
new_class_2_train = class_2.iloc[new_class_2_train]
new_class_2_test = ~class_2.iloc[:][0].isin(new_class_2_train.iloc[:][0])
new_class_2_test = class_2[new_class_2_test]
new_class_3_train = np.random.choice(len(class_3), split_size_3, replace=False)
new_class_3_train = class_3.iloc[new_class_3_train]
new_class_3_test = ~class_3.iloc[:][0].isin(new_class_3_train.iloc[:][0])
new_class_3_test = class_3[new_class_3_test]
x_train_list = pd.concat(
[new_class_0_train, new_class_1_train, new_class_2_train, new_class_3_train])
x_test_list = pd.concat(
[new_class_0_test, new_class_1_test, new_class_2_test, new_class_3_test])
模型定义如下:
from sklearn.utils import class_weight
#from tensorflow.keras import backend as K
#K.set_image_data_format('channels_first')
class_weights = class_weight.compute_sample_weight('balanced', y_train)
print("Class weights are :", class_weights)
IMG_HEIGHT = 224
IMG_WIDTH = 224
IMG_CHANNELS = 1
### Defining the model
def get_model():
return multi_unet_model(n_classes=4, IMG_HEIGHT=IMG_HEIGHT, IMG_WIDTH=IMG_WIDTH, IMG_CHANNELS=IMG_CHANNELS)
model = get_model()
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
#model.compile(optimizer='adam', loss='focal_loss', metrics=['accuracy'])
model.summary()
history = model.fit(x_train, y_train, batch_size=16, verbose=1,epochs=50,
#class_weight = class_weights,
validation_data=(x_test,y_test),shuffle=False)
模型摘要输出为:
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) [(None, 224, 224, 1) 0
__________________________________________________________________________________________________
conv2d (Conv2D) (None, 224, 224, 16) 160 input_1[0][0]
__________________________________________________________________________________________________
dropout (Dropout) (None, 224, 224, 16) 0 conv2d[0][0]
__________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, 224, 224, 16) 2320 dropout[0][0]
__________________________________________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 112, 112, 16) 0 conv2d_1[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D) (None, 112, 112, 32) 4640 max_pooling2d[0][0]
__________________________________________________________________________________________________
dropout_1 (Dropout) (None, 112, 112, 32) 0 conv2d_2[0][0]
__________________________________________________________________________________________________
conv2d_3 (Conv2D) (None, 112, 112, 32) 9248 dropout_1[0][0]
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D) (None, 56, 56, 32) 0 conv2d_3[0][0]
__________________________________________________________________________________________________
conv2d_4 (Conv2D) (None, 56, 56, 64) 18496 max_pooling2d_1[0][0]
__________________________________________________________________________________________________
dropout_2 (Dropout) (None, 56, 56, 64) 0 conv2d_4[0][0]
__________________________________________________________________________________________________
conv2d_5 (Conv2D) (None, 56, 56, 64) 36928 dropout_2[0][0]
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D) (None, 28, 28, 64) 0 conv2d_5[0][0]
__________________________________________________________________________________________________
conv2d_6 (Conv2D) (None, 28, 28, 128) 73856 max_pooling2d_2[0][0]
__________________________________________________________________________________________________
dropout_3 (Dropout) (None, 28, 28, 128) 0 conv2d_6[0][0]
__________________________________________________________________________________________________
conv2d_7 (Conv2D) (None, 28, 28, 128) 147584 dropout_3[0][0]
__________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D) (None, 14, 14, 128) 0 conv2d_7[0][0]
__________________________________________________________________________________________________
conv2d_8 (Conv2D) (None, 14, 14, 256) 295168 max_pooling2d_3[0][0]
__________________________________________________________________________________________________
dropout_4 (Dropout) (None, 14, 14, 256) 0 conv2d_8[0][0]
__________________________________________________________________________________________________
conv2d_9 (Conv2D) (None, 14, 14, 256) 590080 dropout_4[0][0]
__________________________________________________________________________________________________
conv2d_transpose (Conv2DTranspo (None, 28, 28, 128) 131200 conv2d_9[0][0]
__________________________________________________________________________________________________
concatenate (Concatenate) (None, 28, 28, 256) 0 conv2d_transpose[0][0]
conv2d_7[0][0]
__________________________________________________________________________________________________
conv2d_10 (Conv2D) (None, 28, 28, 128) 295040 concatenate[0][0]
__________________________________________________________________________________________________
dropout_5 (Dropout) (None, 28, 28, 128) 0 conv2d_10[0][0]
__________________________________________________________________________________________________
conv2d_11 (Conv2D) (None, 28, 28, 128) 147584 dropout_5[0][0]
__________________________________________________________________________________________________
conv2d_transpose_1 (Conv2DTrans (None, 56, 56, 64) 32832 conv2d_11[0][0]
__________________________________________________________________________________________________
concatenate_1 (Concatenate) (None, 56, 56, 128) 0 conv2d_transpose_1[0][0]
conv2d_5[0][0]
__________________________________________________________________________________________________
conv2d_12 (Conv2D) (None, 56, 56, 64) 73792 concatenate_1[0][0]
__________________________________________________________________________________________________
dropout_6 (Dropout) (None, 56, 56, 64) 0 conv2d_12[0][0]
__________________________________________________________________________________________________
conv2d_13 (Conv2D) (None, 56, 56, 64) 36928 dropout_6[0][0]
__________________________________________________________________________________________________
conv2d_transpose_2 (Conv2DTrans (None, 112, 112, 32) 8224 conv2d_13[0][0]
__________________________________________________________________________________________________
concatenate_2 (Concatenate) (None, 112, 112, 64) 0 conv2d_transpose_2[0][0]
conv2d_3[0][0]
__________________________________________________________________________________________________
conv2d_14 (Conv2D) (None, 112, 112, 32) 18464 concatenate_2[0][0]
__________________________________________________________________________________________________
dropout_7 (Dropout) (None, 112, 112, 32) 0 conv2d_14[0][0]
__________________________________________________________________________________________________
conv2d_15 (Conv2D) (None, 112, 112, 32) 9248 dropout_7[0][0]
__________________________________________________________________________________________________
conv2d_transpose_3 (Conv2DTrans (None, 224, 224, 16) 2064 conv2d_15[0][0]
__________________________________________________________________________________________________
concatenate_3 (Concatenate) (None, 224, 224, 32) 0 conv2d_transpose_3[0][0]
conv2d_1[0][0]
__________________________________________________________________________________________________
conv2d_16 (Conv2D) (None, 224, 224, 16) 4624 concatenate_3[0][0]
__________________________________________________________________________________________________
dropout_8 (Dropout) (None, 224, 224, 16) 0 conv2d_16[0][0]
__________________________________________________________________________________________________
conv2d_17 (Conv2D) (None, 224, 224, 16) 2320 dropout_8[0][0]
__________________________________________________________________________________________________
conv2d_18 (Conv2D) (None, 224, 224, 4) 68 conv2d_17[0][0]
==================================================================================================
Total params: 1,940,868
Trainable params: 1,940,868
Non-trainable params: 0
__________________________________________________________________________________________________