图像理解-CNN三重损失

时间:2018-12-01 16:34:29

标签: python tensorflow keras neural-network deep-learning

我是NN的新手,正在尝试创建一个用于图像理解的简单NN。

我尝试使用三元组损失方法,但是不断出现错误,使我觉得我缺少一些基本概念。

我的代码是:

def triplet_loss(x):
  anchor, positive, negative = tf.split(x, 3)

  pos_dist = tf.reduce_sum(tf.square(tf.subtract(anchor, positive)), 1)
  neg_dist = tf.reduce_sum(tf.square(tf.subtract(anchor, negative)), 1)

  basic_loss = tf.add(tf.subtract(pos_dist, neg_dist), ALPHA)
  loss = tf.reduce_mean(tf.maximum(basic_loss, 0.0), 0)

  return loss


def build_model(input_shape):
  K.set_image_data_format('channels_last')

  positive_example = Input(shape=input_shape)
  negative_example = Input(shape=input_shape)
  anchor_example = Input(shape=input_shape)

  embedding_network = create_embedding_network(input_shape)

  positive_embedding = embedding_network(positive_example)
  negative_embedding = embedding_network(negative_example)
  anchor_embedding = embedding_network(anchor_example)

  merged_output = concatenate([anchor_embedding, positive_embedding, negative_embedding])
  loss = Lambda(triplet_loss, (1,))(merged_output)

  model = Model(inputs=[anchor_example, positive_example, negative_example],
              outputs=loss)
  model.compile(loss='mean_absolute_error', optimizer=Adam())

  return model



def create_embedding_network(input_shape):
  input_shape = Input(input_shape)
  x = Conv2D(32, (3, 3))(input_shape)
  x = PReLU()(x)
  x = Conv2D(64, (3, 3))(x)
  x = PReLU()(x)

  x = Flatten()(x)
  x = Dense(10, activation='softmax')(x)
  model = Model(inputs=input_shape, outputs=x)
  return model

使用:

读取每个图像
imageio.imread(imagePath, pilmode="RGB")

每个图像的形状:

(1024, 1024, 3)

然后我使用自己的三元组方法(仅创建3组锚点,正负)

triplets = get_triplets(data)
triplets.shape

形状为(示例数,三元组,x_image,y_image,通道数  (RGB)):

(20, 3, 1024, 1024, 3)

然后我使用build_model函数:

model = build_model((1024, 1024, 3))

问题从这里开始:

model.fit(triplets, y=np.zeros(len(triplets)), batch_size=1)

对于这行代码,当我尝试训练模型时,出现此错误:

error

有关更多详细信息,我的代码在此collab notebook

我使用的图片可以在此Drive中找到 为了使它无缝运行,请将此文件夹放在

  

我的云端硬盘/ Colab笔记本/图像/

1 个答案:

答案 0 :(得分:0)

对于任何也在挣扎的人

我的问题实际上是每个观察的维度。 通过更改注释中建议的尺寸

(?, 1024, 1024, 3)

使用解决方案更新的colab笔记本

P.s-我还将图片的大小更改为256 * 256,以便代码在我的PC上运行得更快。