本节我们将逻辑回归算法应用到鸢尾花数据集上,看其分类效果。
这里我们用到了 Numpy 来提取数据,使用 Matplotlib 做最终的展示,使用 Scikit 中的 iris 作为数据集,导入线性模块 linear_model。使用 sklearn.model_selection进行测试集和训练集的划分。
In [1]: import numpy as np
...: import matplotlib.pyplot as plt
...: from sklearn import linear_model, datasets
...: from sklearn.model_selection import train_test_split
In [2]: iris = datasets.load_iris() # 导入相关数据
这里我们取 iris 数据集中的前两个属性。
In [3]: X = iris.data[:, :2] # 我们只使用前两个属性
...: X
Out[3]:
array([[5.1, 3.5],
[4.9, 3],
[4.7, 3.2],
[4.6, 3.1],
[5. , 3.6],
[5.4, 3.9],
[4.6, 3.4],
......
[6.8, 3.2],
[6.7, 3.3],
[6.7, 3],
[6.3, 2.5],
[6.5, 3],
[6.2, 3.4],
[5.9, 3]])
In [4]: y = iris.target # 获得目标变量
train_test_split() 方法的第 1 个参数传入的是属性矩阵,第 2 个参数是目标变量,第 3 个参数是测试集所占的比重。它返回了 4 个值,按顺序分别是训练集属性、测试集属性、训练集目标变量、测试集目标变量。
In [5]: X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2) #分割训练集和测试集
为了接下来的作图做准备。
In [6]: h = .02 # 设置网格的步长
In [7]: logreg = linear_model.LogisticRegression(C=1e5) #创建模型对象
In [8]: logreg.fit(X_train, y_train) # 训练
Out[8]:
LogisticRegression(C=100000.0, class_weight=None, dual=False,
fit_intercept=True, intercept_scaling=1, max_iter=100,
multi_class='ovr', n_jobs=1, penalty='l2', random_state=None,
solver='liblinear', tol=0.0001, verbose=0, warm_start=False)
分别设置第 1 维度的网格数据和第 2 维度的网格数据。
In [9]: x_min, x_max = X[:, 0].min() -.5, X[:, 0].max() + .5 # 第1维度网格数据预备
...: y_min, y_max = X[:, 1].min() -.5, X[:, 1].max() + .5 # 第2维度网格数据预备
创建网格数据,“xx,yy”是一个网格类型,主要是为了作面积图。
In [10]: xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max,h)) # 创建网格数据
In [11]: Z = logreg.predict(np.c_[xx.ravel(), yy.ravel()]) # 预测
In [12]: Z = Z.reshape(xx.shape) # 将 Z 矩阵转换为与 xx 相同的形状
绘制模型分类器的结果图像。
In [13]: plt.figure(figsize=(4, 4)) # 设置画板
...: plt.pcolormesh(xx, yy, Z, cmap=plt.cm.Paired) # 作网格图
Out[13]: <matplotlib.collections.QuadMesh at 0xae38cc0>
效果如图 1 所示。
绘制模型图像以及样本点的图像。
In [14]: plt.figure(figsize=(4, 4)) # 设置画板
...: plt.pcolormesh(xx, yy, Z, cmap=plt.cm.Paired) # 作网格图
...: plt.scatter(X_test[:, 0], X_test[:, 1], c=y_test, edgecolors='k', cmap= plt.cm.Paired) # 画出预测的结果
...:
...: plt.xlabel('Sepal length') # 作x轴标签
...: plt.ylabel('Sepal width') # 作y轴标签
...: plt.xlim(xx.min(), xx.max()) # 设置x轴范围
...: plt.ylim(yy.min(), yy.max()) # 设置y轴范围
...: plt.xticks(()) # 隐藏x轴刻度
...: plt.yticks(()) # 隐藏y轴刻度
效果如图 2 所示。