Random Forest

0) Importing Packages & Data , EDA

library(tidyverse)
## Warning: 패키지 'tidyverse'는 R 버전 3.6.3에서 작성되었습니다
## -- Attaching packages ------------------------------------------------------ tidyverse 1.3.0 --
## √ ggplot2 3.3.1     √ purrr   0.3.4
## √ tibble  3.0.1     √ dplyr   1.0.0
## √ tidyr   1.1.0     √ stringr 1.4.0
## √ readr   1.3.1     √ forcats 0.5.0
## Warning: 패키지 'tibble'는 R 버전 3.6.3에서 작성되었습니다
## Warning: 패키지 'tidyr'는 R 버전 3.6.3에서 작성되었습니다
## Warning: 패키지 'purrr'는 R 버전 3.6.3에서 작성되었습니다
## Warning: 패키지 'forcats'는 R 버전 3.6.3에서 작성되었습니다
## -- Conflicts --------------------------------------------------------- tidyverse_conflicts() --
## x dplyr::filter() masks stats::filter()
## x dplyr::lag()    masks stats::lag()
library(tidymodels)
## Warning: 패키지 'tidymodels'는 R 버전 3.6.3에서 작성되었습니다
## -- Attaching packages ----------------------------------------------------- tidymodels 0.1.0 --
## √ broom     0.5.6      √ rsample   0.0.6 
## √ dials     0.0.6      √ tune      0.1.0 
## √ infer     0.5.1      √ workflows 0.1.1 
## √ parsnip   0.1.1      √ yardstick 0.0.6 
## √ recipes   0.1.12
## Warning: 패키지 'broom'는 R 버전 3.6.3에서 작성되었습니다
## Warning: 패키지 'dials'는 R 버전 3.6.3에서 작성되었습니다
## Warning: 패키지 'scales'는 R 버전 3.6.3에서 작성되었습니다
## Warning: 패키지 'infer'는 R 버전 3.6.3에서 작성되었습니다
## Warning: 패키지 'parsnip'는 R 버전 3.6.3에서 작성되었습니다
## Warning: 패키지 'recipes'는 R 버전 3.6.3에서 작성되었습니다
## Warning: 패키지 'rsample'는 R 버전 3.6.3에서 작성되었습니다
## Warning: 패키지 'tune'는 R 버전 3.6.3에서 작성되었습니다
## Warning: 패키지 'workflows'는 R 버전 3.6.3에서 작성되었습니다
## Warning: 패키지 'yardstick'는 R 버전 3.6.3에서 작성되었습니다
## -- Conflicts -------------------------------------------------------- tidymodels_conflicts() --
## x scales::discard() masks purrr::discard()
## x dplyr::filter()   masks stats::filter()
## x recipes::fixed()  masks stringr::fixed()
## x dplyr::lag()      masks stats::lag()
## x dials::margin()   masks ggplot2::margin()
## x yardstick::spec() masks readr::spec()
## x recipes::step()   masks stats::step()
library(pROC)
## Warning: 패키지 'pROC'는 R 버전 3.6.3에서 작성되었습니다
## Type 'citation("pROC")' for a citation.
## 
## 다음의 패키지를 부착합니다: 'pROC'
## The following objects are masked from 'package:stats':
## 
##     cov, smooth, var
library(randomForest)
## Warning: 패키지 'randomForest'는 R 버전 3.6.3에서 작성되었습니다
## randomForest 4.6-14
## Type rfNews() to see new features/changes/bug fixes.
## 
## 다음의 패키지를 부착합니다: 'randomForest'
## The following object is masked from 'package:dials':
## 
##     margin
## The following object is masked from 'package:dplyr':
## 
##     combine
## The following object is masked from 'package:ggplot2':
## 
##     margin
library(plot3D)
## Warning: 패키지 'plot3D'는 R 버전 3.6.3에서 작성되었습니다
data = read_csv("data/pulsar_stars.csv")
## Parsed with column specification:
## cols(
##   `Mean of the integrated profile` = col_double(),
##   `Standard deviation of the integrated profile` = col_double(),
##   `Excess kurtosis of the integrated profile` = col_double(),
##   `Skewness of the integrated profile` = col_double(),
##   `Mean of the DM-SNR curve` = col_double(),
##   `Standard deviation of the DM-SNR curve` = col_double(),
##   `Excess kurtosis of the DM-SNR curve` = col_double(),
##   `Skewness of the DM-SNR curve` = col_double(),
##   target_class = col_double()
## )
head(data)
## # A tibble: 6 x 9
##   `Mean of the in~ `Standard devia~ `Excess kurtosi~ `Skewness of th~
##              <dbl>            <dbl>            <dbl>            <dbl>
## 1            141.              55.7          -0.235            -0.700
## 2            103.              58.9           0.465            -0.515
## 3            103.              39.3           0.323             1.05 
## 4            137.              57.2          -0.0684           -0.636
## 5             88.7             40.7           0.601             1.12 
## 6             93.6             46.7           0.532             0.417
## # ... with 5 more variables: `Mean of the DM-SNR curve` <dbl>, `Standard
## #   deviation of the DM-SNR curve` <dbl>, `Excess kurtosis of the DM-SNR
## #   curve` <dbl>, `Skewness of the DM-SNR curve` <dbl>, target_class <dbl>

이 데이터는 Pulsar Star를 찾는 데이터셋이다. Pulsar star란 전자기파 광선을 뿜는 자전하는 중성자별이라고 한다. target_class가 1이면 Pulsar star고 0이면 일반 별이다.

skimr::skim(data)
Table 1: Data summary
Name data
Number of rows 17898
Number of columns 9
_______________________
Column type frequency:
numeric 9
________________________
Group variables None

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
Mean of the integrated profile 0 1 111.08 25.65 5.81 100.93 115.08 127.09 192.62 ▁▁▇▇▁
Standard deviation of the integrated profile 0 1 46.55 6.84 24.77 42.38 46.95 51.02 98.78 ▂▇▁▁▁
Excess kurtosis of the integrated profile 0 1 0.48 1.06 -1.88 0.03 0.22 0.47 8.07 ▅▇▁▁▁
Skewness of the integrated profile 0 1 1.77 6.17 -1.79 -0.19 0.20 0.93 68.10 ▇▁▁▁▁
Mean of the DM-SNR curve 0 1 12.61 29.47 0.21 1.92 2.80 5.46 223.39 ▇▁▁▁▁
Standard deviation of the DM-SNR curve 0 1 26.33 19.47 7.37 14.44 18.46 28.43 110.64 ▇▁▁▁▁
Excess kurtosis of the DM-SNR curve 0 1 8.30 4.51 -3.14 5.78 8.43 10.70 34.54 ▂▇▂▁▁
Skewness of the DM-SNR curve 0 1 104.86 106.51 -1.98 34.96 83.06 139.31 1191.00 ▇▁▁▁▁
target_class 0 1 0.09 0.29 0.00 0.00 0.00 0.00 1.00 ▇▁▁▁▁

0-1) Random Forest를 위해 변수명을 바꾸는 작업이 필요하다.

colnames(data) = c("Mean_of_the_integrated_profile","Standard_deviation_of_the_integrated_profile",
                   "Excess_kurtosis_of_the_integrated_profile","Skewness_of_the_integrated_profile",
                   "Mean_of_the_DM_SNR_curve","Standard_deviation_of_the_DM_SNR_curve",
                   "Excess_kurtosis_of_the_DM_SNR_curve","Skewness_of_the_DM_SNR_curve","target_class" )

0-2) Random Forest를 위해 target_class 변수를 factor로 바꿔주자

data$target_class = as.factor(data$target_class)

1) Train-Test Split

set.seed(0226)
star_split <- rsample::initial_split(data, prop = 0.5, strata = target_class)

train <- training(star_split)
test <- testing(star_split)

print(nrow(train))
## [1] 8949
print(nrow(test))
## [1] 8949

2) Random Forest Modeling

p = ncol(train) - 1 # class 뺀 X변수들의 숫자
n = nrow(train)

mtry = floor(seq(0.1, 1, by = 0.1) * p) 
nsize = round(c(1, n * seq(0.01, 0.1, by = 0.01)))

pna = matrix(NA, nrow = length(mtry), ncol = length(nsize))
pnadf = data.frame()

for( i in 1:length(mtry)){
  for( j in 1:length(nsize)){
    m = mtry[i]
    n = nsize[j]
    rf = randomForest(target_class~., 
                      data = train, ntree = 100, mtry = m, 
                      nodesize = n, set.seed(0226))
    pred = predict(rf, newdata = test, type = 'prob')
    
    roccurve0 = roc(test$target_class ~ pred[,2])

    auc = roccurve0$auc %>% as.numeric()
    
    pna[i,j] = auc
    pnadf = rbind(pnadf, c(m,n,auc))
  }
}

rownames(pna) = mtry
colnames(pna) = nsize

pna
##           1        89       179       268       358       447       537
## 0 0.9655894 0.9478434 0.9514007 0.9489635 0.9482192 0.9526655 0.9480747
## 1 0.9655894 0.9478434 0.9514007 0.9489635 0.9482192 0.9526655 0.9480747
## 2 0.9658301 0.9545747 0.9501529 0.9506757 0.9499466 0.9527863 0.9425445
## 3 0.9672799 0.9502650 0.9495637 0.9495012 0.9497061 0.9484140 0.9440998
## 4 0.9648768 0.9487849 0.9530485 0.9451542 0.9483413 0.9450359 0.9446473
## 4 0.9648768 0.9487849 0.9530485 0.9451542 0.9483413 0.9450359 0.9446473
## 5 0.9678321 0.9556995 0.9489808 0.9474433 0.9456023 0.9461395 0.9421235
## 6 0.9678004 0.9523885 0.9511687 0.9407779 0.9421318 0.9456583 0.9436566
## 7 0.9651534 0.9533287 0.9447496 0.9446937 0.9436194 0.9438521 0.9458295
## 8 0.9669586 0.9510696 0.9443655 0.9332648 0.9322425 0.9249996 0.9302646
##         626       716       805       895
## 0 0.9497824 0.9510453 0.9502049 0.9364539
## 1 0.9497824 0.9510453 0.9502049 0.9364539
## 2 0.9383514 0.9507290 0.9505942 0.9487513
## 3 0.9393807 0.9404574 0.9345050 0.9449459
## 4 0.9392114 0.9394300 0.9411547 0.9422191
## 4 0.9392114 0.9394300 0.9411547 0.9422191
## 5 0.9393145 0.9383569 0.9352619 0.9372310
## 6 0.9368283 0.9326136 0.9399175 0.9374897
## 7 0.9453812 0.9360155 0.9294285 0.9285424
## 8 0.9213968 0.9208511 0.9205902 0.9200889

randomForest 과정중에 반복되는 메세지와 경고가 출력되어 rmarkdown에 message=FALSE, warning=FALSE를 주어 생략하였다.

위와 같은 mtrym nodesize auc 조합이 만들어졌다.

3) 3D Surface Graph

persp3D(mtry, nsize, pna, theta=110, phi=30, axes=TRUE,scale= 0.75, box=TRUE, nticks=5, 

        ticktype="detailed", xlab = "mtry", ylab= "nodesize", zlab = "AUROC")

colnames(pnadf) = c("mtry", "nodesize", "AUC")

그래프를 봤을 때, nodesize가 200 이하이고, mtry 가 5인 지점에서 가장 짙은 적색이 보이며 AUC 약 0.97로 가장 높아보인다.

4) AUC를 max로 하는 mtry과 nodesize 찾기

df = pnadf %>%
  filter(AUC == max(AUC))

print(df)
##   mtry nodesize       AUC
## 1    5        1 0.9678321

그래프에서 봤던 대로 mtry가 5이고 nodesize가 1인 지점에서 AUC 약 0.968로 가장 높았다.