Keras自定义损失函数访问python全局变量时的内部机制是什么?

时间:2018-12-24 22:25:47

标签: python tensorflow keras

Keras自定义损失函数是否接受全局python变量?

我正在构建自己的Keras自定义损失函数,该函数仅接受y_truey_pred作为参数。但是损失函数非常复杂,并且依赖于其他变量。当前在我的实现中,损失函数只是直接在同一python代码脚本中使用全局变量。训练模型后,如果我想使用模型进行预测,则将更改python环境中的那些全局变量。我的问题是,是否需要再次编译模型,以确保模型已使用这些外部全局变量的最新版本进行了更新?

Rlist=....

def custom_loss(y_true,y_pred):
    z = 0.0
    #Rlist is the global variable 
    for j in Rlist:
        z = z  +K.log(K.sum(K.exp(K.gather(y_pred,j[0])))) \
        - K.log(K.sum(K.exp(K.gather(y_pred,j))))
    z = -z 
    return z
#below build the model and compile it with loss=custom_loss
model=...
model.compile(loss=custom_loss,....
model.fit(x=train_x,y=train_y,...)
#Rlist=...  update Rlist which is adaptive to test dataset
#Do I need to recompile in the code below,or whether Rlist is updated
#in custom_loss when it is called?
model.predict(x=test_x,y=test_y,...)

在我的损失函数中(实际上这是Cox比例风险模型的损失函数),每个样本的损失值之间的损失不是相加的。 Rlist是我的Keras代码的python环境中的全局变量 我的问题是,在训练模型后,如果我将此Rlist更改为 测试数据集,Keras会自动更新Rlist,还是在编译和构建计算图时使用此变量Rlist的旧版本?

是否有任何解释,如果我直接在损失函数中从python环境中引用全局变量,那么当Tensorflow构建其计算图时会发生什么? 我知道使用全局变量不是一个好习惯,还建议更好的建议。

1 个答案:

答案 0 :(得分:0)

“我的Keras代码的python环境”到底是什么意思?如果在训练时在代码中将Rlist变量设置为[1,2,3]。然后在预测/生产模式下将其更改为[3,2,1],您自定义损失将看到[3,2,1]变量。

我不确定您要达到的目标,我想这可能有效: A)用RList创建一个真实的ENV_Variable B)用RList创建一个JSON文件(这样,您就可以在服务器或云上的生产模式下使用RList数据)。 C)在您的代码中创建一个Dict,例如

var names = [{name: 'steven', age: 22}, {name: 'bill', age: 13}];

function findBill(array) {
  return array.find((obj) => (
    obj.name === 'bill'
  ))
};

console.log(findBill(names))