本文共 1105 字,大约阅读时间需要 3 分钟。
用法:
from sklearn.model_selection import StratifiedShuffleSplitStratifiedShuffleSplit(n_splits=10,test_size=None,train_size=None, random_state=None)
参数说明
参数 n_splits是将训练数据分成train/test对的组数,可根据需要进行设置,默认为10
参数test_size和train_size是用来设置train/test对中train和test所占的比例。例如:
1.提供10个数据num进行训练和测试集划分 2.设置train_size=0.8 test_size=0.2 3.train_num=num*train_size=8 test_num=num*test_size=2 4.即10个数据,进行划分以后8个是训练数据,2个是测试数据注:train_num≥2,test_num≥2 ;test_size+train_size可以小于1
参数 random_state控制是将样本随机打乱
例子:
from sklearn.model_selection import StratifiedShuffleSplitss = StratifiedShuffleSplit(n_splits=1, test_size=0.2,random_state=0) #n_slpit 全体数据分组数目,random_state 不将样本随机打乱import numpy as npX = np.array([[1, 2], [3, 4], [1, 2], [3, 4], [1, 2],[3, 4], [1, 2], [3, 4]])#训练数据集8*2y = np.array([0, 0, 1, 1,0,0,1,1])#类别数据集8*1train_idx, val_idx = next(ss.split(X, y))print("train_idx:",train_idx)print("val_idx:",val_idx)for train_index, test_index in ss.split(X, y): print("train_idx:", train_index) print("val_idx:", test_index)
输出:
train_idx: [5 2 6 4 1 3]val_idx: [7 0]train_idx: [5 2 6 4 1 3]val_idx: [7 0]
转载地址:http://bkwmi.baihongyu.com/