DataLoader는 데이터를 미니 배치 단위로 나누어서 제공해주는 역할을 합니다.
학습을 하기 위해서 데이터를 읽어올 때 사용하게 됩니다.
dataset 인자에는 pytorch Dataset 객체를 넣어주면 됩니다.
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
인자를 하나씩 살펴보겠습니다.
- batch_size : 말 그대로 batch_size. 지정한 값에 따라 데이터를 나누어 반환
- shuffle : 데이터를 섞어서 사용할 것인가의 여부
- sampler : 데이터 인덱스를 다루는 방법. 인덱스를 직접 다루기 때문에 shuffle 파라미터는 반드시 False여야 사용할 수 있음
불균형 데이터셋의 경우 클래스의 비율에 맞게 데이터를 제공해야할 필요가 있으며 이때 사용할 수 있는 옵션입니다.
map-style 데이터셋에서는 __len__과 __iter__를 통해 구현이 가능하며 그 외로는
1) SequentialSampler : 항상 같은 순서
2) RandomSampler : 랜덤, replacement 여부 선택, 개수 선택 가능
3) SubsetRandomSampler : 랜덤 리스트
4) WeightRandomSampler : 가중치에 따른 확률
5) BatchSampler : batch 단위로 sampling
6) DistributedSampler : 분산처리
* [sampler 관련 추가 참고]
- num_works : 데이터를 불러올 때 사용하는 서브 프로세스의 개수
num_works가 무조건 높다고 좋은 것은 아님. 오히려 병목현상이 생길 수도 있음
- collate_fn : 미니 배치를 생성하기 위해 샘플 리스트를 병합하는 역할.
주로 zero-padding이나 variable-size 등 데이터 사이즈를 맞추기 위해 함수를 만들어 적용함.
프린터기에서 인쇄할 때 묶어서 인쇄하기같은 기능이라고 합니다.
((피쳐1, 라벨1), (피처2, 라벨2))와 같은 배치 단위 데이터를 ((피처1, 피처2), (라벨1, 라벨2))처럼 바뀌게 됩니다.
def collate_fn(batch):
print('Original:\n', batch)
print('-'*100)
data_list, label_list = [], []
for _data, _label in batch:
data_list.append(_data)
label_list.append(_label)
print('Collated:\n', [torch.Tensor(data_list), torch.LongTensor(label_list)])
print('-'*100)
return torch.Tensor(data_list), torch.LongTensor(label_list)
- pin_memory : True로 설정 시 tensor를 CUDA 고정 메모리에 할당. 고정 된 메모리에서 데이터를 가져오므로 훨씬 빠름
하지만 일반적인 경우에는 많이 사용하지 않는다고 합니다!
- drop_last : batch 단위로 데이터를 불러오면 마지막 batch의 길이가 달라질 수 있음.
batch의 크기에 따른 의존도가 높은 함수를 사용할 때에는 마지막 batch를 drop하여 사용하지 않을 수도 있음
- time_out : 데이터로더가 데이터를 불러올 때 제한시간을 둠
- worker_init_fn : 어떤 worker를 불러올 것인가
'Programming > PyTorch' 카테고리의 다른 글
[PyTorch] IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1) (0) | 2022.02.22 |
---|---|
[PyTorch] Tensor와 tensor (0) | 2022.01.27 |
[PyTorch] Dataset Types 정리 (Map-style datasets, Iterable-style datasets) (0) | 2022.01.25 |
[PyTorch] torch.nn.Embedding 의 역할 (0) | 2021.11.22 |
[PyTorch] torch.clamp 함수 (0) | 2021.11.22 |