来自R package nnet的函数multinom如何计算多项式概率权重?

时间:2014-04-07 07:15:18

标签: r multinomial

我知道我的标题问题的理​​论答案,该问题在Stack Overflow上讨论herethis之前的问题。我的问题是,即使考虑一些数字舍入,我使用R函数multinom中拟合的系数计算的概率权重与直接从同一函数获得的权重(通过{{1 }})。我尝试在predict(fit, newdata = dat, "probs")Java中以数字方式计算这些权重,并且在两种实现中我都获得了相同的结果,这实际上与R返回的值不同。

您是否知道我如何发现predict函数predict(..., "probs")的函数R的实现?

1 个答案:

答案 0 :(得分:2)

我首先安装nnet并打开nnet功能的帮助页面。我看到该函数创建了一个nnet对象。

我尝试predict.nnet但没有出现。这意味着包装未加载,功能不存在或隐藏。 methods("predict")显示该对象实际上是隐藏的(由*表示。)

> methods("predict")
 [1] predict.ar*                predict.Arima*             predict.arima0*            predict.glm               
 [5] predict.HoltWinters*       predict.lm                 predict.loess*             predict.mlm               
 [9] predict.multinom*          predict.nls*               predict.nnet*              predict.poly              
[13] predict.ppr*               predict.prcomp*            predict.princomp*          predict.smooth.spline*    
[17] predict.smooth.spline.fit* predict.StructTS*    

显式调用此函数会显示其代码。

> nnet:::predict.nnet
function (object, newdata, type = c("raw", "class"), ...) 
{
    if (!inherits(object, "nnet")) 
        stop("object not of class \"nnet\"")
    type <- match.arg(type)
    if (missing(newdata)) 
        z <- fitted(object)
    else {
        if (inherits(object, "nnet.formula")) {
            newdata <- as.data.frame(newdata)
            rn <- row.names(newdata)
            Terms <- delete.response(object$terms)
            m <- model.frame(Terms, newdata, na.action = na.omit, 
                xlev = object$xlevels)
            if (!is.null(cl <- attr(Terms, "dataClasses"))) 
                .checkMFClasses(cl, m)
            keep <- match(row.names(m), rn)
            x <- model.matrix(Terms, m, contrasts = object$contrasts)
            xint <- match("(Intercept)", colnames(x), nomatch = 0L)
            if (xint > 0L) 
                x <- x[, -xint, drop = FALSE]
        }
        else {
            if (is.null(dim(newdata))) 
                dim(newdata) <- c(1L, length(newdata))
            x <- as.matrix(newdata)
            if (any(is.na(x))) 
                stop("missing values in 'x'")
            keep <- 1L:nrow(x)
            rn <- rownames(x)
        }
        ntr <- nrow(x)
        nout <- object$n[3L]
        .C(VR_set_net, as.integer(object$n), as.integer(object$nconn), 
            as.integer(object$conn), rep(0, length(object$wts)), 
            as.integer(object$nsunits), as.integer(0L), as.integer(object$softmax), 
            as.integer(object$censored))
        z <- matrix(NA, nrow(newdata), nout, dimnames = list(rn, 
            dimnames(object$fitted.values)[[2L]]))
        z[keep, ] <- matrix(.C(VR_nntest, as.integer(ntr), as.double(x), 
            tclass = double(ntr * nout), as.double(object$wts))$tclass, 
            ntr, nout)
        .C(VR_unset_net)
    }
    switch(type, raw = z, class = {
        if (is.null(object$lev)) stop("inappropriate fit for class")
        if (ncol(z) > 1L) object$lev[max.col(z)] else object$lev[1L + 
            (z > 0.5)]
    })
}
<bytecode: 0x0000000009305fd8>
<environment: namespace:nnet>