找到scipy.stat.norm.pdf(x)的梯度时发生Autograd TypeError

时间:2018-09-16 00:12:55

标签: python scipy typeerror normal-distribution autograd

我想在python中使用autograd使用scipy.stats.norm查找正态分布pdf的简单渐变。

import scipy.stats as stat
import autograd.numpy as np
from autograd import grad

def f(x):
    return stat.norm.pdf(x, 0.0, 1.0)

grad_f = grad(f)

print(grad_f(-1.0))

但是,我遇到了这种错字:

Traceback (most recent call last):
  File "/Users/Lars/anaconda3/lib/python3.6/site-packages/autograd/core.py", line 62, in forward_pass
    try: end_node = fun(*args, **kwargs)
  File "error.py", line 7, in f
    return stat.norm.pdf(x, 0.0, 1.0)
  File "/Users/Lars/anaconda3/lib/python3.6/site-packages/scipy/stats/_distn_infrastructure.py", line 1657, in pdf
    putmask(output, (1-cond0)+np.isnan(x), self.badvalue)
TypeError: ufunc 'isnan' not supported for the input types, and the inputs could not be safely coerced to any supported typesaccording to the casting rule ''safe''

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "error.py", line 11, in <module>
    print(grad_f(-1.0))
  File "/Users/Lars/anaconda3/lib/python3.6/site-packages/autograd/core.py", line 21, in gradfun
    return backward_pass(*forward_pass(fun,args,kwargs,argnum))
  File "/Users/Lars/anaconda3/lib/python3.6/site-packages/autograd/core.py", line 63, in forward_pass
    except Exception as e: add_extra_error_message(e)
  File "/Users/Lars/anaconda3/lib/python3.6/site-packages/autograd/core.py", line 392, in add_extra_error_message
    raise_(etype, value, traceback)
  File "/Users/Lars/anaconda3/lib/python3.6/site-packages/future/utils/__init__.py", line 413, in raise_
    raise exc.with_traceback(tb)
  File "/Users/Lars/anaconda3/lib/python3.6/site-packages/autograd/core.py", line 62, in forward_pass
    try: end_node = fun(*args, **kwargs)
  File "error.py", line 7, in f
    return stat.norm.pdf(x, 0.0, 1.0)
  File "/Users/Lars/anaconda3/lib/python3.6/site-packages/scipy/stats/_distn_infrastructure.py", line 1657, in pdf
    putmask(output, (1-cond0)+np.isnan(x), self.badvalue)
TypeError: ufunc 'isnan' not supported for the input types, and the inputs could not be safely coerced to any supported typesaccording to the casting rule ''safe''

对不起,代码过载。我不知道这可能有什么问题。据我所知,autograd支持scipy.stats.norm.pdf()/ cdf()/ logpdf()/ logcdf()的渐变,如代码所示 https://github.com/HIPS/autograd/blob/master/autograd/scipy/stats/norm.py

1 个答案:

答案 0 :(得分:1)

您需要从autograd导入scipy,因为它将适当包装scipy的功能。以下作品:

// this is from models/Battle

const mongoose = require('mongoose');
const Schema = mongoose.Schema;

// Create Schema
const BattleSchema = new Schema({
    user: {
        type: Schema.Types.ObjectId,
        ref: 'users'
    },
    date: {
        type: Date, 
        default: Date.now
    },
    category: {
        type: Number, 
        required: true // this will come from the selected category 
    },
    winner: {
        type: Number, 
        default: 0
    },
    status: {
        type: Number, 
        default: 0 // 0 means the battle is closed, 1 means the battle is open for votes, the status will stay 0 until all participants dropped
    },
    participants: [
        {
          participant: {
            type: Schema.Types.ObjectId,
            required: true
          }
        }
    ]
    
 

});

module.exports = Battle = mongoose.model('battles', BattleSchema);

//this is from routes/api/battles

// @route   POST api/battles
// @desc    Create battle
// @access  Private
router.post(
    '/create-battle',
    passport.authenticate('jwt', { session: false }),
    (req, res) => {
      const { errors, isValid } = validateBattleInput(req.body);
  
      // Check Validation
      if (!isValid) {
        // If any errors, send 400 with errors object
        return res.status(400).json(errors);
        console.log(errors);
      }

      const newBattle = new Battle({
         user: req.user.id,
         category: req.body.category,
         participant: req.body.participant
      });      

      //save
      newBattle.save().then(battle => {       

        // const participant = req.body.participant;
        const participant = req.body.participant;


        // add participants to array 
        battle.participants.push( participant );
        console.log(typeof req.body.participant);

        // get the inserted id  
        const battleId = battle._id;
        res.json(battle);      

      
      });
    }
);

// this is battle validation 
const Validator = require('validator');
const isEmpty = require('./is-empty');
var bodyParser = require('body-parser');

module.exports = function validateBattleInput(data) {
  let errors = {};

  data.category = !isEmpty(data.category) ? data.category : '';
  data.participant = !isEmpty(data.participant) ? data.participant : '';

  if (Validator.isEmpty(data.category)) {
    errors.category = 'Category field is required';
  }

  // if (Validator.isEmpty(data.challenger)) {
  //     errors.challenger = 'Challenger field is required';
  // }

  if (Validator.isEmpty(data.participant)) {
    errors.participant = 'Participant field is required';
  }

  return {
    errors,
    isValid: isEmpty(errors)
  };
};