在keras中制作自定义丢失功能

时间:2017-08-30 13:11:57

标签: python machine-learning tensorflow keras

您好我一直在尝试在keras中为dice_error_coefficient创建自定义丢失函数。它的实现在 tensorboard 中,我尝试在keras中使用相同的函数和tensorflow,但是当我使用 model.train_on_batch 时它仍然返回 NoneType model.fit ,因为它在模型中的指标中使用时会给出正确的值。可以请有人帮我解决我该怎么办?我曾经尝试过像ahundt这样的Keras-FCN这样的库,在那里他使用了自定义丢失函数,但似乎都没有。代码中的目标和输出分别为y_true和y_pred,用于keras中的losses.py文件。

def dice_hard_coe(target, output, threshold=0.5, axis=[1,2], smooth=1e-5):
    """References
    -----------
    - `Wiki-Dice <https://en.wikipedia.org/wiki/Sørensen–Dice_coefficient>`_
    """

    output = tf.cast(output > threshold, dtype=tf.float32)
    target = tf.cast(target > threshold, dtype=tf.float32)
    inse = tf.reduce_sum(tf.multiply(output, target), axis=axis)
    l = tf.reduce_sum(output, axis=axis)
    r = tf.reduce_sum(target, axis=axis)
    hard_dice = (2. * inse + smooth) / (l + r + smooth)
    hard_dice = tf.reduce_mean(hard_dice)
    return hard_dice

2 个答案:

答案 0 :(得分:57)

在Keras中实现参数化自定义丢失功能有两个步骤。首先,编写系数/度量的方法。其次,编写一个包装函数来按照Keras需要的方式格式化事物。

  1. 使用Keras后端而不是tensorflow直接用于简单的自定义丢失功能(如DICE)实际上相当简洁。以下是以这种方式实现的系数示例:

    import keras.backend as K
    def dice_coef(y_true, y_pred, smooth, thresh):
        y_pred = y_pred > thresh
        y_true_f = K.flatten(y_true)
        y_pred_f = K.flatten(y_pred)
        intersection = K.sum(y_true_f * y_pred_f)
    
        return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
    
  2. 现在是棘手的部分。 Keras损失函数必须只取(y_true,y_pred)作为参数。所以我们需要一个单独的函数来返回另一个函数。

    def dice_loss(smooth, thresh):
      def dice(y_true, y_pred)
        return -dice_coef(y_true, y_pred, smooth, thresh)
      return dice
    
  3. 最后,您可以在Keras编译中使用它。

    # build model 
    model = my_model()
    # get the loss function
    model_dice = dice_loss(smooth=1e-5, thresh=0.5)
    # compile model
    model.compile(loss=model_dice)
    

答案 1 :(得分:0)

根据documentation,您可以使用这样的自定义损失函数:

<块引用>

任何带有签名 const defaultVals : string[] = ["val1","val1","val1","val2] 的可调用函数返回损失数组(输入批次中的一个样本)都可以作为损失传递给 compile()。请注意,任何此类损失都会自动支持样本加权。

举个简单的例子:

const auth = require("../middleware/auth");

const bcrypt = require("bcrypt");

const _ = require("lodash");

const { User, validate } = require("../models/user");

const express = require("express");

const router = express.Router();

router.get("/me", auth, async (req, res) => {

     const user = await User.findById(req.user._id).select("-password");

     res.send(user);

});

router.post("/", async (req, res) => {

   const { error } = validate(req.body);

   if (error) return res.status(400).send(error.details[0].message);

   let user = await User.findOne({ email: req.body.email });

   if (user) return res.status(400).send("User already registered.");

   user = new User(_.pick(req.body, ["name", "email", "password"]));

   bcrypt.genSalt(10, function (_err, salt) {

  bcrypt.hash(user.name, salt, function (_err, hash) {

  // Store hash in your password DB.

   user.password = hash;

  });

  });



    await user.save();

    const token = user.generateAuthToken();

    res

     .header("x-auth-token", token)

     .header("access-control-expose-headers", "x-auth-token")

     .send(_.pick(user, ["_id", "name", "email"]));

  });

  module.exports = router;

完整示例:

 import React from "react";

 import Joi from "joi-browser";

 import Form from "./common/form";

 import * as userService from '../services/userService'; // Import methods *

 class RegisterForm extends Form {

    state = {

    data: { username: "", password: "", name: "" },

    errors: {}

   };

   schema = {

   username: Joi.string()

    .required()

    .email()

    .label("Username"),

   password: Joi.string()

    .required()

    .min(5)

    .label("Password"),

 name: Joi.string()

    .required()

   .label("Name")

 };

 doSubmit = async () => {

   try {

     const response = await userService.register(this.state.data);

    console.log(response);

    localStorage.setItem('token', response.header['x-auth-token']);

   this.props.history.push("/");

  }

  catch (ex) {

       if (ex.response && ex.response.status === 400) {

      const errors = {...this.state.errors};

      errors.username = ex.response.data;

      this.setState({ errors });

    }

   }



  userService.register(this.state.data);

  };

  render() {

   return (

       <div>

         <h1>Register</h1>

          <form onSubmit={this.handleSubmit}>

             {this.renderInput("username", "Username")}

            {this.renderInput("password", "Password", "password")}

            {this.renderInput("name", "Name")}

            {this.renderButton("Register")}

        </form>

    </div>

    );

   }

  }

  export default RegisterForm;