自定义损失函数的外部变量访问

时间:2019-07-19 15:16:19

标签: keras classification cluster-analysis loss-function

我在Keras中有一个分类神经网络,我想为此编写一个自定义损失函数。在这个损失函数中,我想访问一个外部变量(浮点数组)(全局簇列表),并提取该变量的第i个索引,其中“ i”是网络预测的类(预测的群集号)。样本由不应进入模型的(类别变量)和数值变量(坐标(x,y))组成。 (损失函数)的最终目标是将(坐标(x,y))(集群坐标(x',y'))< / strong>并计算其RMSE。 (下面的图形说明很小)

http://prntscr.com/ohd53e(我只能发布链接,因为堆栈溢出将不允许我显示图片)

下面已经尝试过的代码,我正在使用Keras的功能API将**(分类变量)输入到NN中,然后使用自定义损失函数进行编译。

def custom_loss(global_list_of_clusters):
    def my_mse(real_coordinates,outputs) :
        predicted_cluster_numbers = K.argmax(outputs,axis=1)
        return K.sqrt(K.mean(K.square(global_list_of_clusters[predicted_cluster_numbers] - real_coordinates), axis=[0,1])) 
    return my_mse

def init_model(nb_cat_feat,global_list_of_clusters) :

    input_cat = Input(shape=(None,nb_cat_feat),name='cat_feat')

    dense1 = Dense(50)(input_cat)
    dense2 = Dense(25)(dense1)
    output = Dense(n_clusters,activation='softmax')(dense2)

    model = Model(inputs=input_cat,outputs=output)

    model.compile(loss=custom_loss(global_list_of_clusters), optimizer='adam')

    return model

培训将这样称呼:

model.fit(cat_feat,inv_sequences,epochs=5,batch_size=500)

我的代码当前给我的错误是这个错误:

File "/home/courbes_usr/courbes-clustering/NN_model.py", line 21, in my_mse
return K.sqrt(K.mean(K.square(global_list_of_clusters[predicted_cluster_number] - real_coordinates), axis=[0,1]))
IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices

这使我相信我不正确地访问了(global_list_of_clusters)结构。但是,我还没有找到正确的方法来做自己想要的事情。

谢谢您的帮助!

0 个答案:

没有答案