scikit-learn的基本用法(一)——KNN算法的使用

文章作者:Tyan
博客:noahsnail.com  |  CSDN  |  简书

本文主要使用scikit-learn中的KNN算法进行Iris数据集的分类。Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集。通过花萼长度,花萼宽度,花瓣长度,花瓣宽度4个属性预测鸢尾花卉属于(Setosa,Versicolour,Virginica)三个种类中的哪一类。

  • Demo
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import numpy as np
from sklearn import datasets
from sklearn.cross_validation import train_test_split
from sklearn.neighbors import KNeighborsClassifier

# 加载iris数据集
iris = datasets.load_iris()
# 读取特征
iris_X = iris.data
# 读取分类标签
iris_y = iris.target
# 将数据分为训练、测试两部分
X_train, X_test, y_train, y_test = train_test_split(iris_X, iris_y, test_size = 0.2)
# 定义分类器
knn = KNeighborsClassifier()
# 进行分类
knn.fit(X_train, y_train)
# 计算预测值
y_predict = knn.predict(X_test)
# 计算准确率, 由于每次数据集划分不同, 可能不一样
print np.sum(np.fabs(y_predict - y_test)) / float(len(y_test))
  • 结果
1
0.0666666666667

参考资料

  1. https://www.youtube.com/user/MorvanZhou
如果有收获,可以请我喝杯咖啡!