此代码可以将数据分为K份,并返回每份索引所构成的列表。
def get_cross_validation_list(data, fold):
"""
K折交叉验证
把每份的索引返回为列表形式,列表的元素是索引所构成的列表
:param data:原始数据
:param fold:折数
:return:
"""
# 有几个组的样本数是多一个的(我们称之为不普通组),例如:199个样本分为10份,肯定有9份样本数是20,1份样本数是19
number_of_unusual_group = len(data) % fold
# 普通组的样本数
number_of_usual_group_sample = int(len(data) / fold)
# 不普通组的样本数
number_of_unusual_group_sample = int(len(data) / fold) + 1
# 存放所有的索引
all_index = list(range(len(data)))
# 记录不普通组的数目
flag = 1
final_index_list = []
for i in range(fold):
# 如果还有不普通组,就多选一个
if flag <= number_of_unusual_group:
choice_index = random.sample(all_index, number_of_unusual_group_sample)
final_index_list.append(choice_index)
flag += 1
else:
choice_index = random.sample(all_index, number_of_usual_group_sample)
final_index_list.append(choice_index)
# 在所有索引中去除掉已经被选择的
for k in choice_index:
all_index.remove(k)
return final_index_list
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34