您当前的位置:首页 > 计算机 > 编程开发 > 人工智能

Sklearn交叉验证分层与分组使用介绍

时间:10-06来源:作者:点击数:

有些数据集分布并不均匀,因此在训练模型后会出现极大的不平衡。这时就需要采用分层抽样,即分层交叉验证迭代器,可以理解为分层抽样。

1.分层交叉验证

#导入相关模块
In [1]: from sklearn.model_selection import StratifiedKFold
#导入相关数据
I [2]: X = ["a", "b", "c", "d", "e", "f"]
#导入相关数据
In [3]: y = [0, 0, 1, 1, 1, 1,]
#创建分层交叉验证对象
In [4]: skf = StratifiedKFold(n_splits=2)
#查看分组结果
In [5]: for train, test in skf.split(X, y):
   ...:     print("%s-%s" % (train, test))
[1 4 5]-[0 2 3]
[0 2 3]-[1 4 5]

2. 分组交叉验证

有时测试集的数据可能是分组得来的,这时可能出现的情况就是组内的各个变量之间不是独立的,而组间是独立的。我们需要去除这个影响因素,也就是说测试集中的样本组别不能来自训练集中样本的组别。

分组迭代器有以下几种。

1) 组K折

#导入相关模块
In [1]: from sklearn.model_selection import GroupKFold
#导入相关数据
In [2]: X = ["a", "b", "c", "d", "e", "f"]
#导入相关数据
In [3]: y = [0, 0, 1, 1, 1, 1,]
#导入分组标签
In [4]: groups = [1, 1, 2, 2, 2, 2]
#创建分组对象
In [5]: gkf = GroupKFold(n_splits=2)
#查看分组结果
In [6]: for train, test in gkf.split(X, y, groups=groups):
   ...:     print("%s-%s" % (train, test))
[0 1]-[2 3 4 5]
[2 3 4 5]-[0 1]

2) 留一组交叉验证

#导入相关模块
In [1]: from sklearn.model_selection import LeaveOneGroupOut
#导入相关数据
In [2]: X = ["a", "b", "c", "d", "e", "f"]
#导入相关数据
In [3]: y = [0, 0, 1, 1, 1, 1,]
#导入分组标签
In [4]: groups = [1, 2, 2, 2, 2, 2]
#创建分组对象
In [5]: logo = LeaveOneGroupOut()
#查看分组结果
In [6]: for train, test in logo.split(X, y, groups=groups):
   ...:     print("%s-%s" % (train, test))
[1 2 3 4 5]-[0]
[0]-[1 2 3 4 5]

3) 留P组交叉验证

#导入相关模块
In [1]: from sklearn.model_selection import LeavePGroupsOut
#导入相关数据
In [2]: X = ["a", "b", "c", "d", "e", "f"]
#导入相关数据
In [3]: y = [0, 0, 1, 1, 1, 1,]
#导入分组标签
In [4]: groups = [1, 1, 2, 2, 3, 3]
#创建分组对象
In [5]: lpgo = LeavePGroupsOut(n_groups=2)
#查看分组结果
In [6]: for train, test in lpgo.split(X, y, groups=groups):
   ...:     print("%s-%s" % (train, test))
[4 5]-[0 1 2 3]
[2 3]-[0 1 4 5] 
[0 1]-[2 3 4 5]

4) 随机排列组交叉验证

#导入相关模块
In [1]: from sklearn.model_selection import GroupShuffleSplit
#导入相关数据
In [2]: X = ["a", "b", "c", "d", "e", "f"]
#导入相关数据
In [3]: y = [0, 0, 1, 1, 1, 1,]
#导入分组标签
In [4]: groups = [1, 1, 2, 2, 3, 3]
#创建分组对象
In [5]: gss = GroupShuffleSplit(n_splits=3, test_size=0.5)
#查看分组结果
In [6]: for train, test in gss.split(X, y, groups=groups):
   ...:     print("%s-%s" % (train, test))
[4 5]-[0 1 2 3]
[0 1]-[2 3 4 5]
[4 5]-[0 1 2 3]
方便获取更多学习、工作、生活信息请关注本站微信公众号城东书院 微信服务号城东书院 微信订阅号
推荐内容
相关内容
栏目更新
栏目热门