我使用r训练了插入符号和nnet的人工神经网络算法。我正在尝试生成一个有意义的输出-理想情况下使用Confusion Matrix-但继续出现诸如“数据和引用应该是具有相同水平的因子”或“参数必须具有相同长度”之类的错误。
pitchData <- read.csv(file.choose(), header = T)
summary(pitchData)
set.seed(75)
DataSplit <- createDataPartition(cleanPitch$type, p = 0.75, list = FALSE)
trainData = cleanPitch[DataSplit,]
testData = cleanPitch[-DataSplit,]
#ANN for pitcher's case -- physical description variables only
set.seed(2713)
ANNscout <- train(type ~ code + pitch_type + b_score + b_count + s_count + outs + pitch_num + on_1b + on_2b + on_3b,
data = trainData, method = "nnet", trace = FALSE)
summary(ANNscout)
predictScout = predict(ANNscout, newData = testData)
confusionMatrix(testData$type, ANNscout)
错误发生在confusionMatrix(testData $ type,ANNscout)。我还尝试了confusionMatrix(predictScout,testData $ type),因为汇总时它们的输出为:
> summary(testData$type)
B S X
65126 82996 31456
> summary(predictScout)
B S X
195279 248965 94492
,我认为这些是相同的因子长度,等等。
我也尝试过按照其他地方的建议使用table()函数,但这似乎无法解决根本问题。
链接到数据集:https://www.kaggle.com/pschale/mlb-pitch-data-20152018#pitches.csv
答案 0 :(得分:0)
您对错误的对象做了混淆矩阵。应该是
confusionMatrix(testData$type, predictScout)
以下是可重现的示例:
library(caret)
pitchData <- read.csv("pitches.csv.gz",header=TRUE)
set.seed(75)
COLS <- c("type","code","pitch_type","b_score","b_count",
"s_count","outs","pitch_num","on_1b","on_2b","on_3b")
trainData = pitchData[sample(1:nrow(pitchData),100),COLS]
testData = pitchData[sample(1:nrow(pitchData),100),COLS]
#ANN for pitcher's case -- physical description variables only
set.seed(2713)
ANNscout <- train(type ~ code + pitch_type + b_score + b_count + s_count + outs + pitch_num + on_1b + on_2b + on_3b,
data = trainData, method = "nnet", trace = FALSE)
predictScout = predict(ANNscout, newData = testData)
#throws error
confusionMatrix(testData$type, ANNscout)
#correct call
confusionMatrix(testData$type, predictScout)
Confusion Matrix and Statistics
Reference
Prediction B S X
B 144 161 84
S 137 209 84
X 77 68 36
Overall Statistics
Accuracy : 0.389
95% CI : (0.3586, 0.42)
No Information Rate : 0.438
P-Value [Acc > NIR] : 0.9992
以下数据:
trainData <-
> dput(trainData)
structure(list(type = structure(c(2L, 2L, 2L, 2L, 1L, 1L, 3L,
2L, 2L, 3L, 2L, 2L, 1L, 1L, 1L, 3L, 2L, 3L, 3L, 1L, 3L, 2L, 2L,
2L, 2L, 2L, 2L, 2L, 1L, 2L, 1L, 1L, 2L, 3L, 1L, 2L, 2L, 1L, 3L,
3L, 1L, 3L, 2L, 1L, 1L, 1L, 1L, 3L, 1L, 2L, 1L, 2L, 1L, 1L, 2L,
2L, 2L, 2L, 1L, 1L, 1L, 2L, 2L, 2L, 1L, 1L, 1L, 3L, 1L, 2L, 1L,
2L, 1L, 1L, 2L, 3L, 2L, 2L, 2L, 2L, 1L, 3L, 3L, 1L, 2L, 1L, 1L,
2L, 2L, 1L, 1L, 2L, 1L, 1L, 1L, 3L, 1L, 2L, 2L, 2L), .Label = c("B",
"S", "X"), class = "factor"), code = structure(c(7L, 4L, 7L,
7L, 3L, 3L, 6L, 16L, 4L, 5L, 4L, 15L, 3L, 3L, 3L, 19L, 4L, 19L,
19L, 3L, 19L, 18L, 7L, 4L, 4L, 7L, 4L, 4L, 3L, 15L, 3L, 3L, 15L,
5L, 3L, 15L, 7L, 3L, 19L, 6L, 3L, 19L, 15L, 3L, 3L, 3L, 3L, 19L,
3L, 15L, 3L, 7L, 3L, 3L, 4L, 7L, 15L, 4L, 3L, 3L, 3L, 4L, 7L,
7L, 2L, 3L, 3L, 5L, 3L, 7L, 3L, 7L, 3L, 3L, 4L, 19L, 4L, 4L,
4L, 16L, 3L, 19L, 19L, 8L, 7L, 3L, 3L, 4L, 4L, 3L, 3L, 15L, 3L,
3L, 3L, 19L, 3L, 4L, 4L, 1L), .Label = c("", "*B", "B", "C",
"D", "E", "F", "H", "I", "L", "M", "P", "Q", "R", "S", "T", "V",
"W", "X", "Z"), class = "factor"), pitch_type = structure(c(13L,
4L, 8L, 11L, 17L, 8L, 3L, 11L, 8L, 11L, 8L, 11L, 17L, 8L, 18L,
11L, 11L, 11L, 4L, 4L, 11L, 18L, 8L, 11L, 14L, 8L, 8L, 8L, 11L,
4L, 8L, 18L, 8L, 18L, 18L, 8L, 17L, 18L, 4L, 11L, 3L, 3L, 18L,
11L, 8L, 8L, 4L, 11L, 8L, 8L, 8L, 8L, 3L, 13L, 17L, 8L, 4L, 8L,
18L, 8L, 18L, 17L, 8L, 18L, 18L, 17L, 8L, 8L, 8L, 8L, 8L, 3L,
8L, 17L, 8L, 17L, 7L, 18L, 8L, 8L, 11L, 17L, 18L, 11L, 3L, 7L,
8L, 11L, 18L, 17L, 18L, 4L, 8L, 11L, 3L, 8L, 8L, 8L, 3L, 3L), .Label = c("",
"AB", "CH", "CU", "EP", "FA", "FC", "FF", "FO", "FS", "FT", "IN",
"KC", "KN", "PO", "SC", "SI", "SL", "UN"), class = "factor"),
b_score = c(2, 0, 8, 0, 0, 3, 7, 2, 3, 0, 1, 6, 7, 3, 6,
2, 1, 1, 0, 4, 0, 0, 6, 2, 0, 0, 0, 0, 2, 1, 2, 1, 0, 8,
0, 1, 2, 2, 0, 0, 0, 1, 4, 2, 3, 0, 3, 8, 1, 0, 0, 0, 2,
0, 3, 2, 7, 1, 5, 1, 2, 0, 0, 0, 4, 6, 5, 0, 3, 10, 2, 1,
1, 0, 0, 0, 0, 3, 0, 0, 1, 0, 0, 6, 1, 2, 1, 6, 1, 2, 2,
1, 0, 0, 3, 1, 0, 6, 0, 4), b_count = c(1, 0, 3, 0, 2, 0,
2, 0, 2, 3, 0, 0, 0, 0, 0, 3, 0, 0, 1, 0, 0, 1, 0, 3, 1,
2, 2, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 2, 0, 1,
1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 0, 1, 0, 0, 0, 1,
2, 3, 3, 2, 1, 3, 2, 1, 1, 0, 1, 3, 1, 1, 3, 0, 0, 0, 1,
0, 1, 1, 0, 2, 2, 2, 0, 0, 1, 2, 1, 1, 0, 0, 1, 0, 1), s_count = c(2,
0, 2, 1, 0, 0, 1, 1, 2, 1, 1, 1, 0, 0, 0, 2, 0, 1, 2, 1,
0, 2, 1, 0, 1, 2, 2, 0, 2, 1, 0, 1, 2, 2, 0, 0, 0, 0, 2,
0, 1, 2, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 1, 0, 1, 0, 2, 0,
2, 0, 0, 0, 1, 2, 2, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
1, 0, 0, 0, 0, 2, 2, 2, 0, 2, 1, 0, 1, 0, 1, 2, 2, 1, 1,
1, 1, 0, 2), outs = c(1, 2, 2, 0, 0, 0, 0, 1, 2, 0, 2, 2,
2, 0, 1, 2, 0, 1, 0, 2, 2, 0, 1, 1, 0, 2, 0, 0, 2, 1, 1,
2, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0,
1, 0, 1, 0, 0, 2, 2, 0, 1, 1, 1, 2, 0, 0, 1, 2, 0, 2, 2,
0, 2, 1, 1, 0, 0, 1, 0, 2, 1, 1, 2, 0, 0, 1, 0, 2, 2, 1,
1, 2, 2, 0, 0, 1, 0, 0, 0, 0, 1, 0), pitch_num = c(4, 1,
7, 2, 3, 1, 4, 2, 5, 5, 2, 2, 1, 1, 1, 9, 1, 2, 4, 2, 1,
4, 2, 4, 3, 5, 5, 1, 5, 2, 2, 3, 5, 4, 2, 2, 2, 1, 4, 1,
2, 5, 1, 2, 2, 2, 2, 3, 1, 1, 1, 1, 2, 1, 3, 2, 5, 1, 4,
1, 1, 1, 3, 5, 6, 4, 4, 2, 4, 3, 2, 2, 1, 2, 4, 2, 3, 5,
1, 1, 1, 2, 3, 6, 4, 1, 9, 4, 3, 2, 1, 3, 7, 4, 3, 2, 2,
3, 1, 6), on_1b = c(0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1,
0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1,
0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1,
0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0,
0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1), on_2b = c(0, 1, 1, 0, 0,
0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0),
on_3b = c(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0)), row.names = c(2572335L, 2473476L,
1773596L, 1225945L, 533967L, 538045L, 490237L, 2032341L, 1471986L,
885702L, 1452254L, 925502L, 2844400L, 1491337L, 1892438L, 298502L,
942660L, 23942L, 1408325L, 1079270L, 85860L, 285803L, 1986834L,
847131L, 1857056L, 2291578L, 2123168L, 1938826L, 738920L, 1908592L,
864041L, 1065750L, 1526969L, 167075L, 2678451L, 620148L, 1925244L,
2473758L, 1937885L, 1794853L, 1822660L, 1740177L, 772699L, 2634322L,
1586198L, 848353L, 773501L, 1263338L, 2187469L, 750081L, 2012897L,
476469L, 2585175L, 144072L, 465180L, 805632L, 1070865L, 1506775L,
305483L, 2575955L, 731114L, 842298L, 1437696L, 2821218L, 311496L,
2630474L, 1390322L, 1345644L, 1595676L, 1716804L, 477467L, 2098794L,
2161436L, 1422499L, 695945L, 90115L, 2711212L, 803555L, 1485935L,
1051585L, 171143L, 2618698L, 949978L, 1972503L, 612458L, 522836L,
235720L, 1150378L, 1892321L, 2599680L, 2025615L, 2675539L, 903434L,
1555927L, 178249L, 2747402L, 2186038L, 1544343L, 1995247L, 519849L
), class = "data.frame")
testData <-
structure(list(type = structure(c(2L, 1L, 2L, 2L, 2L, 1L, 2L,
2L, 2L, 1L, 1L, 2L, 2L, 3L, 2L, 2L, 2L, 1L, 1L, 1L, 1L, 1L, 1L,
3L, 1L, 2L, 2L, 2L, 1L, 2L, 3L, 1L, 3L, 2L, 2L, 2L, 2L, 2L, 1L,
2L, 2L, 1L, 2L, 2L, 1L, 3L, 1L, 2L, 1L, 3L, 3L, 2L, 2L, 1L, 2L,
1L, 2L, 2L, 1L, 2L, 2L, 2L, 2L, 1L, 1L, 2L, 3L, 2L, 1L, 2L, 1L,
2L, 1L, 3L, 2L, 1L, 3L, 1L, 1L, 2L, 2L, 2L, 2L, 3L, 1L, 2L, 2L,
3L, 3L, 3L, 1L, 2L, 2L, 3L, 1L, 1L, 2L, 1L, 3L, 2L), .Label = c("B",
"S", "X"), class = "factor"), code = structure(c(4L, 3L, 15L,
7L, 7L, 3L, 7L, 7L, 7L, 2L, 2L, 15L, 15L, 19L, 7L, 7L, 7L, 3L,
3L, 3L, 2L, 3L, 3L, 5L, 3L, 7L, 4L, 15L, 3L, 4L, 5L, 3L, 19L,
7L, 7L, 7L, 4L, 7L, 3L, 7L, 4L, 3L, 15L, 7L, 3L, 5L, 3L, 7L,
3L, 5L, 19L, 7L, 15L, 3L, 7L, 3L, 7L, 4L, 3L, 15L, 7L, 7L, 4L,
3L, 3L, 4L, 19L, 4L, 3L, 15L, 3L, 7L, 3L, 19L, 7L, 3L, 19L, 3L,
3L, 7L, 15L, 7L, 4L, 19L, 3L, 4L, 4L, 19L, 19L, 5L, 3L, 7L, 4L,
5L, 3L, 2L, 7L, 3L, 5L, 15L), .Label = c("", "*B", "B", "C",
"D", "E", "F", "H", "I", "L", "M", "P", "Q", "R", "S", "T", "V",
"W", "X", "Z"), class = "factor"), pitch_type = structure(c(18L,
8L, 3L, 17L, 18L, 18L, 3L, 8L, 8L, 3L, 3L, 18L, 3L, 3L, 8L, 11L,
8L, 8L, 3L, 18L, 18L, 17L, 3L, 11L, 17L, 8L, 8L, 8L, 3L, 18L,
3L, 18L, 8L, 8L, 11L, 8L, 8L, 3L, 4L, 18L, 8L, 3L, 8L, 13L, 3L,
3L, 17L, 8L, 1L, 8L, 17L, 8L, 8L, 8L, 8L, 18L, 17L, 8L, 17L,
10L, 18L, 3L, 8L, 8L, 8L, 8L, 18L, 18L, 3L, 8L, 11L, 3L, 8L,
8L, 11L, 4L, 11L, 11L, 4L, 8L, 4L, 11L, 4L, 4L, 8L, 11L, 8L,
11L, 18L, 18L, 11L, 8L, 8L, 18L, 3L, 13L, 8L, 8L, 11L, 8L), .Label = c("",
"AB", "CH", "CU", "EP", "FA", "FC", "FF", "FO", "FS", "FT", "IN",
"KC", "KN", "PO", "SC", "SI", "SL", "UN"), class = "factor"),
b_score = c(5, 8, 0, 0, 0, 0, 10, 4, 1, 2, 12, 0, 2, 3, 4,
4, 0, 3, 0, 6, 3, 2, 3, 0, 0, 0, 2, 3, 0, 2, 0, 2, 7, 1,
7, 1, 4, 0, 2, 3, 2, 4, 5, 0, 2, 6, 0, 2, 1, 4, 2, 0, 1,
2, 2, 2, 1, 2, 3, 0, 3, 6, 6, 3, 0, 6, 0, 6, 1, 5, 1, 1,
0, 0, 5, 14, 1, 0, 1, 1, 0, 3, 1, 7, 0, 6, 1, 0, 0, 1, 1,
4, 4, 7, 1, 0, 11, 3, 3, 10), b_count = c(0, 1, 2, 0, 3,
1, 0, 2, 3, 0, 1, 1, 0, 1, 0, 1, 2, 1, 1, 1, 3, 0, 1, 1,
0, 1, 0, 2, 0, 1, 1, 1, 0, 3, 0, 1, 0, 1, 0, 2, 0, 1, 1,
0, 0, 1, 0, 0, 0, 1, 2, 3, 1, 0, 0, 0, 0, 2, 0, 1, 2, 0,
0, 0, 0, 0, 1, 1, 0, 0, 2, 0, 0, 3, 1, 0, 0, 1, 1, 2, 0,
0, 0, 2, 0, 0, 0, 1, 0, 0, 3, 1, 0, 2, 0, 0, 0, 0, 1, 0),
s_count = c(0, 2, 0, 1, 2, 0, 1, 2, 1, 0, 0, 1, 1, 0, 1,
0, 2, 1, 2, 0, 2, 1, 2, 1, 1, 0, 0, 2, 1, 1, 1, 0, 0, 2,
1, 2, 0, 2, 0, 2, 0, 2, 1, 1, 0, 1, 0, 2, 1, 0, 1, 2, 0,
2, 1, 1, 2, 1, 1, 1, 2, 0, 0, 1, 1, 0, 0, 1, 0, 1, 2, 1,
0, 2, 1, 1, 0, 1, 2, 2, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1,
0, 0, 2, 0, 2, 0, 0, 1, 1), outs = c(1, 1, 0, 0, 1, 1, 2,
1, 2, 1, 2, 1, 2, 2, 0, 2, 0, 2, 1, 1, 1, 0, 2, 1, 0, 0,
2, 2, 0, 1, 2, 1, 0, 2, 2, 0, 1, 1, 2, 0, 2, 1, 2, 0, 1,
0, 1, 2, 2, 0, 2, 2, 1, 1, 2, 2, 1, 1, 0, 2, 1, 1, 0, 1,
1, 0, 0, 0, 2, 0, 0, 1, 2, 2, 2, 2, 1, 1, 0, 2, 1, 1, 1,
2, 0, 2, 0, 1, 2, 2, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0), pitch_num = c(1,
4, 3, 2, 7, 2, 2, 5, 5, 1, 2, 3, 2, 2, 2, 2, 5, 3, 4, 2,
7, 2, 4, 3, 2, 2, 1, 5, 2, 3, 3, 2, 1, 7, 2, 4, 1, 4, 1,
6, 1, 5, 3, 2, 1, 3, 1, 3, 2, 2, 4, 7, 2, 3, 2, 2, 5, 4,
2, 3, 5, 1, 1, 2, 2, 1, 2, 3, 1, 2, 6, 2, 1, 8, 3, 2, 1,
3, 5, 5, 2, 1, 1, 4, 1, 1, 1, 2, 1, 1, 5, 2, 1, 5, 1, 4,
1, 1, 3, 2), on_1b = c(1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0,
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1,
0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1,
0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0,
1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0,
1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1), on_2b = c(1, 0, 0, 0,
0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,
0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0,
0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1,
1), on_3b = c(0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0,
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,
1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0)), row.names = c(1626717L, 1405275L,
122726L, 1108751L, 2093357L, 179555L, 2607549L, 1412512L, 638641L,
416297L, 2598530L, 382011L, 157089L, 1593057L, 1019948L, 1694849L,
2708797L, 2700695L, 2101221L, 1561311L, 1122920L, 1041141L, 2010001L,
1444355L, 715603L, 1977289L, 1769783L, 2759144L, 2689437L, 167822L,
2170300L, 377341L, 1608546L, 2118229L, 1733000L, 1211742L, 327439L,
2462799L, 1488833L, 2789777L, 1344427L, 634185L, 1701253L, 1032484L,
1739743L, 2685610L, 215036L, 2071714L, 548057L, 2173749L, 972140L,
2254162L, 984278L, 2581566L, 2773565L, 2526454L, 2796506L, 538995L,
1941397L, 1977176L, 2377087L, 2338603L, 2548693L, 1586820L, 1003003L,
1949343L, 2742108L, 1892406L, 1165382L, 2587610L, 197021L, 2245196L,
2341574L, 280293L, 2160792L, 2534355L, 309769L, 834342L, 1428917L,
1342288L, 1888427L, 1633950L, 2475553L, 1867478L, 642364L, 2724974L,
437424L, 430426L, 2575340L, 2370985L, 1070394L, 2190709L, 820645L,
271133L, 1820342L, 2396685L, 2062304L, 782662L, 422914L, 1030451L
), class = "data.frame")