본문 바로가기

Programming/PyTorch

[PyTorch] torch.nn.ModuleList()란?

728x90
반응형

Tensorflow만 주로 사용하다가 pytorch를 사용한 모델을 베이스로 개선하려니 모르는 것 투성이라서 하나씩 정리해보기로 한다!

 

 

ModuleList — PyTorch 1.10.0 documentation

Shortcuts

pytorch.org

 

[ModuleList]

모듈을 리스트 형태로 저장하는 것을 의미한다.

파이썬의 일반적인 리스트처럼 인덱스로 접근이 가능하다.

공식 홈페이지에 있는 사용 예제를 살펴보자.

 

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])

    def forward(self, x):
        # ModuleList can act as an iterable, or be indexed using ints
        for i, l in enumerate(self.linears):
            x = self.linears[i // 2](x) + l(x)
        return x

 

1) ModuleList에 nn.Linear 모듈을 10개를 저장함

2) 이를 리스트처럼 순서대로 iterable하게 접근하여 사용

 

간단하고 매우 직관적이다!

말그대로 모듈을 리스트로 저장하여 좀 더 간단하고 편리하게 다룰 수 있게 도와준다고 보면 될 것 같다.

728x90
반응형