toch.utils.data 공식 문서를 읽고 pytoch dataloder 사용법을 기록합니다.

Dataloader

torch.utils.data.DataLoader class constructor arguments는 다음과 같습니다.

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)

Dataset

PyTorch에서 지원하는 dataset의 타입은 두 가지가 있습니다.

Map-style datasets
  • Python __getitem__(), __len__() 기능 사용
  • indices/keys로 부터 데이터 샘플을 매핑하는 방식
  • dataset[idx] 방식으로 데이터 읽음
Iterable-style datasets
  • Python __iter__() 기능 사용
  • Random read 비용이 많이 들고 힘들거나, 가져온 데이터에 따라 배치 크기가 달라지는 경우에 유용
  • iter(dataset) 방식으로 데이터 읽음
  • 자세한 내용은 IterableDataset 참고

Sampler

Iterable-style dataset의 data loading 순서는 user-defined iterable을 통해 온전히 조작 가능합니다. 반면에, map-style dataset의 data loading에 사용되는 indices/keys의 순서는 torch.utils.data.Sampler class를 사용해 명시 가능합니다.

  • Sampler는 dataset indices에 대한 iterable object를 나타냅니다.
  • Sampler를 따로 정의하지 않더라도 Dataloader class constructor의 shuffle argument에 의해서 sequential(shuffle False인 경우) 혹은 shuffled(shuffle True인 경우) sampler가 자동으로 생성됩니다. 이 경우 batch_sizedrop_last를 목적에 따라 적절히 선택합시다.
  • 매번 list of batch indices를 뱉는 커스텀 Sampler를 새로 정의하는 경우에는, 해당 Sampler를 정의한 후에 Dataloader constructor의 batch_sampler argument에 전달하면 됩니다.

Example 1.

가장 간단한 예시로 BatchSampler를 확인해보겠습니다. mini-batch SGD에서 사용되는 Sampler이며 아래의 코드는 PyTorch 코드를 보기 쉽게 정리한 것입니다.

batch_size를 통해 batch size를 결정하고 drop_last를 통해 데이터 셋 제일 마지막 나머지 부분을 버릴지 말지 결정합니다.

class BatchSampler(Sampler):
    def __init__(self, sampler, batch_size, drop_last):
        self.sampler = sampler
        self.batch_size = batch_size
        self.drop_last = drop_last

    def __iter__(self):
        batch = []
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch

    def __len__(self):
        if self.drop_last:
            return len(self.sampler) // self.batch_size
        else:
            return (len(self.sampler) + self.batch_size - 1) // self.batch_size

Example 2.

개인적으로 작업하고 있는 내용 중에서 서로 다른 두 개의 데이터 셋을 concat 해야하는 상황이 존재하여 아래와 같이 코드를 작성해 보았습니다.

두 개의 데이터 셋을 입력으로 받아 __getitem__을 통해 이미지 데이터를 출력하는 ConcatDataset라는 이름의 map-style dataset을 제작하였습니다. __getitem__는 (데이터셋 index, 이미지 index) 형태의 indicies tuple을 입력으로 받습니다. 따라서 추후 custom sampler class에 indicies 순서 선정 방법을 명시해주면 됩니다.

class ConcatDataset(torch.utils.data.DataLoader):
    def __init__(self, first_dataset, second_dataset):
        self.datasets = [first_dataset, second_dataset]

    def __len__(self):
        return sum([dataset.__len__() for dataset in self.datasets])

    def __getitem__(self, indicies):
        dataset_idx = indicies[0]
        data_idx = indicies[1]
        return self.datasets[dataset_idx].__getitem__(data_idx)

indicies 선택을 위한 ConcatSampler class를 제작합니다. 저는 첫 번째 데이터 셋은 84 batch size, 두 번째 데이터 셋은 42 batch size, 총 126 batch size를 갖는 데이터로더가 필요하여 아래와 같이 코드를 작성하였습니다.

class ConcatSampler():
    def __init__(self, samplers, first_batch_size=84, second_batch_size=42):
        self.samplers = samplers
        self.batch_sizes = [first_batch_size, second_batch_size]
        self.total_batch_size = first_batch_size + second_batch_size

        # Select larger number as n_batches
        if len(samplers[0]) // first_batch_size > len(samplers[1]) // second_batch_size:
            self.n_batches = len(samplers[0]) // first_batch_size
        else:
            self.n_batches = len(samplers[1]) // second_batch_size

    def __iter__(self):
        for i in range(self.n_batches):
            batch = []
            for dataset_idx in [0, 1]:
                for idx in self.samplers[dataset_idx]:
                    batch.append((dataset_idx, idx))
                    if len(batch) == self.batch_sizes[0] and dataset_idx == 0:
                        break
                    if len(batch) == self.total_batch_size and dataset_idx == 1:
                        break
            yield batch

    def __len__(self):
        return self.n_batches

실제 사용을 위해서 ConcatSampler와 ConcatDataset를 생성하고, 이를 DataLoader constructor argument로 전달합니다. train_loader를 사용하여 데이터를 로드했을 때, 첫 번째 데이터 셋에서는 84개의 데이터 샘플이, 두 번째 데이터 셋에서는 42개의 데이터 샘플이 로드됩니다.

sampler = ConcatSampler([torch.utils.data.sampler.RandomSampler(dataset) for dataset in [first_dataset, second_dataset]])
dataset = ConcatDataset(first_dataset, second_dataset)
train_loader = torch.utils.data.DataLoader(dataset, batch_sampler=sampler, num_workers=8, pin_memory=True)

References