百度360必应搜狗淘宝本站头条
当前位置:网站首页 > IT技术 > 正文

Python机器学习库Sklearn系列教程(22)-交叉验证(K折、

wptr33 2025-03-25 18:08 18 浏览

(K折、留一,留p,随机)

学习预测函数的参数,并在相同数据集上进行测试是一种错误的做法: 一个仅给出测试用例标签的模型将会获得极高的分数,但对于尚未出现过的数据它则无法预测出任何有用的信息。 这种情况称为 overfitting(过拟合). 为了避免这种情况,在进行(监督)机器学习实验时,通常取出部分可利用数据作为 test set(测试数据集) X_test, y_test。

利用 scikit-learn 包中的 train_test_split 辅助函数可以很快地将实验数据集划分为任何训练集(training sets)和测试集(test sets)。

计算交叉验证的指标

使用交叉验证最简单的方法是在估计器和数据集上调用 cross_val_score 辅助函数。

保留数据的数据转换

正如在训练集中保留的数据上测试一个 predictor (预测器)是很重要的一样,预处理(如标准化,特征选择等)和类似的 data transformations 也应该从训练集中学习,并应用于预测数据以进行预测

cross_validate 函数和多度量评估

cross_validate 函数与 cross_val_score 在下面的两个方面有些不同 -

  • 它允许指定多个指标进行评估.
  • 除了测试得分之外,它还会返回一个包含训练得分,拟合次数, score-times (得分次数)的一个字典。

交叉验证迭代器

K折交叉验证: KFold 将所有的样例划分为 k 个组,称为折叠 (fold) (如果 k = n, 这等价于 Leave One Out(留一) 策略),都具有相同的大小(如果可能)。预测函数学习时使用 k - 1 个折叠中的数据,最后一个剩下的折叠会用于测试。

K折重复多次: RepeatedKFold 重复 K-Fold n 次。当需要运行时可以使用它 KFold n 次,在每次重复中产生不同的分割。

留一交叉验证: LeaveOneOut (或 LOO) 是一个简单的交叉验证。每个学习集都是通过除了一个样本以外的所有样本创建的,测试集是被留下的样本。 因此,对于 n 个样本,我们有 n 个不同的训练集和 n 个不同的测试集。这种交叉验证程序不会浪费太多数据,因为只有一个样本是从训练集中删除掉的:

留P交叉验证: LeavePOut 与 LeaveOneOut 非常相似,因为它通过从整个集合中删除 p 个样本来创建所有可能的 训练/测试集。对于 n 个样本,这产生了 {n \choose p} 个 训练-测试 对。与 LeaveOneOut 和 KFold 不同,当 p > 1 时,测试集会重叠。

用户自定义数据集划分: ShuffleSplit 迭代器将会生成一个用户给定数量的独立的训练/测试数据划分。样例首先被打散然后划分为一对训练测试集合。

设置每次生成的随机数相同: 可以通过设定明确的 random_state ,使得伪随机生成器的结果可以重复。

基于类标签、具有分层的交叉验证迭代器

如何解决样本不平衡问题? 使用StratifiedKFold和StratifiedShuffleSplit 分层抽样。 一些分类问题在目标类别的分布上可能表现出很大的不平衡性:例如,可能会出现比正样本多数倍的负样本。在这种情况下,建议采用如 StratifiedKFold 和 StratifiedShuffleSplit 中实现的分层抽样方法,确保相对的类别频率在每个训练和验证 折叠 中大致保留。

StratifiedKFold 是 k-fold 的变种,会返回 stratified(分层) 的折叠:每个小集合中, 各个类别的样例比例大致和完整数据集中相同。

StratifiedShuffleSplit 是 ShuffleSplit 的一个变种,会返回直接的划分,比如: 创建一个划分,但是划分中每个类的比例和完整数据集中的相同。

用于分组数据的交叉验证迭代器

如何进一步测试模型的泛化能力? **留出一组特定的不属于测试集和训练集的数据。**有时我们想知道在一组特定的 groups 上训练的模型是否能很好地适用于看不见的 group 。为了衡量这一点,我们需要确保验证对象中的所有样本来自配对训练折叠中完全没有表示的组。

GroupKFold 是 k-fold 的变体,它确保同一个 group 在测试和训练集中都不被表示。 例如,如果数据是从不同的 subjects 获得的,每个 subject 有多个样本,并且如果模型足够灵活以高度人物指定的特征中学习,则可能无法推广到新的 subject 。 GroupKFold 可以检测到这种过拟合的情况。

LeaveOneGroupOut 是一个交叉验证方案,它根据第三方提供的 array of integer groups (整数组的数组)来提供样本。这个组信息可以用来编码任意域特定的预定义交叉验证折叠。

每个训练集都是由除特定组别以外的所有样本构成的。

LeavePGroupsOut 类似于 LeaveOneGroupOut ,但为每个训练/测试集删除与 P 组有关的样本。

GroupShuffleSplit 迭代器是 ShuffleSplit 和 LeavePGroupsOut 的组合,它生成一个随机划分分区的序列,其中为每个分组提供了一个组子集。

时间序列分割

TimeSeriesSplit 是 k-fold 的一个变体,它首先返回 k 折作为训练数据集,并且 (k+1) 折作为测试数据集。 请注意,与标准的交叉验证方法不同,连续的训练集是超越前者的超集。 另外,它将所有的剩余数据添加到第一个训练分区,它总是用来训练模型。

这个类可以用来交叉验证以固定时间间隔观察到的时间序列数据样本。

代码实现

from sklearn.model_selection import train_test_split,cross_val_score,cross_validate # 交叉验证所需的函数
from sklearn.model_selection import KFold,LeaveOneOut,LeavePOut,ShuffleSplit # 交叉验证所需的子集划分方法
from sklearn.model_selection import StratifiedKFold,StratifiedShuffleSplit # 分层分割
from sklearn.model_selection import GroupKFold,LeaveOneGroupOut,LeavePGroupsOut,GroupShuffleSplit # 分组分割
from sklearn.model_selection import TimeSeriesSplit # 时间序列分割
from sklearn import datasets  # 自带数据集
from sklearn import svm  # SVM算法
from sklearn import preprocessing  # 预处理模块
from sklearn.metrics import recall_score  # 模型度量
iris = datasets.load_iris()  # 加载数据集
print('样本集大小:',iris.data.shape,iris.target.shape)
# ===================================数据集划分,训练模型==========================
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.4, random_state=0)  # 交叉验证划分训练集和测试集.test_size为测试集所占的比例
print('训练集大小:',X_train.shape,y_train.shape)  # 训练集样本大小
print('测试集大小:',X_test.shape,y_test.shape)  # 测试集样本大小
clf = svm.SVC(kernel='linear', C=1).fit(X_train, y_train) # 使用训练集训练模型
print('准确率:',clf.score(X_test, y_test))  # 计算测试集的度量值(准确率)
#  如果涉及到归一化,则在测试集上也要使用训练集模型提取的归一化函数。
scaler = preprocessing.StandardScaler().fit(X_train)  # 通过训练集获得归一化函数模型。(也就是先减几,再除以几的函数)。在训练集和测试集上都使用这个归一化函数
X_train_transformed = scaler.transform(X_train)
clf = svm.SVC(kernel='linear', C=1).fit(X_train_transformed, y_train) # 使用训练集训练模型
X_test_transformed = scaler.transform(X_test)
print(clf.score(X_test_transformed, y_test))  # 计算测试集的度量值(准确度)
# ===================================直接调用交叉验证评估模型==========================
clf = svm.SVC(kernel='linear', C=1)
scores = cross_val_score(clf, iris.data, iris.target, cv=5)  #cv为迭代次数。
print(scores)  # 打印输出每次迭代的度量值(准确度)
print("Accuracy: %0.2f (+/- %0.2f)" % (scores.mean(), scores.std() * 2))  # 获取置信区间。(也就是均值和方差)
# ===================================多种度量结果======================================
scoring = ['precision_macro', 'recall_macro'] # precision_macro为精度,recall_macro为召回率
scores = cross_validate(clf, iris.data, iris.target, scoring=scoring,cv=5, return_train_score=True)
sorted(scores.keys())
print('测试结果:',scores)  # scores类型为字典。包含训练得分,拟合次数, score-times (得分次数)
# ==================================K折交叉验证、留一交叉验证、留p交叉验证、随机排列交叉验证==========================================
# k折划分子集
kf = KFold(n_splits=2)
for train, test in kf.split(iris.data):
    print("k折划分:%s %s" % (train.shape, test.shape))
    break
# 留一划分子集
loo = LeaveOneOut()
for train, test in loo.split(iris.data):
    print("留一划分:%s %s" % (train.shape, test.shape))
    break
# 留p划分子集
lpo = LeavePOut(p=2)
for train, test in loo.split(iris.data):
    print("留p划分:%s %s" % (train.shape, test.shape))
    break
# 随机排列划分子集
ss = ShuffleSplit(n_splits=3, test_size=0.25,random_state=0)
for train_index, test_index in ss.split(iris.data):
    print("随机排列划分:%s %s" % (train.shape, test.shape))
    break
# ==================================分层K折交叉验证、分层随机交叉验证==========================================
skf = StratifiedKFold(n_splits=3)  #各个类别的比例大致和完整数据集中相同
for train, test in skf.split(iris.data, iris.target):
    print("分层K折划分:%s %s" % (train.shape, test.shape))
    break
skf = StratifiedShuffleSplit(n_splits=3)  # 划分中每个类的比例和完整数据集中的相同
for train, test in skf.split(iris.data, iris.target):
    print("分层随机划分:%s %s" % (train.shape, test.shape))
    break
# ==================================组 k-fold交叉验证、留一组交叉验证、留 P 组交叉验证、Group Shuffle Split==========================================
X = [0.1, 0.2, 2.2, 2.4, 2.3, 4.55, 5.8, 8.8, 9, 10]
y = ["a", "b", "b", "b", "c", "c", "c", "d", "d", "d"]
groups = [1, 1, 1, 2, 2, 2, 3, 3, 3, 3]
# k折分组
gkf = GroupKFold(n_splits=3)  # 训练集和测试集属于不同的组
for train, test in gkf.split(X, y, groups=groups):
    print("组 k-fold分割:%s %s" % (train, test))
# 留一分组
logo = LeaveOneGroupOut()
for train, test in logo.split(X, y, groups=groups):
    print("留一组分割:%s %s" % (train, test))
# 留p分组
lpgo = LeavePGroupsOut(n_groups=2)
for train, test in lpgo.split(X, y, groups=groups):
    print("留 P 组分割:%s %s" % (train, test))
# 随机分组
gss = GroupShuffleSplit(n_splits=4, test_size=0.5, random_state=0)
for train, test in gss.split(X, y, groups=groups):
    print("随机分割:%s %s" % (train, test))
# ==================================时间序列分割==========================================
tscv = TimeSeriesSplit(n_splits=3)
TimeSeriesSplit(max_train_size=None, n_splits=3)
for train, test in tscv.split(iris.data):
    print("时间序列分割:%s %s" % (train, test))

相关推荐

MySQL进阶五之自动读写分离mysql-proxy

自动读写分离目前,大量现网用户的业务场景中存在读多写少、业务负载无法预测等情况,在有大量读请求的应用场景下,单个实例可能无法承受读取压力,甚至会对业务产生影响。为了实现读取能力的弹性扩展,分担数据库压...

Postgres vs MySQL_vs2022连接mysql数据库

...

3分钟短文 | Laravel SQL筛选两个日期之间的记录,怎么写?

引言今天说一个细分的需求,在模型中,或者使用laravel提供的EloquentORM功能,构造查询语句时,返回位于两个指定的日期之间的条目。应该怎么写?本文通过几个例子,为大家梳理一下。学习时...

一文由浅入深带你完全掌握MySQL的锁机制原理与应用

本文将跟大家聊聊InnoDB的锁。本文比较长,包括一条SQL是如何加锁的,一些加锁规则、如何分析和解决死锁问题等内容,建议耐心读完,肯定对大家有帮助的。为什么需要加锁呢?...

验证Mysql中联合索引的最左匹配原则

后端面试中一定是必问mysql的,在以往的面试中好几个面试官都反馈我Mysql基础不行,今天来着重复习一下自己的弱点知识。在Mysql调优中索引优化又是非常重要的方法,不管公司的大小只要后端项目中用到...

MySQL索引解析(联合索引/最左前缀/覆盖索引/索引下推)

目录1.索引基础...

你会看 MySQL 的执行计划(EXPLAIN)吗?

SQL执行太慢怎么办?我们通常会使用EXPLAIN命令来查看SQL的执行计划,然后根据执行计划找出问题所在并进行优化。用法简介...

MySQL 从入门到精通(四)之索引结构

索引概述索引(index),是帮助MySQL高效获取数据的数据结构(有序),在数据之外,数据库系统还维护者满足特定查询算法的数据结构,这些数据结构以某种方式引用(指向)数据,这样就可以在这些数据结构...

mysql总结——面试中最常问到的知识点

mysql作为开源数据库中的榜一大哥,一直是面试官们考察的重中之重。今天,我们来总结一下mysql的知识点,供大家复习参照,看完这些知识点,再加上一些边角细节,基本上能够应付大多mysql相关面试了(...

mysql总结——面试中最常问到的知识点(2)

首先我们回顾一下上篇内容,主要复习了索引,事务,锁,以及SQL优化的工具。本篇文章接着写后面的内容。性能优化索引优化,SQL中索引的相关优化主要有以下几个方面:最好是全匹配。如果是联合索引的话,遵循最...

MySQL基础全知全解!超详细无废话!轻松上手~

本期内容提醒:全篇2300+字,篇幅较长,可搭配饭菜一同“食”用,全篇无废话(除了这句),干货满满,可收藏供后期反复观看。注:MySQL中语法不区分大小写,本篇中...

深入剖析 MySQL 中的锁机制原理_mysql 锁详解

在互联网软件开发领域,MySQL作为一款广泛应用的关系型数据库管理系统,其锁机制在保障数据一致性和实现并发控制方面扮演着举足轻重的角色。对于互联网软件开发人员而言,深入理解MySQL的锁机制原理...

Java 与 MySQL 性能优化:MySQL分区表设计与性能优化全解析

引言在数据库管理领域,随着数据量的不断增长,如何高效地管理和操作数据成为了一个关键问题。MySQL分区表作为一种有效的数据管理技术,能够将大型表划分为多个更小、更易管理的分区,从而提升数据库的性能和可...

MySQL基础篇:DQL数据查询操作_mysql 查

一、基础查询DQL基础查询语法SELECT字段列表FROM表名列表WHERE条件列表GROUPBY分组字段列表HAVING分组后条件列表ORDERBY排序字段列表LIMIT...

MySql:索引的基本使用_mysql索引的使用和原理

一、索引基础概念1.什么是索引?索引是数据库表的特殊数据结构(通常是B+树),用于...