数据集
继承 torch.utils.data.Dataset
,实现 3 个方法:
__init__
:构造函数,读取数据文件,存储到变量中__len__
:返回样本总数__getitem__
:返回下标为idx
的一个样本
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, file):
super().__init__()
self.data = read_file(file)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]