具有数据生成器的MNIST的三重态损失

时间:2019-12-27 17:02:20

标签: keras deep-learning

我为MNIST创建了三重学习代码。无法弄清楚是什么问题。 请帮忙。代码的第一行是连接到Google云端硬盘-我使用的是Google Colab。如果您使用其他环境,请忽略它。

from google.colab import drive
drive.mount("/content/drive/", force_remount=True)

import os
import random 
import numpy as np 

import matplotlib.pyplot as plt
import matplotlib.patheffects as PathEffects
from sklearn.manifold import TSNE
from keras.preprocessing.image import ImageDataGenerator
from keras.preprocessing.image import array_to_img, img_to_array, load_img
from keras.layers import Input, Conv2D, MaxPooling2D, Dense, Activation, Dropout
from keras.layers import Flatten, Lambda, concatenate # ,BatchNormalization #, GaussianNoise
from keras.callbacks import ModelCheckpoint
from keras.optimizers import SGD, Adam
from keras.models import Model
import keras.backend as K
from keras.models import Sequential
from keras.datasets import mnist

import seaborn as sns

np.random.seed(7)

working_path = "/content/drive/My Drive/TripletMNIST/data/"
!ls "/content/drive/My Drive/TripletMNIST/data"

best_weights_filepath = working_path + "models/2_1_TripletMNIST.txt"
last_weights_filepath = working_path + "models/2_1_TripletMNIST_Last.txt"

(x_train, y_train), (x_test, y_test) = mnist.load_data()

x = np.concatenate((x_train, x_test), axis=0)
y = np.concatenate((y_train, y_test), axis=0)

unique, counts = np.unique(y_train, return_counts=True)
print(unique, counts)

MIN_NUM_OF_IMAGES_PER_CLASS = min(counts) # This is a minimum num. of images among 10 digits
print(MIN_NUM_OF_IMAGES_PER_CLASS)

img = np.array(x_train[0], dtype='float')
#pixels = first_image.reshape((28, 28))
plt.imshow(img, cmap='gray')
plt.show()

IMAGE_SIZE = 28
input_shape=(IMAGE_SIZE, IMAGE_SIZE, 1)

BATCH_SIZE = 256
TRAINING_IMAGES_PER_CLASS = int(MIN_NUM_OF_IMAGES_PER_CLASS * 0.6)
VALIDATION_IMAGES_PER_CLASS = int(MIN_NUM_OF_IMAGES_PER_CLASS * 0.2)
TESTING_IMAGES_PER_CLASS = MIN_NUM_OF_IMAGES_PER_CLASS 
    - TRAINING_IMAGES_PER_CLASS - VALIDATION_IMAGES_PER_CLASS

我们将数据保存在一个数组中,其中包含ID(类名),ImageNames(该类的所有图像,此处为“ 0”,“ 1”,以此类推,直到“ 9”)和“ Images”(二进制图像)。 / p>

arrLabels = []
for strClassName in np.unique(y):
  # All images for that class
  fltr = np.where((y == strClassName))
  #print(len(fltr[0]))

  arrImages = []
  arrImageNames = []
  for nIdx in fltr[0]:
    if(len(arrImageNames) > MIN_NUM_OF_IMAGES_PER_CLASS):
      break
    strFileName = y[nIdx]
    arrImageNames.append(strFileName)
    arrImages.append(x[nIdx])

  arrLabels.append(
  {
    'Id':strClassName,
    'ImageNames':arrImageNames[:MIN_NUM_OF_IMAGES_PER_CLASS],
    'Images':arrImages[:MIN_NUM_OF_IMAGES_PER_CLASS]
  })

print("Classes selected: ", len(arrLabels)) 

仅在散点图中使用(以下)

arrTrainImages = []
arrTrainClasses = []

arrTestImages = []
arrTestClasses = []

for nClassIdx in range(len(arrLabels)):
  for nImageIdx in range(TRAINING_IMAGES_PER_CLASS + VALIDATION_IMAGES_PER_CLASS):
    arrTrainClasses.append(arrLabels[nClassIdx]['Id'])
    arrTrainImages.append(img_to_array(arrLabels[nClassIdx]['Images'][nImageIdx]))
  for nImageIdx in range(TRAINING_IMAGES_PER_CLASS + VALIDATION_IMAGES_PER_CLASS + 1, TRAINING_IMAGES_PER_CLASS + VALIDATION_IMAGES_PER_CLASS + TESTING_IMAGES_PER_CLASS):
    arrTestClasses.append(arrLabels[nClassIdx]['Id'])
    arrTestImages.append(img_to_array(arrLabels[nClassIdx]['Images'][nImageIdx]))

arrTrainClasses = np.array(arrTrainClasses)
arrTrainImages = np.array(arrTrainImages)
arrTestClasses = np.array(arrTestClasses)
arrTestImages = np.array(arrTestImages)

arrTrainImagesFlat = arrTrainImages.reshape(-1,IMAGE_SIZE*IMAGE_SIZE*1)
arrTestImagesFlat = arrTestImages.reshape(-1,IMAGE_SIZE*IMAGE_SIZE*1)

# Define our own plot function
def scatter(x, labels, subtitle=None, bShowClasterLabels=0):
    # Choose a color palette with seaborn.

    cmap = np.array(sns.color_palette("hls", len(labels)))

    # Create a scatter plot.
    f = plt.figure(figsize=(8, 8))
    ax = plt.subplot(aspect='equal')
    sc = ax.scatter(x[:,0], x[:,1], lw=0, s=40, c=cmap)
    plt.xlim(-25, 25)
    plt.ylim(-25, 25)
    ax.axis('off')
    ax.axis('tight')

    # Add the labels for each digit. Use if clasters are present.
    if(bShowClasterLabels):
      txts = []
      for i in range(len(arrLabels)):
          # Position of each label.
          xtext, ytext = np.median(x[labels == arrLabels[i]['Id'], :], axis=0)
          txt = ax.text(xtext, ytext, str(i), fontsize=24)
          txt.set_path_effects([
              PathEffects.Stroke(linewidth=5, foreground="w"),
              PathEffects.Normal()])
          txts.append(txt)

    if subtitle != None:
        plt.suptitle(subtitle)

    plt.savefig(subtitle)

在集群化之前绘制点

tsne = TSNE()

arr = list(range(len(arrTrainImagesFlat)))
arrSamples = random.sample(arr, 100)
train_tsne_embeds = tsne.fit_transform(arrTrainImagesFlat[arrSamples])
scatter(train_tsne_embeds, arrTrainClasses[arrSamples], "Samples from Training Data")

arr = list(range(len(arrTestImagesFlat)))
arrSamples = random.sample(arr,20)
eval_tsne_embeds = tsne.fit_transform(arrTestImagesFlat[arrSamples])
scatter(eval_tsne_embeds, arrTestClasses[arrSamples], "Samples from Test Data", 0)


datagen = ImageDataGenerator(
    rotation_range=15,
    width_shift_range=0.2,
    height_shift_range=0.2,
    zoom_range=0.5
)

def loadImage(cClass, nImageIdx, datagen):
  img = cClass['Images'][nImageIdx]

  arrImg = img_to_array(img)

  arrImg = datagen.random_transform(arrImg)

  return np.array(arrImg, dtype="float32")


def deleteSavedNet(best_weights_filepath):
    if(os.path.isfile(best_weights_filepath)):
        os.remove(best_weights_filepath)
        print("deleteSavedNet():File removed")
    else:
        print("deleteSavedNet():No file to remove") 


def plotChart_1(arrX, arrY, strTitle, strLabelY, strLabelX):
    plt.plot(arrX, arrY)
    plt.title(strTitle)
    plt.ylabel(strLabelY)
    plt.xlabel(strLabelX)
    plt.show()

def plotHistoryAccForSubmit():
    plt.plot(history.history['acc'])
    plt.title('Model accuracy')
    plt.ylabel('Accuracy')
    plt.xlabel('Epoch')
    plt.show()

def plotHistoryAcc():
    plt.plot(history.history['accuracy'])
    plt.plot(history.history['val_accuracy'])
    plt.title('Model accuracy')
    plt.ylabel('Accuracy')
    plt.xlabel('Epoch')
    plt.show()

def plotHistoryLoss():
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('Model loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.show()    

def plotLoss(arrTrainingLoss, arrValidationLoss):
    plt.plot(arrTrainingLoss)
    plt.plot(arrValidationLoss)
    plt.title('Model loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.show()   


def triplet_loss(y_true, y_pred, alpha = 0.2):

    #print('y_pred.shape = ',y_pred)

    VECTOR_SIZE = BATCH_SIZE
    anchor = y_pred[:, :VECTOR_SIZE]
    positive = y_pred[:, VECTOR_SIZE:2*VECTOR_SIZE]
    negative = y_pred[:, 2*VECTOR_SIZE:]

    # distance between the anchor and the positive
    pos_dist = K.sum(K.square(anchor-positive),axis=1)

    # distance between the anchor and the negative
    neg_dist = K.sum(K.square(anchor-negative),axis=1)

    # compute loss
    basic_loss = pos_dist-neg_dist+alpha
    loss = K.maximum(basic_loss,0.0)

    return loss


def getBaseModel(input_shape, type, nL2):

  # https://towardsdatascience.com/lossless-triplet-loss-7e932f990b24
  if(type=="lossless"):
    base_model = Sequential()
    base_model.add(Conv2D(128,(7,7),padding='same',input_shape=input_shape,activation='relu',name='conv1'))
    base_model.add(MaxPooling2D((2,2),(2,2),padding='same',name='pool1'))
    base_model.add(Conv2D(256,(5,5),padding='same',activation='relu',name='conv2'))
    base_model.add(MaxPooling2D((2,2),(2,2),padding='same',name='pool2'))
    base_model.add(Flatten(name='flatten'))
    base_model.add(Dense(EMBEDDING_DIM,name='embeddings', activation="sigmoid"))

  return base_model  


def createModel(nL2):

  base_model = getBaseModel(input_shape, "lossless", nL2)

  input_anchor = Input(shape=input_shape, name='input_anchor')
  input_positive = Input(shape=input_shape, name='input_pos')
  input_negative = Input(shape=input_shape, name='input_neg')

  base_model_anchor = base_model(input_anchor)
  base_model_positive = base_model(input_positive)
  base_model_negative = base_model(input_negative)

  merged_vector = concatenate([base_model_anchor, base_model_positive, base_model_negative], axis=-1, name='merged_layer')

  base_lr = 0.00004

  triplet_model = Model(inputs=[input_anchor, input_positive, input_negative], outputs=merged_vector)

  #triplet_model.compile(loss=triplet_loss, optimizer=adam_optim)

  # triplet_model.summary()  
  triplet_model.compile(loss=triplet_loss, optimizer=Adam(base_lr), metrics=[my_accuracy])

  return base_model, triplet_model


def get_triplet(nSplitIdx, bIsTrain):
  # Select random class
  positiveClass = np.random.choice(arrLabels)

  # Depending train or validate, select range. 
  # Say we have 10 images per class, and 70% does to train. Then 0-6 (train); 7-9 (valid, at least 3)
  if(bIsTrain):
    nMinIdx = 0
    nMaxIdx = nSplitIdx - 1
  else:
    nMinIdx = nSplitIdx
    nMaxIdx = MIN_NUM_OF_IMAGES_PER_CLASS - 1 - TESTING_IMAGES_PER_CLASS

  # Get 3 indices: for base image and for positive example, from same class. And one more for negative example.
  # TBD: figure (test) if SAME image should be used in a positive pair, like [img[1], img[1]]?
  arrImageIdx = np.random.choice(range(nMinIdx, nMaxIdx), 3)

  while arrImageIdx[0] == arrImageIdx[1]:
    arrImageIdx[1] = np.random.choice(range(nMinIdx, nMaxIdx))

  negativeClass = np.random.choice(arrLabels)
  while negativeClass['Id'] == positiveClass['Id']:
    negativeClass = np.random.choice(arrLabels)

  return arrImageIdx, positiveClass, negativeClass


def gen(bIsTrain):
  #nSplitIdx = int(NUM_OF_IMAGES_PER_CLASS * TESTING_SPLIT)
  while True:
    arrBaseExamples = []
    arrPositiveExamples = []
    arrNegativeExamples = []

    for i in range(BATCH_SIZE):
      nImageIdx, positiveClass, negativeClass = get_triplet(TRAINING_IMAGES_PER_CLASS, bIsTrain)

      #t0 = time()
      baseExampleImg = loadImage(positiveClass, nImageIdx[0], datagen)      
      positiveExampleImg = loadImage(positiveClass, nImageIdx[1], datagen)
      negativeExampleImg = loadImage(negativeClass, nImageIdx[2], datagen)


      arrBaseExamples.append(baseExampleImg)
      arrPositiveExamples.append(positiveExampleImg)
      arrNegativeExamples.append(negativeExampleImg)

    base = np.array(arrBaseExamples) / 255.
    positive = np.array(arrPositiveExamples) / 255.
    negative = np.array(arrNegativeExamples) / 255.

    label = None

    yield (np.stack(base, positive, negative), label) 

checkpoint = ModelCheckpoint(best_weights_filepath, monitor="val_my_accuracy", save_best_only=True, save_weights_only=True, mode='max')#, verbose=1)

callbacks_list = [checkpoint] #, save_model_at_epoch_end_callback]  # , early]

gen_train = gen(True)
gen_valid = gen(False)

EPOCHS=50

nMultiplier = 10
arrParams = [[0.8, 3]]

for i in range(0, len(arrParams)):

  nL2 = arrParams[i][0]
  EMBEDDING_DIM = arrParams[i][1]

  deleteSavedNet(best_weights_filepath)

  #random.seed(datetime.now())
  random.seed(7)
  embedding_model, triplet_model = createModel(nL2)

  #loadBestModel()

  # No need to set classifier layers to True, as already True as created (?)
  nNumOfClasses = len(arrLabels)
  nNumOfTrainSamples = TRAINING_IMAGES_PER_CLASS * nNumOfClasses
  nNumOfValidSamples = VALIDATION_IMAGES_PER_CLASS * nNumOfClasses
  STEP_SIZE_TRAIN = nMultiplier * nNumOfTrainSamples // BATCH_SIZE
  if(STEP_SIZE_TRAIN == 0):
    STEP_SIZE_TRAIN = 1

  STEP_SIZE_VALID = nMultiplier * nNumOfValidSamples // BATCH_SIZE
  if(STEP_SIZE_VALID == 0):
    STEP_SIZE_VALID = 1

  print("Available metrics: ", triplet_model.metrics_names)

  history = triplet_model.fit_generator(gen_train, validation_data=gen_valid, 
    epochs=EPOCHS, steps_per_epoch=STEP_SIZE_TRAIN, validation_steps=STEP_SIZE_VALID, callbacks=callbacks_list)
    #workers=4, use_multiprocessing=True)

  print(nL2, EMBEDDING_DIM)
  plotHistoryLoss()    
  plotHistoryAcc()

错误我得到: ValueError:尺寸必须相等,但输入形状为[?,9],[?, 0]的'loss_4 / merged_layer_loss / sub'(op:'Sub')的尺寸应为9和0。

0 个答案:

没有答案