如何获得在Keras train_on_batch的损失函数中计算的tf张量值而无需计算两次或编写自定义循环?

时间:2019-08-05 19:50:27

标签: tensorflow keras

我有一个模型,并且已经实现了自定义损失函数:

def custom_loss(labels, predictions):
    global diff
    #actual code uses decorator so no globals
    diff = labels - predictions
    return tf.square(diff)

model.compile(loss=custom_loss, optimizer=opt.RMSprop())

...

model.train_on_batch(input, labels)

#

在运行diff之后如何获取train_on_batch而又不会导致 它会重新运行以预测第二次出现在幕后(不必要的减速)并弄混了可训练的/批处理规范(可能的问题)?

我想避免进行手动的原始tensorflow train_op循环等,跟踪学习阶段和其他情况。 Is this my only choice

我正在使用tensorflow 1.14的keras模块

1 个答案:

答案 0 :(得分:0)

我已经解决了(发现了控件依赖性和记住的变量)

基本上,我为变量操作创建了一个diff赋值,并且在control_dependencies的帮助下,每次计算op时,tf都要执行此操作,这样,当我获得此变量时,就不会导致图形重新计算

@Bean
public ReactiveJwtDecoder jwtDecoder() {
    return new NimbusReactiveJwtDecoder(keySourceUrl);
}
@Bean
public ReactiveAuthenticationManager authenticationManager() {
    return new JwtReactiveAuthenticationManager(jwtDecoder());
}

测试代码

diff_var = tf.Variable()

def custom_loss(labels, predictions):
    diff = labels - predictions
    diff_var_op = diff_var.assign(diff)
    with tf.control_dependencies([diff_var_op]):
        return tf.square(diff)

我已经在16个小时前将这个问题发布到了datascience stackexchange上,但是没有得到任何答案,事实证明,关于stackexchange的问题还更多,所以我希望自己在此工作上能更快得到答案。 。 不确定如何处理ds.exchange问​​题,也将答案复制到那里