本节我们将线性判别算法应用到花卉分类场景中。花卉分类指通过花卉不同的特征,如花瓣的长和宽、花蕊的长和宽,将花卉分为不同的类别。本节先将多维数据简化为二维数据,以便和理论知识部分相呼应。
In [1]: import numpy as np
...: import matplotlib.pyplot as plt
...: from sklearn import datasets
...: from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
...: from sklearn.model_selection import train_test_split
In [2]: iris = datasets.load_iris()
In [3]: X = iris['data']
In [4]: y = iris['target']
In [5]: target_names = iris['target_names']
数据集如图 1 所示,这里只使用 sepal length 和 sepal width 两个属性。
In [11]: for m,i,target_name in zip('vo^',range(2),target_names[0:2]):
...: sl = X[y == i,0] # sl = sepal length (cm)
...: sw = X[y == i,1] # sw = sepal width (cm)
...: plt.scatter(sl,sw,marker=m,label=target_name,s=30,c='k')
...:
...: plt.xlabel('sepal length (cm)') # 绘制x轴和y轴标签名
...: plt.ylabel('sepal width (cm)')
...: plt.show()
In [7]: plt.close()
获取 sepal length 和 sepal width 两个属性的自变量矩阵;获取 sepal length 和 sepal width 两个属性的因变量矩阵。
In [8]: X=X[(y==1) | (y==0),0:2]
...: y=y[(y==1) | (y==0)]
通过 n_components 参数设置压缩之后的维度为 1。
In [9]: lda = LinearDiscriminantAnalysis(n_components=1)
In [10]: ld = lda.fit(X,y)
这一步实际上就是通过模型进行降维。
In [11]: X_t =ld.transform(X)
因为压缩到 1 维,所以y轴坐标全部为 0。
In [12]: y_t = np.zeros(X_t.shape)
结果如图 2 所示。
In [13]: for m,i,target_name in zip('ov^',range(2),target_names[0:2]): # 做压缩后
#的图像
...: plt.scatter(X_t[y == i],y_t[y == i],marker=m,label=target_name,s=30, c='k')
...:
...: plt.legend()
...: plt.show()
In [14]: plt.close()
这里取 80% 作为训练集,20% 作为测试集。
In [15]: X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.2)
In [16]: lda = LinearDiscriminantAnalysis(n_components=1)
In [17]: ld = lda.fit(X_train,y_train)
In [18]: pre = ld.predict(X_test)
In [19]: list(zip(pre,y_test,pre==y_test))
Out[19]:
[(0, 0, True),
(0, 0, True),
(1, 1, True),
(1, 1, True),
(1, 1, True),
(0, 0, True),
(0, 0, True),
(1, 1, True),
(1, 1, True),
(1, 1, True),
(1, 1, True),
(1, 1, True),
(1, 1, True),
(0, 0, True),
(0, 0, True),
(1, 1, True),
(0, 0, True),
(0, 0, True),
(1, 1, True),
(1, 1, True)]
In [20]: ld.score(X_test,y_test)
Out[20]: 1.0