81-mlr3初体验

1、创建任务

library(pacman)
p_load(mlr3)

str(iris)
## 'data.frame':    150 obs. of  5 variables:
##  $ Sepal.Length: num  5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ...
##  $ Sepal.Width : num  3.5 3 3.2 3.1 3.6 3.9 3.4 3.4 2.9 3.1 ...
##  $ Petal.Length: num  1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...
##  $ Petal.Width : num  0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...
##  $ Species     : Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 1 1 1 1 1 1 ...
# 创建任务
task <- TaskClassif$new(id = "iris", backend = iris, target = "Species")

2、选择学习器

# 机器学习算法,使用rpart
lrner <- lrn("classif.rpart", cp = 0.1, minsplit = 10)

3、拆分训练集和测试集

set.seed(123)
# 按照8:2拆分
dtrain <- sample(task$nrow, 0.8 * task$nrow)
dtest <- setdiff(seq_len(task$nrow), dtrain)

4、训练模型

lrner$train(task, row_ids = dtrain)
# 查看训练好的模型
lrner$model
## n= 120 
## 
## node), split, n, loss, yval, (yprob)
##       * denotes terminal node
## 
## 1) root 120 75 virginica (0.33333333 0.29166667 0.37500000)  
##   2) Petal.Length< 2.45 40  0 setosa (1.00000000 0.00000000 0.00000000) *
##   3) Petal.Length>=2.45 80 35 virginica (0.00000000 0.43750000 0.56250000)  
##     6) Petal.Length< 4.75 32  1 versicolor (0.00000000 0.96875000 0.03125000) *
##     7) Petal.Length>=4.75 48  4 virginica (0.00000000 0.08333333 0.91666667) *

5、预测

pred <- lrner$predict(task, row_ids = dtest)
pred$response
##  [1] setosa     setosa     setosa     setosa     setosa     setosa     setosa     setosa     setosa     setosa    
## [11] versicolor versicolor versicolor versicolor versicolor versicolor versicolor versicolor versicolor versicolor
## [21] virginica  versicolor virginica  versicolor versicolor virginica  virginica  virginica  virginica  virginica 
## Levels: setosa versicolor virginica

6、模型评估

# 混淆矩阵
pred$confusion
##             truth
## response     setosa versicolor virginica
##   setosa         10          0         0
##   versicolor      0         13         0
##   virginica       0          2         5
# 准确率
pred$score(msr("classif.acc"))
## classif.acc 
##   0.9333333

7、交叉验证

resampling <- rsmp("cv", folds = 10L)
rr <- resample(task = task, 
               learner = lrner, 
               resampling = resampling)
rr$aggregate(measures = msr("classif.acc"))
## classif.acc 
##   0.9266667

你可能感兴趣的:(81-mlr3初体验)