我想在python中使用autograd使用scipy.stats.norm查找正态分布pdf的简单渐变。
import scipy.stats as stat
import autograd.numpy as np
from autograd import grad
def f(x):
return stat.norm.pdf(x, 0.0, 1.0)
grad_f = grad(f)
print(grad_f(-1.0))
但是,我遇到了这种错字:
Traceback (most recent call last):
File "/Users/Lars/anaconda3/lib/python3.6/site-packages/autograd/core.py", line 62, in forward_pass
try: end_node = fun(*args, **kwargs)
File "error.py", line 7, in f
return stat.norm.pdf(x, 0.0, 1.0)
File "/Users/Lars/anaconda3/lib/python3.6/site-packages/scipy/stats/_distn_infrastructure.py", line 1657, in pdf
putmask(output, (1-cond0)+np.isnan(x), self.badvalue)
TypeError: ufunc 'isnan' not supported for the input types, and the inputs could not be safely coerced to any supported typesaccording to the casting rule ''safe''
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "error.py", line 11, in <module>
print(grad_f(-1.0))
File "/Users/Lars/anaconda3/lib/python3.6/site-packages/autograd/core.py", line 21, in gradfun
return backward_pass(*forward_pass(fun,args,kwargs,argnum))
File "/Users/Lars/anaconda3/lib/python3.6/site-packages/autograd/core.py", line 63, in forward_pass
except Exception as e: add_extra_error_message(e)
File "/Users/Lars/anaconda3/lib/python3.6/site-packages/autograd/core.py", line 392, in add_extra_error_message
raise_(etype, value, traceback)
File "/Users/Lars/anaconda3/lib/python3.6/site-packages/future/utils/__init__.py", line 413, in raise_
raise exc.with_traceback(tb)
File "/Users/Lars/anaconda3/lib/python3.6/site-packages/autograd/core.py", line 62, in forward_pass
try: end_node = fun(*args, **kwargs)
File "error.py", line 7, in f
return stat.norm.pdf(x, 0.0, 1.0)
File "/Users/Lars/anaconda3/lib/python3.6/site-packages/scipy/stats/_distn_infrastructure.py", line 1657, in pdf
putmask(output, (1-cond0)+np.isnan(x), self.badvalue)
TypeError: ufunc 'isnan' not supported for the input types, and the inputs could not be safely coerced to any supported typesaccording to the casting rule ''safe''
对不起,代码过载。我不知道这可能有什么问题。据我所知,autograd支持scipy.stats.norm.pdf()/ cdf()/ logpdf()/ logcdf()的渐变,如代码所示 https://github.com/HIPS/autograd/blob/master/autograd/scipy/stats/norm.py
答案 0 :(得分:1)
您需要从autograd导入scipy,因为它将适当包装scipy的功能。以下作品:
// this is from models/Battle
const mongoose = require('mongoose');
const Schema = mongoose.Schema;
// Create Schema
const BattleSchema = new Schema({
user: {
type: Schema.Types.ObjectId,
ref: 'users'
},
date: {
type: Date,
default: Date.now
},
category: {
type: Number,
required: true // this will come from the selected category
},
winner: {
type: Number,
default: 0
},
status: {
type: Number,
default: 0 // 0 means the battle is closed, 1 means the battle is open for votes, the status will stay 0 until all participants dropped
},
participants: [
{
participant: {
type: Schema.Types.ObjectId,
required: true
}
}
]
});
module.exports = Battle = mongoose.model('battles', BattleSchema);
//this is from routes/api/battles
// @route POST api/battles
// @desc Create battle
// @access Private
router.post(
'/create-battle',
passport.authenticate('jwt', { session: false }),
(req, res) => {
const { errors, isValid } = validateBattleInput(req.body);
// Check Validation
if (!isValid) {
// If any errors, send 400 with errors object
return res.status(400).json(errors);
console.log(errors);
}
const newBattle = new Battle({
user: req.user.id,
category: req.body.category,
participant: req.body.participant
});
//save
newBattle.save().then(battle => {
// const participant = req.body.participant;
const participant = req.body.participant;
// add participants to array
battle.participants.push( participant );
console.log(typeof req.body.participant);
// get the inserted id
const battleId = battle._id;
res.json(battle);
});
}
);
// this is battle validation
const Validator = require('validator');
const isEmpty = require('./is-empty');
var bodyParser = require('body-parser');
module.exports = function validateBattleInput(data) {
let errors = {};
data.category = !isEmpty(data.category) ? data.category : '';
data.participant = !isEmpty(data.participant) ? data.participant : '';
if (Validator.isEmpty(data.category)) {
errors.category = 'Category field is required';
}
// if (Validator.isEmpty(data.challenger)) {
// errors.challenger = 'Challenger field is required';
// }
if (Validator.isEmpty(data.participant)) {
errors.participant = 'Participant field is required';
}
return {
errors,
isValid: isEmpty(errors)
};
};