본문 바로가기

Programming/PyTorch

[PyTorch] DataLoader의 역할 및 사용법

728x90
반응형

DataLoader는 데이터를 미니 배치 단위로 나누어서 제공해주는 역할을 합니다.

학습을 하기 위해서 데이터를 읽어올 때 사용하게 됩니다.

dataset 인자에는 pytorch Dataset 객체를 넣어주면 됩니다.

 

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)

 

 

인자를 하나씩 살펴보겠습니다.

 

- batch_size : 말 그대로 batch_size. 지정한 값에 따라 데이터를 나누어 반환

- shuffle : 데이터를 섞어서 사용할 것인가의 여부

- sampler : 데이터 인덱스를 다루는 방법. 인덱스를 직접 다루기 때문에 shuffle 파라미터는 반드시 False여야 사용할 수 있음

                   불균형 데이터셋의 경우 클래스의 비율에 맞게 데이터를 제공해야할 필요가 있으며 이때 사용할 수 있는 옵션입니다.

                    map-style 데이터셋에서는 __len__과 __iter__를 통해 구현이 가능하며 그 외로는

                    1) SequentialSampler : 항상 같은 순서

                    2) RandomSampler : 랜덤, replacement 여부 선택, 개수 선택 가능

                    3) SubsetRandomSampler : 랜덤 리스트

                    4) WeightRandomSampler : 가중치에 따른 확률

                    5) BatchSampler : batch 단위로 sampling

                    6) DistributedSampler : 분산처리

                   * [sampler 관련 추가 참고]

- num_works : 데이터를 불러올 때 사용하는 서브 프로세스의 개수

                          num_works가 무조건 높다고 좋은 것은 아님. 오히려 병목현상이 생길 수도 있음

- collate_fn : 미니 배치를 생성하기 위해 샘플 리스트를 병합하는 역할. 

                       주로 zero-padding이나 variable-size 등 데이터 사이즈를 맞추기 위해 함수를 만들어 적용함.

                       프린터기에서 인쇄할 때 묶어서 인쇄하기같은 기능이라고 합니다.

                      ((피쳐1, 라벨1), (피처2, 라벨2))와 같은 배치 단위 데이터를 ((피처1, 피처2), (라벨1, 라벨2))처럼 바뀌게 됩니다. 

https://www.coastalcreative.com/wp-content/uploads/2019/10/collated-not-collated-543x600.jpg

def collate_fn(batch):
    print('Original:\n', batch)
    print('-'*100)
    
    data_list, label_list = [], []
    
    for _data, _label in batch:
        data_list.append(_data)
        label_list.append(_label)
    
    print('Collated:\n', [torch.Tensor(data_list), torch.LongTensor(label_list)])
    print('-'*100)
    
    return torch.Tensor(data_list), torch.LongTensor(label_list)

 

- pin_memory : True로 설정 시 tensor를 CUDA 고정 메모리에 할당. 고정 된 메모리에서 데이터를 가져오므로 훨씬 빠름

                            하지만 일반적인 경우에는 많이 사용하지 않는다고 합니다!

 

- drop_last : batch 단위로 데이터를 불러오면 마지막 batch의 길이가 달라질 수 있음. 

                      batch의 크기에 따른 의존도가 높은 함수를 사용할 때에는 마지막 batch를 drop하여 사용하지 않을 수도 있음

 

- time_out : 데이터로더가 데이터를 불러올 때 제한시간을 둠

 

- worker_init_fn : 어떤 worker를 불러올 것인가

 

 

 

 

 

torch.utils.data — PyTorch 1.10.1 documentation

torch.utils.data At the heart of PyTorch data loading utility is the torch.utils.data.DataLoader class. It represents a Python iterable over a dataset, with support for These options are configured by the constructor arguments of a DataLoader, which has si

pytorch.org

 

728x90
반응형