본문 바로가기

DL & ML/Graph

[DGL] 기본 message passing layer 만들기

728x90
반응형

GNN에서는 message passing을 통해 이웃 노드들의 feature를 통합하고, 이를 기반으로 타겟 노드의 새로운 embedding을 생성합니다.

이 과정을 위해 dgl 라이브러리에서는 여러 built-in function을 제공합니다.

먼저, 이웃 노드의 feature를 통합하여 message 형태로 만드는 역할을 message function,

message로 새로운 vector를 만드는 역할을 하는 것을 reduce function이라고 합니다.

 

대부분의 GNN에서 message passing은 반드시 필요합니다. message passing을 통해서 이웃 노드의 feature를 가지고 오고, 가지고 온 feature들을 통해서 그 노드의 새로운 을 결정합니다. DGL에서는 이웃 노드의 feature를 message 형태로 모으는 과정을 message function, 모은 message를 가공하여 새로운 vector를 만들어내는 과정을 reduce function을 사용하여 처리합니다. message function과 reduce function을 동시에 실행시키려면 update_all(message_function, reduce_function) 함수를 사용하면 됩니다.

 

종류는 크게 Message Function과 Reduce Function으로 나누어집니다.

 

dgl.DGLGraph.update_all란?

update_all이란 message function과 reduce function을 동시에 실행시킬 수 있도록 도와주는 역할입니다.

즉, 이웃 노드의 feature를 수집하고 가공까지 한 번에 해주는 것이죠.

사용법은 아래와 같습니다. 

 

dgl.g.update_all(message_func, reduce_fun) # g : dgl 그래프

 

적용할 그래프에 대해 message function과 reduce function을 파라미터로 넣어주면 되는 것이죠!

그렇다면 message function과 reduce function에는 어떤 것이 있을까요?

 

 

dgl.function.copy_u 란? (Message Function)

copy_u란 source node의 feature를 사용하여 message를 계산해주는 역할을 한다고 공식 문서 상에 나와있습니다.

dgl tutorial(github)의 설명을 참고하면 message가 이웃에 전달 될 때 노드 피쳐 'h'를 복사하는 역할을 한다고 합니다!

 

사용법은 매우 간단합니다.

 

copy_u(u, out)

 

1) u - 소스 노드의 feature field (str)

2) out - output 메세지 field (str)

 

이렇게만 보면 잘 감이 안 옵니다. 

그렇다면 공식 문서의 예제를 살펴봅시다.

만약 우리가 아래와 같이 copy_u를 사용하여 message_func을 정의하였다면!

 

message_func = dgl.function.copy_u('h', 'm')

 

message_func은 아래와 같은 함수와 동일한 역할을 한다고 합니다.

 

def message_func(edges):
    return {'m': edges.src['h']}

 

출발 노드의 'h' feature를 message의 'm'이라는 feature에 넣어서 도착 노드에 전달한다라고 생각하면 됩니다.

이 외에도 여러 message function이 존재하니 공식 문서를 한 번 참고해보세요!

 

 

dgl.function.mean란? (Reduce Function)

 

dgl.function.mean('m', 'h')

 

mean 함수는 이름에서 느껴지듯이 모든 수신된 메세지의 'm' feature에 대해 평균을 내고, 그 결과를 도착 노드의 새로운 feature 'h'에 저장하는 역할을 합니다.

이건 매우 간단하죠?

이 외에도 sum, min, max 등의 여러 reduce function이 존재합니다!

 

 

이렇게 간단하게! message passing layer를 만드는 방법을 살펴봤습니다.

자세한 예제를 살펴보고 싶다면 아래 github tutorial을 참고하시길 바랍니다.

 

 

 

GitHub - myeonghak/DGL-tutorial: 그래프 딥러닝 라이브러리 DGL 쉽게 배우기

그래프 딥러닝 라이브러리 DGL 쉽게 배우기. Contribute to myeonghak/DGL-tutorial development by creating an account on GitHub.

github.com

 

 

 

 

 

728x90
반응형