使用numpy截取连续的测试集及剩余作为训练集

这个repo 用来记录一些python技巧、书籍、学习链接等,欢迎star

github地址

本来想从numpy中找一个可以将array格式的数据提取出连续的测试集,其余作为训练集的函数,但是没找到,就自己写了几个,记录之~

import numpy as np

"""
utils:
    用于截取连续的测试集及剩余作为训练集
"""

def trunc(arr, start, end):
    sample = arr[start:end]
    remainder = np.delete(arr, np.s_[start:end], axis=0)
    return sample, remainder

def trunc2(arr, start, end):
    rows = arr.shape[0]
    sample = arr.take(range(start, end), axis=0)
    remaind = arr.take(list(set(range(rows)) - set(range(start, end))), axis=0)
    return sample, remaind

def trunc3(arr, start, end):
    head, middle, tail = np.split(arr, (start, end))
    head_tail = np.r_[head, tail]
    return middle, head_tail

if __name__ == "__main__":
    arr = np.arange(24).reshape(6, 4)
    print(trunc(arr, 2, 5))
    print(trunc2(arr, 2, 5))
    print(trunc3(arr, 2, 5))

这个repo 用来记录一些python技巧、书籍、学习链接等,欢迎star

github地址

你可能感兴趣的:(numpy,数据分割)