소소한 컴퓨터 이야기

One-Hot Encoding

by Cori

1. One-Hot Encoding

0) 정의

· N개의 클래스를 N차원의 One-Hot 벡터로 표현되도록 변환한다 (고유값들을 피처로 만들고 정답에 해당하는 열은 1, 나머진 0으로 -)

· 숫자의 차이가 모델에 영향을 미치는 선형 계열 모델에서 범주형 데이터로 변환할 때 사용한다.

· Decision Tree 계열의 알고리즘은 Feature에 0이 많은 경우 성능이 떨어지기 때문에 Label Encoding을 수행한다.

 

1) 함수

· sklearn.preprocessing.OneHotEncoder

더보기

· fit(데이터셋): 데이터셋을 기준으로 어떻게 변환할 지 학습

· transform(데이터셋): argument로 받은 데이터셋을 원핫인코딩 처리

· fit_transform(데이터셋): 학습과 변환을 한 번에 처리

· get_feature_names(): 원핫인코딩으로 변환된 컬럼의 이름을 반환

 

* 데이터셋은 2차원 배열을 전달하며, Feature 별로 원핫인코딩 처리한다. 

 

· get_dummies(DataFrame [, columns=[변환할 컬럼명]]) 

-> DataFrame에서 범주형 (object, category) 변수만 변환한다. 

* 범주형 변수인데 숫자 값을 가지는 경우가 있는데 (별점 등), 이런 경우 get_dummies(columns=['컬럼명', '컬럼명'] 매개변수로 

  컬럼들을 명시한다. 

 

2) 예제

원-핫 인코딩

· Dataset 생성

· One Hot Encoder 생성, 학습 및 변환 

* ohe.fit_transform(df)으로 fit & transform 과정을 한 번에 처리할 수 있다. 

 

· One Hot Encoding 결과 값을 DataFrame으로 변환 

· get_dummies()

get_dummies() 함수를 통해 object, categorical 타입의 feature들은 모두 변환한다. 

* get_dummies() 함수에 변환할 컬럼을 지정하면 숫자형 컬럼들 또한 변환할 수 있음 

3) Adult Dataset에 One Hot Encoding 적용

Adult-data의 컬럼들을 정리해보면 다음과 같다.

이 중에서, 'age', 'workclass', 'education', 'occupation', 'gender', 'hours-per-week' 컬럼만 사용한다. 

* 범주형 Feature 중 income은 출력 데이터이므로 Label Encoding 처리를 하고, 나머지 범주형 Feature들은 One-hot Encoding 처리한다. 

 

· 데이터 로딩

skipinitialspace=True로 설정하면 첫 글자가 공백인 것의 공백은 제거한다. 

· 결측치 제거

· 필요한 Feature들만 추출 

age와 hours-per-week 컬럼을 제외한 컬럼들은 모두 object 타입이다. 

 

· income 데이터 레이블 인코딩 수행 

fit_transform() 함수를 통해 간결하게 구현하였음  * fit(), transform() 따로 구현 

 

· One Hot Encoding 

'workclass', 'education', 'occupation', 'gender' 4개의 컬럼에 대해 원 핫 인코딩을 수행하였다.

 

* 생성시에 sparse = False로 주었는데, sparse를 False로 주지 않으면 scipy의 csr_matrix (희소행렬 객체)로 반환한다. 희소행렬은 대부분 0으로 구성된 행렬과 계산이나 메모리 효율을 이용해 0이 아닌 값의 index만 관리한다. 

원 핫 인코딩된 컬럼들 조회

· One-hot Encoding된 컬럼 + 연속형 변수 컬럼 병합 

원-핫 인코딩 된 변수 타입을 조회해보면 numpy 배열이다. 연속형 변수들은 DataFrame 형태이기 때문에, 이 둘을 병합하기 위해 연속형 변수 컬럼들을 numpy 배열로 변환해주었다. 

두 개의 numpy 배열을 numpy 라이브러리의 concatenate 함수를 이용하여 합침 

 

· 데이터셋 분리 

· DecisionTree & LogisticRegression 모델 학습 

· 학습 결과 추론 

· 평가 

scikit-learn에서 제공하는 accuracy_score 함수를 이용하여 정확도를 측정해 보았다.

블로그의 프로필 사진

블로그의 정보

코딩하는 오리

Cori

활동하기