Scikit-learn을 활용한 최근접 이웃 분류: 씨앗 데이터셋

최근접 이웃 분류

이번 포스트에서는 최근접 이웃 분류(일명 knn)에 관해 간단하게 알아보도록 하겠습니다. 최근접 이웃 분류란 해당 데이터 포인트에서 가장 가까운 곳에 위치한 데이터 포인트의 라벨을 예상값으로 하는 알고리즘입니다. Scikit-learn을 이용하여 교차 검증을 거쳐 최근접 이웃 분류를 해보도록 하겠습니다.

우선 import 부터 선언하겠습니다.

import numpy as np
from sklearn.neighbors import KNeighborsClassifier

우리가 사용할 데이터셋은 '씨앗 데이터셋'입니다.

참고:: https://www.dropbox.com/s/9xmhhvi6gm5dwo0/iris.txt?dl=0

씨앗 데이터셋은 7개의 Feature와 1개의 Label을 가진 210종의 씨앗을 모아둔 데이터셋입니다.

with open("C:₩₩~~~₩₩seeds.tsv", 'r') as f:
    entire_file = list(line.split() for line in f)

entire_file = np.array(entire_file)
print(np.shape(entire_file))     # (210, 8)

사용하기 좋게 Feature와 Label을 Numpy 배열에 나눠 담는 작업을 합니다. 여기서 주의할 점은 KNN에 들어갈 Label은 꼭 1차원 배열이어야 한다는 것입니다. 아래 코드에서 reshape하지 않고 사용했더니 정확도가 30% 가량이 나왔습니다.(나온 것도 신기...)

features = entire_file[:, :-1]

labels = entire_file[:, [-1]]
labels = np.reshape(labels, -1)

KNN 분류기를 생성합니다. 생성 시 n_neighbors를 인자로 전달하게 되는데 이는 '몇 명의 이웃'을 보고 예상값을 전달해줄지 결정해주는 인자입니다. 이 값이 1이라면 가장 가까운 1명의 이웃을 보고 예상값을 반환해줍니다. 이 값이 10이라면 가장 가까운 10명의 이웃의 Label을 살펴보고 빈도수가 가장 높은 Label을 반환해줍니다.

classifier = KNeighborsClassifier(n_neighbors=10)

교차 검증을 하고자 하니 데이터를 학습용과 검증용으로 나눌 필요가 있습니다. Scikit-learn에서는 이를 편리하게 하는 모듈을 제공합니다. kf에는 훈련용 데이터와 검증용 데이터가 불 배열 형태로 List에 담겨있습니다.

from sklearn.cross_validation import KFold
kf = KFold(len(features), n_folds=3, shuffle=True)

교차 검증의 중첩 당 정확도를 구한 뒤 한번에 합하기 위해 List를 만듭니다.

means = []

kf에서 훈련용 데이터와 검증용 데이터를 꺼내와 가장 적합한(fit) KNN 모델을 찾고 정확도를 측정합니다.

for training, testing in kf:
    classifier.fit(features[training], labels[training])
    prediction = classifier.predict(features[testing])
    curmean = np.mean(prediction == labels[testing])
    means.append(curmean)
    
print("전체 정확도: {:.1%}".format(np.mean(means)))

* 출력값
전체 정확도: 90.5%

 

3 thoughts on “Scikit-learn을 활용한 최근접 이웃 분류: 씨앗 데이터셋

  1. 안녕하세요 이번에 졸업과제로 knn을 사용해야하는 대학생입니다! 씨앗 데이터셋 올려주신 예제를 해보는데 sklearn.cross_validation에서 KFold가 import되지 않아요ㅜㅜ 그래서 혹시나 하고 sklearn.model_selection에서 KFold를 import해봤는데 ‘KFold’ object is not iterable이라고 TypeError가 뜨더라구요… 혹시 이것과 관련해서 답변 주시면 정말 감사하겠습니다ㅜㅜㅜ

    1. 말씀해주신대로

      from sklearn.cross_validation import KFold

      로 import 해서 진행하면 정상적으로 잘 되는 것 같습니다만, 어느 코드에서 문제가 발생하나요?

    2. 아, 기존에 올렸던 데이터의 tab이 조금 이상하게 되어 있어서 그 부분 새로 수정했습니다.
      다시 한번 해보세요~

댓글 남기기

이메일은 공개되지 않습니다. 필수 입력창은 * 로 표시되어 있습니다