作者:黄天元,复旦大学博士在读,目前研究涉及文本挖掘、社交网络分析和机器学习等。希望与大家分享学习经验,推广并加深R语言在业界的应用。
微信:hope9057
Kaggle上最经典的泰坦尼克号入门级教程,我们这里尝试玩转它(https://www.kaggle.com/c/titanic)。先讲数据背景,我们有各种各样的乘客数据,想要利用这些数据,预测在泰坦尼克号沉船的时候,这个乘客是否能够存活。具体的数据字典可以参照:
https://www.kaggle.com/c/titanic/data。
先导入数据
#数据导入 set.seed(201891) library(pacman) p_load(tidyverse) p_load(caret,caretEnsemble) setwd("E:\\_data_hope\\Titanic\\data") read_csv("train.csv") -> train_raw1 read_csv("test.csv") -> test_raw1 read_csv("gender_submission.csv") -> gs
人工筛选变量是第一步,这是机器学习无法逾越的高度,因为我们知道哪些变量是真正“有关”的,哪些即使是真的提高了预测精度也只是假象而已。我们应该知道,乘客的ID号码,乘客叫什么名字,乘客在哪里上船,还有买票的号码,是与存活率完全没有直接关系的,直接删除掉。
train_raw1 %>% select(-PassengerId,-Name,-Ticket,-Embarked) -> train_raw2 test_raw1 %>% select(-PassengerId,-Name,-Ticket,-Embarked) -> test_raw2
如果数据中有一些属性含有大量缺失值,那么它对预测的贡献几乎为零,甚至具有不良的干扰。当然有的时候缺和不缺本来就是一种信息,但是这里我们无法深入判断。首先我们先看看是否有缺失值,有的话缺多少?
p_load(VIM,Amelia) missmap(train_raw2)
missmap(test_raw2)
train_raw2 %>% aggr()
test_raw2 %>% aggr()
Cabin,也就是舱位号码缺了很多,因此我们应该直接删除掉整列。年龄数据存在缺失,但是缺失比例不大,而且年龄可能会提供重要信息,所以需要保留。能够直接删除缺失行吗?答案是不行,因为待预测的验证集包含有缺失值,因此必须对它们进行必要的处理才行。
这个例子中,我倾向于使用KNN插值法,原理就是,相似的乘客可能会有相同的年纪。需要注意的是,KNN插值法不允许变量中包含有非数值型变量,因此这里直接先转为因子再转为数值。性别只有两个,因此没有关系,直接化为因子就可以。如果有多于两个的因子,应该先用one-hot encoding这种方法把它化为稀疏矩阵再来做。
p_load(DMwR) #KNN插值法需要用的包 train_raw2 %>% select(-Cabin) -> train1 test_raw2 %>% select(-Cabin) -> test1 train1 %>% mutate(Sex=as.numeric(as.factor(Sex))) %>% as.data.frame() %>% knnImputation() %>% pull(Age) -> train_age test1 %>% mutate(Sex=as.numeric(as.factor(Sex))) %>% as.data.frame() %>% knnImputation() %>% pull(Age) -> test_age train1 %>% mutate(Age=train_age) -> train.wash test1 %>% mutate(Age=test_age) -> test.wash
这样一来我们就得到了清洗好的训练集train.wash和测试集test.wash。
一般建模之初,应该设定两个模型:零模型与全模型。零模型即随机猜测我们能够得到的正确率。什么?你认为是50%?这不对,虽然我们最终结果只有存活和不存活,但是因为样本中存活和非存活的比例不同,因此需要特殊对待。
train.wash %>% count(Survived) %>% mutate(n/sum(n)) ## # A tibble: 2 x 3 ## Survived n `n/sum(n)` ## <int> <int> <dbl> ## 10 549 0.616 ## 21 342 0.384 gs %>% count(Survived) %>% mutate(n/sum(n)) ## # A tibble: 2 x 3 ## Survived n `n/sum(n)` ## <int> <int> <dbl> ## 10 266 0.636 ## 21 152 0.364
我们可以看到,有61.6%的乘客最后不能存活,38.4%的乘客可以存活。也就是我们对任意一个乘客都假设他不能够存活,我们就会得到61.6%的准确率。如果我们的模型在训练集中最后准确率不能够超越这个数值,那么就白忙一场了。
在验证集中也一样,如果最终我们的accuracy没有超越63.6%,那么还不如瞎猜这个乘客肯定不能够存活更好。
首先,我们的问题数据量不大,我们看看样本量多少。
train.wash ## # A tibble: 891 x 7 ## Survived Pclass Sex Age SibSp Parch Fare ## <int> <int> <chr> <dbl> <int> <int> <dbl> ## 1 0 3 male 22 1 0 7.25 ## 2 1 1 female 38 1 0 71.3 ## 3 1 3 female 26 0 0 7.92 ## 4 1 1 female 35 1 0 53.1 ## 5 0 3 male 35 0 0 8.05 ## 6 0 3 male 27.1 0 0 8.46 ## 7 0 1 male 54 0 0 51.9 ## 8 0 3 male 2 3 1 21.1 ## 9 1 3 female 27 0 2 11.1 ## 10 1 2 female 14 1 0 30.1 ## # ... with 881 more rows
891个样本量的时候,我们决定进行三折交叉验证,不过尝试进行重复的交叉验证,这里我们先重复五次,设定如下:
ctrl= trainControl(method = "repeatedcv",number = 3,repeats=5,search="random", summaryFunction = twoClassSummary, classProbs = TRUE, savePredictions = "final")
注意我们用了search=“random”,从而采取了随机超参数搜索,对于一些模型来说设置网格比较费时,我们先看个大概,因此采用这种方法。需要注意的是,建模前最好把所有变量都转化为数值变量,计算机只认得数字,任何情况都是如此,就算有字符串也是转为因此变量再来做的,我们这里就先转化为因子变量来做。
train.wash %>% mutate(Sex=as.factor(Sex)) %>% mutate(Survived=ifelse(Survived==1,"Alive","Dead")) -> train test.wash %>% mutate(Sex=as.factor(Sex)) -> test gs %>% mutate(Survived=ifelse(Survived==1,"Alive","Dead")) -> gs
能够进行二分类的模型非常多,大类是线性和非线性。线性一般来说解释性强但是效果一般,非线性效果好一点但是解释性弱一点,而且容易出现过拟合。我们用零模型设定了基准,这里我们广泛采用不同的模型看看哪个表现更好。采用的线性模型包括:逻辑回归(glm)、具有惩罚项的逻辑回归(glmnet)、偏最小二乘判别分析(pls)、线性判别分析(lda)和PAM模型(pam)来做;非线性模型包括:非线性判别(mda)、神经网络(nnet)、灵活判别分析(fda)、支持向量机(svm)、K近邻(KNN)、朴素贝叶斯(nb)、随机森林(rf)还有大名鼎鼎的Xgboost(xgbLinear/xgbTree)。需要注意的是,这里神经网络就是三层的全连接神经网络,这个问题还没有如此有“深度”,因此还没有涉及深度学习的领域。为了能够一下子拟合所有模型,我们祭出caretEnsemble::caretList这个利器。这样我们可以对各种模型做一个初筛,虽然只能方便地比较训练集而不是把测试集一起比较了,但是尽管在训练集表现好不一定在测试集表现就好,但是在训练集表现不好的在测试集一般来说一定就不太好。
model_list=caretList( Survived~.,data=train, trControl=ctrl, metric="ROC", preProcess=c("center","scale"), methodList=c("glm","glmnet","pls","lda","pam", "mda","fda","svmRadialCost","knn","nb","rf","xgbLinear","xgbTree"), tuneList = list(nnet=caretModelSpec(method="nnet",trace=F)) ) ## 1234567891011121314151617181920212223242526272829301111111111111111 results <- resamples(model_list) summary(results) ## ## Call: ## summary.resamples(object = results) ## ## Models: nnet, glm, glmnet, pls, lda, pam, mda, fda, svmRadialCost, knn, nb, rf, xgbLinear, xgbTree ## Number of resamples: 15 ## ## ROC ## Min. 1st Qu. Median Mean 3rd Qu. Max. ## nnet 0.8129374 0.8509251 0.8606078 0.8582015 0.8679657 0.9011121 ## glm 0.8117390 0.8493433 0.8603921 0.8581664 0.8661202 0.9032212 ## glmnet 0.8134647 0.8514284 0.8602243 0.8585722 0.8672946 0.9025022 ## pls 0.8159093 0.8515842 0.8592657 0.8583757 0.8660004 0.9016873 ## lda 0.8167242 0.8518958 0.8597929 0.8584827 0.8661442 0.9012559 ## pam 0.7858067 0.8249569 0.8399722 0.8366248 0.8498346 0.8795897 ## mda 0.8059870 0.8384503 0.8553590 0.8514077 0.8654252 0.8974691 ## fda 0.8157655 0.8521474 0.8620458 0.8590595 0.8693917 0.9042997 ## svmRadialCost 0.8175630 0.8466710 0.8583549 0.8574521 0.8698471 0.9058096 ## knn 0.8234829 0.8474140 0.8626929 0.8630508 0.8763901 0.9067683 ## nb 0.7802943 0.8309007 0.8414582 0.8397006 0.8490916 0.8993864 ## rf 0.8379829 0.8630045 0.8843352 0.8818362 0.8948687 0.9349295 ## xgbLinear 0.8255201 0.8616743 0.8763302 0.8782395 0.8927835 0.9344262 ## xgbTree 0.8233870 0.8598289 0.8675822 0.8697121 0.8826934 0.9328204 ## NA's ## nnet 0 ## glm 0 ## glmnet 0 ## pls 0 ## lda 0 ## pam 0 ## mda 0 ## fda 0 ## svmRadialCost 0 ## knn 0 ## nb 0 ## rf 0 ## xgbLinear 0 ## xgbTree 0 ## ## Sens ## Min. 1st Qu. Median Mean 3rd Qu. Max. ## nnet 0.6140351 0.6622807 0.6842105 0.6847953 0.7105263 0.7543860 ## glm 0.6491228 0.6842105 0.7105263 0.7128655 0.7280702 0.8333333 ## glmnet 0.6315789 0.6710526 0.6929825 0.6964912 0.7149123 0.8070175 ## pls 0.6140351 0.6710526 0.7017544 0.6970760 0.7149123 0.8070175 ## lda 0.6228070 0.6710526 0.7017544 0.6988304 0.7192982 0.8070175 ## pam 0.2368421 0.2807018 0.3070175 0.3087719 0.3333333 0.3947368 ## mda 0.6754386 0.6842105 0.7105263 0.7140351 0.7368421 0.7982456 ## fda 0.6491228 0.6842105 0.7105263 0.7187135 0.7368421 0.8245614 ## svmRadialCost 0.6842105 0.7017544 0.7192982 0.7239766 0.7543860 0.7631579 ## knn 0.6491228 0.6842105 0.7017544 0.7093567 0.7324561 0.7894737 ## nb 0.6315789 0.6754386 0.6929825 0.7011696 0.7324561 0.7807018 ## rf 0.6666667 0.7105263 0.7543860 0.7485380 0.7850877 0.8333333 ## xgbLinear 0.6842105 0.7017544 0.7280702 0.7485380 0.7982456 0.8508772 ## xgbTree 0.6315789 0.6666667 0.6929825 0.7029240 0.7280702 0.8421053 ## NA's ## nnet 0 ## glm 0 ## glmnet 0 ## pls 0 ## lda 0 ## pam 0 ## mda 0 ## fda 0 ## svmRadialCost 0 ## knn 0 ## nb 0 ## rf 0 ## xgbLinear 0 ## xgbTree 0 ## ## Spec ## Min. 1st Qu. Median Mean 3rd Qu. Max. ## nnet 0.8469945 0.8633880 0.8688525 0.8699454 0.8797814 0.8961749 ## glm 0.8251366 0.8579235 0.8633880 0.8633880 0.8743169 0.8852459 ## glmnet 0.8360656 0.8551913 0.8633880 0.8652095 0.8797814 0.8907104 ## pls 0.8415301 0.8497268 0.8633880 0.8601093 0.8688525 0.8797814 ## lda 0.8415301 0.8497268 0.8579235 0.8586521 0.8688525 0.8743169 ## pam 0.9508197 0.9781421 0.9890710 0.9857923 0.9945355 1.0000000 ## mda 0.8524590 0.8579235 0.8743169 0.8703097 0.8743169 0.8961749 ## fda 0.8360656 0.8524590 0.8633880 0.8619308 0.8743169 0.8852459 ## svmRadialCost 0.8524590 0.8743169 0.8907104 0.8918033 0.9071038 0.9289617 ## knn 0.8415301 0.8497268 0.8633880 0.8637523 0.8743169 0.8961749 ## nb 0.7978142 0.8251366 0.8469945 0.8404372 0.8524590 0.8743169 ## rf 0.8306011 0.8469945 0.8743169 0.8721311 0.8879781 0.9289617 ## xgbLinear 0.7868852 0.8469945 0.8688525 0.8619308 0.8825137 0.9125683 ## xgbTree 0.8743169 0.8825137 0.8961749 0.8965392 0.9098361 0.9234973 ## NA's ## nnet 0 ## glm 0 ## glmnet 0 ## pls 0 ## lda 0 ## pam 0 ## mda 0 ## fda 0 ## svmRadialCost 0 ## knn 0 ## nb 0 ## rf 0 ## xgbLinear 0 ## xgbTree 0 dotplot(results)
# correlation between results modelCor(results) ## nnet glm glmnet pls lda pam ## nnet 1.0000000 0.9964623 0.9985821 0.9959158 0.9959797 0.9740745 ## glm 0.9964623 1.0000000 0.9983260 0.9948820 0.9955163 0.9594457 ## glmnet 0.9985821 0.9983260 1.0000000 0.9980440 0.9981462 0.9659061 ## pls 0.9959158 0.9948820 0.9980440 1.0000000 0.9997897 0.9642498 ## lda 0.9959797 0.9955163 0.9981462 0.9997897 1.0000000 0.9635982 ## pam 0.9740745 0.9594457 0.9659061 0.9642498 0.9635982 1.0000000 ## mda 0.9173182 0.9373646 0.9223877 0.9055397 0.9077165 0.8484849 ## fda 0.9466791 0.9473209 0.9505168 0.9481600 0.9463399 0.9136922 ## svmRadialCost 0.6297685 0.6462695 0.6458028 0.6397843 0.6367630 0.5758754 ## knn 0.8617429 0.8809390 0.8734330 0.8571032 0.8584613 0.8115088 ## nb 0.8887790 0.8916700 0.8939506 0.8849860 0.8812802 0.8525138 ## rf 0.8374398 0.8593235 0.8511846 0.8363713 0.8369300 0.7998105 ## xgbLinear 0.8295344 0.8349773 0.8296049 0.8133293 0.8109480 0.8317105 ## xgbTree 0.9298302 0.9301195 0.9347325 0.9292743 0.9271813 0.9103171 ## mda fda svmRadialCost knn nb ## nnet 0.9173182 0.9466791 0.6297685 0.8617429 0.8887790 ## glm 0.9373646 0.9473209 0.6462695 0.8809390 0.8916700 ## glmnet 0.9223877 0.9505168 0.6458028 0.8734330 0.8939506 ## pls 0.9055397 0.9481600 0.6397843 0.8571032 0.8849860 ## lda 0.9077165 0.9463399 0.6367630 0.8584613 0.8812802 ## pam 0.8484849 0.9136922 0.5758754 0.8115088 0.8525138 ## mda 1.0000000 0.9161242 0.6471664 0.8950366 0.8427981 ## fda 0.9161242 1.0000000 0.6160408 0.8759144 0.9175588 ## svmRadialCost 0.6471664 0.6160408 1.0000000 0.7903877 0.6854170 ## knn 0.8950366 0.8759144 0.7903877 1.0000000 0.8883843 ## nb 0.8427981 0.9175588 0.6854170 0.8883843 1.0000000 ## rf 0.8622639 0.8366536 0.7694248 0.9629252 0.8754177 ## xgbLinear 0.8163424 0.8056738 0.7492483 0.9250997 0.8380506 ## xgbTree 0.8711245 0.8857206 0.7031082 0.8450892 0.8545409 ## rf xgbLinear xgbTree ## nnet 0.8374398 0.8295344 0.9298302 ## glm 0.8593235 0.8349773 0.9301195 ## glmnet 0.8511846 0.8296049 0.9347325 ## pls 0.8363713 0.8133293 0.9292743 ## lda 0.8369300 0.8109480 0.9271813 ## pam 0.7998105 0.8317105 0.9103171 ## mda 0.8622639 0.8163424 0.8711245 ## fda 0.8366536 0.8056738 0.8857206 ## svmRadialCost 0.7694248 0.7492483 0.7031082 ## knn 0.9629252 0.9250997 0.8450892 ## nb 0.8754177 0.8380506 0.8545409 ## rf 1.0000000 0.9312219 0.8407419 ## xgbLinear 0.9312219 1.0000000 0.8413261 ## xgbTree 0.8407419 0.8413261 1.0000000 splom(results)
筛选发现,所有模型准确率大致都在0.83~0.89之间,不会相差太大。其中,基于决策树的模型表现比较好,以随机森林为最好,其次是xgbLinear。不过,我们发现基于决策树之间的结果相关性比较大,但是它们与KNN、朴素贝叶斯、PAM方法相关性比较弱,于是我们决定要进行集成学习(Ensemble);其中KNN和PAM相关性比较强,我们仅采用其中ROC值更高的KNN模型。主模型采用随机森林(rf),辅助模型采用KNN,NaiveBayes。目前我们单独采用随机森林能够达到的ROC值(AUC)为0.8875979。希望经过集成学习后能够突破它。
对模型进行初筛之后,我们来确定一下模型列表:
model_list2=caretList( Survived~.,data=train, trControl=ctrl, metric="ROC", preProcess=c("center","scale"), methodList=c("rf","nb","knn") )
然后,我们进行集成学习建模。因为是二分类问题,我们用逻辑回归glm来进行集成学习。
glm_ensemble <- caretStack( model_list2, method="glm", metric="ROC", trControl=trainControl( method="boot", number=10, savePredictions="final", classProbs=TRUE, summaryFunction=twoClassSummary ) ) glm_ensemble ## A glm ensemble of 2 base models: rf, nb, knn ## ## Ensemble results: ## Generalized Linear Model ## ## 4455 samples ## 3 predictor ## 2 classes: 'Alive', 'Dead' ## ## No pre-processing ## Resampling: Bootstrapped (10 reps) ## Summary of sample sizes: 4455, 4455, 4455, 4455, 4455, 4455, ... ## Resampling results: ## ## ROC Sens Spec ## 0.8784721 0.7300053 0.8954497
这个结果中集成学习还不如单纯用随机森林得到的效果好。注意每次运行都有随机性,所以结果是不唯一的。我们这里不set.seed,但是需要知道每次的结果都不尽相同,但是一般来说集成学习都会提高总体的准确率。
目前我们已经确定了模型,首先我们认为随机森林模型是比较好的;其次我们认为以随机森林为主,辅助以KNN和朴素贝叶斯方法有提高模型表现的可能,因此要用集成学习方法。在验证阶段,我们需要构建随机森林模型和它的集成模型,并比较两种方法的效果。
test %>% mutate(PassengerId=test_raw1$PassengerId) %>% na.omit -> new.test predict(glm_ensemble,newdata=new.test) -> pre.ensemble predict(model_list2[["rf"]],newdata=new.test) -> pre.rf new.test %>% mutate(rf=pre.rf,ensemble=pre.ensemble) %>% select(PassengerId,rf,ensemble) %>% left_join(gs) %>% mutate_all(funs(as.factor(as.character(.))))-> pre ## Joining, by = "PassengerId" confusionMatrix(pre$rf,pre$Survived) ## Confusion Matrix and Statistics ## ## Reference ## Prediction Alive Dead ## Alive 112 31 ## Dead 40 234 ## ## Accuracy : 0.8297 ## 95% CI : (0.7902, 0.8646) ## No Information Rate : 0.6355 ## P-Value [Acc > NIR] : <2e-16 ## ## Kappa : 0.6278 ## Mcnemar's Test P-Value : 0.3424 ## ## Sensitivity : 0.7368 ## Specificity : 0.8830 ## Pos Pred Value : 0.7832 ## Neg Pred Value : 0.8540 ## Prevalence : 0.3645 ## Detection Rate : 0.2686 ## Detection Prevalence : 0.3429 ## Balanced Accuracy : 0.8099 ## ## 'Positive' Class : Alive ## confusionMatrix(pre$ensemble,pre$Survived) ## Confusion Matrix and Statistics ## ## Reference ## Prediction Alive Dead ## Alive 36 244 ## Dead 116 21 ## ## Accuracy : 0.1367 ## 95% CI : (0.1052, 0.1734) ## No Information Rate : 0.6355 ## P-Value [Acc > NIR] : 1 ## ## Kappa : -0.5798 ## Mcnemar's Test P-Value : 2.179e-11 ## ## Sensitivity : 0.23684 ## Specificity : 0.07925 ## Pos Pred Value : 0.12857 ## Neg Pred Value : 0.15328 ## Prevalence : 0.36451 ## Detection Rate : 0.08633 ## Detection Prevalence : 0.67146 ## Balanced Accuracy : 0.15804 ## ## 'Positive' Class : Alive ##
在验证集中,我们发现集成学习出现了严重的过拟合现象,不如单纯使用随机森林的效果好。这里其实我没有对模型的超参数进行调整,因为我认为这个准确率已经能够接受,其实可以让模型自动再对超参数进行优化,可能会得到更好的效果。继续做下去的话,就是选定随机森林之后对我们的模型进行进一步超参数的调整。
发现网上有人能做到百分百,其实这是完全没有意义的。泰坦尼克号案例就是学习用的,具体应用场景我能够想到的,就是保险业,给每个人投保的时候需要考虑乘客的存活率。不过泰坦尼克的例子已经是多年以前了,现在能够拿到的乘客信息比以前要多得多,更加精细,在具体问题的时候我们还是要不断调整我们的模型。
公众号后台回复关键字即可学习
回复 爬虫 爬虫三大案例实战
回复 Python 1小时破冰入门回复 数据挖掘 R语言入门及数据挖掘
回复 人工智能 三个月入门人工智能
回复 数据分析师 数据分析师成长之路
回复 机器学习 机器学习的商业应用
回复 数据科学 数据科学实战
回复 常用算法 常用数据挖掘算法