我知道我的标题问题的理论答案,该问题在Stack Overflow上讨论here或this之前的问题。我的问题是,即使考虑一些数字舍入,我使用R
函数multinom
中拟合的系数计算的概率权重与直接从同一函数获得的权重(通过{{1 }})。我尝试在predict(fit, newdata = dat, "probs")
和Java
中以数字方式计算这些权重,并且在两种实现中我都获得了相同的结果,这实际上与R
返回的值不同。
您是否知道我如何发现predict
函数predict(..., "probs")
的函数R
的实现?
答案 0 :(得分:2)
我首先安装nnet
并打开nnet
功能的帮助页面。我看到该函数创建了一个nnet
对象。
我尝试predict.nnet
但没有出现。这意味着包装未加载,功能不存在或隐藏。 methods("predict")
显示该对象实际上是隐藏的(由*
表示。)
> methods("predict")
[1] predict.ar* predict.Arima* predict.arima0* predict.glm
[5] predict.HoltWinters* predict.lm predict.loess* predict.mlm
[9] predict.multinom* predict.nls* predict.nnet* predict.poly
[13] predict.ppr* predict.prcomp* predict.princomp* predict.smooth.spline*
[17] predict.smooth.spline.fit* predict.StructTS*
显式调用此函数会显示其代码。
> nnet:::predict.nnet
function (object, newdata, type = c("raw", "class"), ...)
{
if (!inherits(object, "nnet"))
stop("object not of class \"nnet\"")
type <- match.arg(type)
if (missing(newdata))
z <- fitted(object)
else {
if (inherits(object, "nnet.formula")) {
newdata <- as.data.frame(newdata)
rn <- row.names(newdata)
Terms <- delete.response(object$terms)
m <- model.frame(Terms, newdata, na.action = na.omit,
xlev = object$xlevels)
if (!is.null(cl <- attr(Terms, "dataClasses")))
.checkMFClasses(cl, m)
keep <- match(row.names(m), rn)
x <- model.matrix(Terms, m, contrasts = object$contrasts)
xint <- match("(Intercept)", colnames(x), nomatch = 0L)
if (xint > 0L)
x <- x[, -xint, drop = FALSE]
}
else {
if (is.null(dim(newdata)))
dim(newdata) <- c(1L, length(newdata))
x <- as.matrix(newdata)
if (any(is.na(x)))
stop("missing values in 'x'")
keep <- 1L:nrow(x)
rn <- rownames(x)
}
ntr <- nrow(x)
nout <- object$n[3L]
.C(VR_set_net, as.integer(object$n), as.integer(object$nconn),
as.integer(object$conn), rep(0, length(object$wts)),
as.integer(object$nsunits), as.integer(0L), as.integer(object$softmax),
as.integer(object$censored))
z <- matrix(NA, nrow(newdata), nout, dimnames = list(rn,
dimnames(object$fitted.values)[[2L]]))
z[keep, ] <- matrix(.C(VR_nntest, as.integer(ntr), as.double(x),
tclass = double(ntr * nout), as.double(object$wts))$tclass,
ntr, nout)
.C(VR_unset_net)
}
switch(type, raw = z, class = {
if (is.null(object$lev)) stop("inappropriate fit for class")
if (ncol(z) > 1L) object$lev[max.col(z)] else object$lev[1L +
(z > 0.5)]
})
}
<bytecode: 0x0000000009305fd8>
<environment: namespace:nnet>