DataScience/MachineLearning

[트리계열 이해하기] 4. GBM

mkk4726 2023. 7. 29. 11:44

1. 기본 컨셉

 

Gradient Boosting Machine 은 이름처럼 Boosting 계열의 트리 모델입니다.

 

2023.07.28 - [DataScience/MachineLearning] - [트리계열 이해하기] 3. AdaBoost

 

[트리계열 이해하기] 3. AdaBoost

AdaBoost는 이름에서 알 수 있 듯, Boosting 계열 모델 중 하나입니다. Boosting은 Ensemble 기법 중 하나로, Bagging과 다르게 Sequenital한 기법입니다. 2023.07.28 - [DataScience/MachineLearning] - [트리계열 이해하기] 2.

mkk4726.tistory.com

AdaBoost가 틀린 데이터에 대해서는 양의 가중치를, 맞은 데이터에 대해서는 음의 가중치를 부여하며 규칙을 만들어갔다면, GBM은 잔차를 계속해서 학습하는 규칙을 만들어갑니다.

그림1. GBM 기본 컨셉

GBM의 기본 컨셉은 "잔차를 학습한다" 입니다.


2. 잔차를 학습하는 방법

 

잔차를 학습하는 방법은 꽤나 직관적입니다.

그림2. 잔차를 학습하는 방법

처음에는 평균으로 예측을 합니다. 

그 후 원래 값에 평균을 빼서 잔차를 구합니다.

 

그리고 그 다음 트리부터는 앞에서 구한 잔차를 구분하는 규칙을 만듭니다.

예측치는 구분한 잔차에 평균을 더한 값이 됩니다.

이렇게만 예측해도 잔차가 거의 없는, overfitting 하게 되는 문제가 발생합니다.

 

그리고 저는 이렇게 예측할거면 "왜 잔차를 학습하는거지?"라는 의문도 들었습니다.

잔차를 구분하고 나서 평균을 더하나, 원래 값으로 구분하나 똑같으니까요.

저는 이 뒤 과정을 보고 나서 잔차를 학습한다는 개념이 이해가 됐습니다.

그림5. 잔차에 학습률을 곱해 예측

GBM은 이렇게 잔차를 학습해 얻은 예측치에 학습률을 곱합니다.

오버피팅을 예방하기 위하기 위해서라고 합니다. 

그림6. 학습 과정

이렇게 구한 예측치로 잔차를 다시 구하고, 이를 다음 트리가 학습합니다.

 

그림7. GBM의 최종 예측치

이 과정을 반복하면 모델은 잔차를 완전히 학습하게 됩니다.

최종 예측치는 앞에서 구한 잔차에 학습률을 곱한 값이 됩니다.

 

여기서 잔차를 학습한다고 했지만, 정확히는 해당 iter에서의 gradient를 통해 학습하는 것이며,

따라서 gradient boosting machine이라는 이름이 붙었습니다.

 


3. Overfitting을 예방하는 법

 

애초에 트리구조 자체가 오버피팅에 취약하기도 하고, 잔차를 학습하기에 오버피팅이 더욱 쉽게 일어나는 구조입니다.

따라서 대표적으로 3가지 방법을 통해 오버피팅을 예방합니다.

 

1. Subsampling

그림8. Subsampling

iteration 마다 데이터를 복원추출하여 사용합니다.  ( bagging또한 가능하다고 합니다. )

2. Shrinkage

그림9. Shrinkage

학습률과 비슷한 개념으로 iter마다 발생하는 impact를 조금만 반영하는 것을 말합니다.

 

3. Early Stopping

그림10. Early Stopping

train set에 너무 오버피팅 되지 않도록 valid set과 차이가 벌어지면 학습을 중단하는 것입니다.


4. Importance Score

GBM에서의 Importance Score는 Information Gain을 얼마나 얻었느냐로 평가합니다.

 

먼저 피처 j 의 트리 T에서의 Influence는 다음과 같이 구합니다.

 

$Influence_j(T) = \sum_{i=1}^{L-1}(IG_i \times 1(S_j=j))$

 

그리고 이를 모든 트리에서 구하고 평균해준 것이 피처 j의 importance score 입니다.

 

$Influence_j = \frac{1}{M}\sum_{k=1}^{M}Influence_j(T_k)$

 

 


- Reference

패스트캠퍼스, 초격차 패키지 : 50개 프로젝트로 완벽하게 끝내는 머신러닝