본문 바로가기

DL & ML/Graph

[DGL] TypeError : default_collate 에러 (DGL.batch, collate_fn)

728x90
반응형

데이터셋을 만들고 dataloader로 넘겨주었는데 dataloader에서 값을 받는 중에 만난 에러!

내가 만들고자 한 것은 dgl 라이브러리의 heterogeneous graph 타입의 그래프를 반환해주는 dataset / dataloader 였다.

dataset과 dataloader는 PyTorch 타입을 사용했는데 아래와 같은 에러를 만났다.

 

TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'dgl.heterograph.DGLHeteroGraph'>

 

찾아보니 배치로 넘겨줄 때의 output 값은 꼭 tensors, numpy arrays, numbers, dicts, lists 중 하나여야 한다고..

그래프 데이터를 넘겨주는 건 이번이 처음이라서 이러한 제약조건이 있는지 몰랐다.

그냥 내가 원하는 데이터를 자동으로 배치로 만들어서 보내주겠거니 했지...

 

그렇다면 나는 그래프 데이터를 사용해야하는데 이럴때는 어떻게 해야할까?

바로! dataset의 collate_fn과 dgl.batch를 활용해주면 된다.

 

우선 collate_fn에 대해 살펴보자.

 

 

torch.utils.data — PyTorch 1.10 documentation

torch.utils.data At the heart of PyTorch data loading utility is the torch.utils.data.DataLoader class. It represents a Python iterable over a dataset, with support for These options are configured by the constructor arguments of a DataLoader, which has si

pytorch.org

 

 

공식 문서의 설명을 참고해보면!

 

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)

 

Dataset과 DataLodaer의 관계는 아래의 그림과 같다.

https://hulk89.github.io/pytorch/2019/09/30/pytorch_dataset/

 

Dataset은 기본적으로 index를 통해서 데이터를 접근할 수 있게 해준다.

이때 Dataset을 호출 시 어떤 인덱스의 데이터를 가져올지를 결정해주는 역할이 Sampler이다.

Dataset을 순서대로 접근이 아니라 shuffle 하여 랜덤한 순서로 접근하고 싶을 때의 역할이 바로 Sampler!

DataLoader를 통해 batch 단위로 sampling 할 시에 사용되는 것은 batch_sampler!

 

데이터셋이 batch_sampler를 통해 묶이고 나면 그 다음 collate_fn을 호출하여 최종적인 배치로 만들어준다.

collate_fn이 자주 쓰이는 예로는 서로 Input 길이가 다르면 배치로 묶지 못하므로 이런 부분을 조정해주는 역할을 주로 한다.

 

그럼 우리는 collate_fn을 잘 활용하여 데이터로더에서 기본적으로 제공하지 않는 데이터 타입도 데이터로더로 받을 수 있다.

어떻게? 바로 dgl에서 제공하는 batch 기능을 활용하면 된다!

 

 

 

dgl.batch — DGL 0.6.1 documentation

© Copyright 2018, DGL Team. Revision 5906fa4c.

docs.dgl.ai

 

dgl에서 제공하는 배치란 여러 개의 DGLGraph를 하나의 그래프로 만들어 더 효과적으로 그래프 계산을 할 수 있게 도와주는 것이다.

여러 개의 그래프를 하나의 그래프로 만들면 내가 원하는 그래프와 달라지는데?! 하고 생각할 수 있는데

여기서 여러 개의 그래프를 모두 뭉뚱그려 하나로 만드는 것이 아니라 각각을 disjoint 한 요소로 보아 하나의 그래프로 만든다.

 

아래의 표를 보면 조금 더 이해하기 쉬울 것 같다.

 

Original node ID 0 ~ N_0 0 ~ N_1 0 ~ N_k
New node ID 0 ~ N_0 N_0+1 ~ N_0+N_1+1 1+sum_{i=0}^{k-1} N_i ~ 1+sum_{i=0}^k N_i

 

기존에 여러 개의 그래프 노드 id를 새로운 노드 id로 mapping 하여 각각을 연결되지 않은, disjoint한 그래프 집합으로 만든다는 것!

이렇게 batch 형태로 그래프를 만들면 그래프 각각에 대해 수행하는 것과 동일하게 연산이 가능하지만 병렬적으로, 훨씬 효율적으로 처리가 가능하다고 한다.

 

heterograph를 입력으로 줄 때는 반듯이 동일한 relations 형태의 그래프(노드 / 엣지 타입이 동일해야)여야 각각에 대해서 수행이 잘 된다! 그래야 최종적인 결과도 동일한 relations을 가진 heterograph로 나올 수 있다고 한다.

 

그런데 batch로 만들어서 그래프를 확인해보면 그래프 내의 노드 개수가 각가의 그래프의 노드 개수를 통합한 개수인 것을 확인할 수 있다. (당연히 여러 그래프를 통합해서 하나의 그래프를 만드는 것이니까...!)

 

그렇다면 각각의 그래프에 대해서 노드 수 또는 엣지 수를 확인하고 싶으면 어떻게 해야할까?

바로 batch_num_nodes()를 통해 각 배치 별 그래프 노드의 개수를 확인할 수 있다!

 

예를 들어 batch로 만든 그래프가 다음과 같다고 하면, 노드와 엣지 개수가 매우 많은 것을 확인할 수 있다.

Graph(num_nodes={'item': 8554, 'user': 7930},
      num_edges={('item', 'is_rated', 'user'): 196290, ('user', 'is_rating', 'item'): 196290},
      metagraph=[('item', 'user', 'is_rated'), ('user', 'item', 'is_rating')])

 

batch_num_nodes를 사용하면 배치 내 각 그래프 별로 노드의 개수를 확인할 수 있다!

 

>>> g.batch_num_nodes('user')
tensor([150,  55, 234, 201, 101, 269, 183,  48, 147,  53,  24, 140,  75,  34,
        200,  71,  81, 228, 101,  72, 324, 164,  40,  82, 246,  95,  91, 139,
         31, 124,  32, 200,  13,  78,  14, 227,  94, 168, 164,  38,  19, 177,
         13, 254,  53, 232,  80,  96, 101,  76,  91, 182,  92, 324, 113, 141,
        140,  49, 158,  72,  94, 122, 185, 235])

 

batch_num_edges를 사용하면 배치 내 각 그래프 별로 엣지의 개수를 확인할 수 있다!

 

>>> g.batch_num_edges(('is_rating'))
tensor([ 2697,   814,  5385,  2665,   902,  1575,  5711,  3395, 10838,  4209,
          210,  4202,  2594,  1605,  1309,  5086,  2285,  4100,  1037,  1678,
        10144,  1254,   925,  2729,  6340,  2200,   619,  2653,   144,  1827,
         1033,  6108,   865,  1233,   574,  4229,  2884,  3699,  1197,  1417,
          123,  5538,   382,  7359,  3316,  6393,  1793,  3396,  2569,  4834,
         3663,   838,  2704,  2833,  2455,  4527,  3579,   403,  4923,   423,
         5083,  7091,  4497,  3197])

 

 

추가적으로 batch로 만들어진 그래프에서는 feature를 어떻게 처리할 수 있을까?

default로는 DGL.batch에서는 node / edge features는 모든 input graphs의 feature tensors를 concat하여 일괄 처리한다.

따라서 동일한 이름의 feature는 data type과 feature size가 동일해야 한다.

ndata 또는 edata를 None 값으로 주어 feature 일괄 처리를 막을 수도 있고 어떤 feature를 처리할지 리스트로 전달할 수도 있다.

 

 

DGL에서의 batch 사용법을 간단하게 다뤄보았다!

그렇다면 결론적으로 어떻게 graph를 데이터로더를 통해 전달할 수 있는가?

collate_fn 함수에서 아래와 같이 batch를 형성해 전달해주면 된다!

 

def _collate_fn(data) :
    graph_lst, labe_lst = map(list, zip(*data))
    graph_batch = dgl.batch(graph_lst)
    return graph_batch, label_lst

 

728x90
반응형