自定义Keras损失函数合并类

时间:2018-10-16 12:48:00

标签: python-3.x tensorflow keras loss-function

我有一些数据,其中每个点都属于两个主要类之一,几个子类之一,并且是实数值的向量。子类对于两个主要类都是通用的。因此数据看起来像这样:

a类的数据

| primary class | subclass | x0  | ... | xn  |
|---------------|----------|-----|-----|-----|
|       a       |    0     | x00 | ... | x0n |
|       a       |    1     | x10 | ... | x1n |
|       a       |    2     | x20 | ... | x2n |
|      ...      |   ...    | ... | ... | ... |
|       a       |    0     | xm0 | ... | xmn |

b类的数据

| primary class | subclass | x0  | ... | xn  |
|---------------|----------|-----|-----|-----|
|       b       |    0     | x00 | ... | x0n |
|       b       |    1     | x10 | ... | x1n |
|       b       |    2     | x20 | ... | x2n |
|      ...      |   ...    | ... | ... | ... |
|       b       |    2     | xk0 | ... | xkn |

数据分为目标。该网络应该学习一种地图,以使 source 分布与 target 分布相似。目前,我的成本函数未考虑主要类别中的子类别。我想让它做到这一点。当前费用如下所示:

def cost(y_true, y_pred):
    # create a random sample of the target data
    sample = K.cast(K.round(K.random_uniform_varable(shape=tuple([target_sample_size]), low=0, high=target_train_size-1)), 'int32')

    # get the sample from the target training data
    target_sample = K.gather(target_train, sample)

    # the loss is a distance metric between the output of the net
    # from that batch and the sample we got from the target
    loss = distance(net_output_layer, target_sample)

    return loss

请注意,y_true和y_pred本质上是伪变量(但我想使用网络的输出层与y_pred相同吗?)。

我希望费用是这样的:

def cost(y_true, y_pred):
    # create a random sample of the target data
    sample = K.cast(K.round(K.random_uniform_varable(shape=tuple([target_sample_size]), low=0, high=target_train_size-1)), 'int32')

    # get the sample from the target training data
    target_sample = K.gather(target_train, sample)

    # split the target sample based on the subclasses
    # how can this be done?
    target_0 = ???
    target_1 = ???
    target_2 = ???

    # split the source mini batch based on the subclasses
    # how can this be done?
    source_0 = ???
    source_1 = ???
    source_2 = ???

    # the loss is the sum of the distances bewteen the subclasses
    loss = distance(source_0, target_0) + distance(source_1, target_1) + distance(source_2, target_2)

    return loss

因此对于一个迷你批处理,我想将迷你批处理拆分为子类,然后计算子类之间的距离。我可以将子类标签传递为y_true,但是我无法弄清楚如何使用它来仅选择net_output_layer中属于子类之一的那些部分。我搜索了很多东西,但找不到任何相同的东西。非常感谢您的帮助。

0 个答案:

没有答案