您当前的位置:首页 > 计算机 > 编程开发 > Python

KNN分类算法原理与Python+sklearn实现根据身高和体重对体型分类

时间:12-27来源:作者:点击数:

KNN算法是k-Nearest Neighbor Classification的简称,也就是k近邻分类算法。基本思路是在特征空间中查找k个最相似或者距离最近的样本,然后根据k个最相似的样本对未知样本进行分类。基本步骤为:

(1)计算已知样本空间中所有点与未知样本的距离;

(2)对所有距离按升序排列;

(3)确定并选取与未知样本距离最小的k个样本或点;

(4)统计选取的k个点所属类别的出现频率;

(5)把出现频率最高的类别作为预测结果,即未知样本所属类别。

下面的代码模拟了上面的算法思路和步骤,以身高+体重对肥胖程度进行分类为例,采用欧几里得距离。

from collections import Counter

import numpy as np

# 已知样本数据

# 每行数据分别为性别,身高,体重

knownData = ((1, 180, 85),

             (1, 180, 86),

             (1, 180, 90),

             (1, 180, 100),

             (1, 185, 120),

             (1, 175, 80),

             (1, 175, 60),

             (1, 170, 60),

             (1, 175, 90),

             (1, 175, 100),

             (1, 185, 90),

             (1, 185, 80))

knownTarget = ('稍胖', '稍胖', '稍胖', '过胖',

               '太胖', '正常', '偏瘦', '正常',

               '过胖', '太胖', '正常', '偏瘦')

def KNNPredict(current, knownData=knownData, knownTarget=knownTarget, k=3):

    # current为未知样本,格式为(性别,身高,体重)

    data = dict(zip(knownData, knownTarget))

    # 如果未知样本与某个已知样本精确匹配,直接返回结果

    if current in data.keys():

        return data[current]

   

    # 按性别过滤,只考虑current性别一样的样本数据

    g = lambda item:item[0][0]==current[0]

    samples = list(filter(g, data.items()))

    g = lambda item:((item[0][1]-current[1])**2+\

                     (item[0][2]-current[2])**2)**0.5

    distances = sorted(samples, key=g)

    # 选取距离最小的前k个

    distances = (item[1] for item in distances[:k])

    # 计算选取的k个样本所属类别的出现频率

    # 选择频率最高的类别作为结果

    return Counter(distances).most_common(1)[0][0]

unKnownData = [(1, 180, 70), (1, 160, 90), (1, 170, 85)]

for current in unKnownData:

    print(current, ':', KNNPredict(current))

运行结果为:

(1, 180, 70) : 偏瘦

(1, 160, 90) : 过胖

(1, 170, 85) : 正常

下面的代码使用扩展库sklearn中的k近邻分类算法处理了同样的问题:

# 使用sklearn库的k近邻分类模型

from sklearn.neighbors import KNeighborsClassifier

# 创建并训练模型

clf = KNeighborsClassifier(n_neighbors=3, weights='distance')

clf.fit(knownData, knownTarget)

# 分类

for current in unKnownData:

    print(current, end=' : ')

    current = np.array(current).reshape(1,-1)

    print(clf.predict(current)[0])

运行结果为:

(1, 180, 70) : 偏瘦

(1, 160, 90) : 过胖

(1, 170, 85) : 正常

方便获取更多学习、工作、生活信息请关注本站微信公众号城东书院 微信服务号城东书院 微信订阅号
推荐内容
相关内容
栏目更新
栏目热门