본문 바로가기

Computer Vision

nn.Module 입력과 forward 함수: Pytorch에서 모델 호출법

nn.Module의 입력은 뭐고, 클래스 메소드 forward()에 모델 입력은 어떻게 전달하는가?

(1) 모델의 입력 X가 nn.Module(X)로써 전달이 된다면 이 X는 forward에 입력 인자로 전달된다.

     forward 함수는 사용자가 직접 호출하지 않기 때문

(2) 이렇게 입력 받은 모델의 input은 자동으로 forward의 함수 입력으로 들어가서 자동으로 forward propagtion한다. 
(3) 이때 입력의 길이, 갯수는 가변인자(*input) 임

pytorch document : nn.Module의 forward 함수에 대한 설명

nn.Module을 활용한 모델 디자인

nn.Module을 상속하여 다음과 같이 모델을 디자인할 수 있다.

일반적으로, pytorch의 nn.Module을 사용하여 모델을 디자인 하기 위해서는

    (1) nn.Module을 상속하는 모델 클래스의 선언

    (2) __init__ 함수와 forward 함수 등의 오버라이딩 필수 함수들을 오버라이딩

 

하여 디자인하고, 이렇게 디자인된 모델에 대하여 하기와 같이 인스턴스 입력에 input을 넣어주면 자동으로 forward의 입력 인자로 전달 및 forward propagtion이 수행된다. 

# NeuralNetwork는 모델을 디자인한 class
model = NeuralNetwork().to(device)

# 모델의 입력이 X일 때 forward propagation
pred = model(X)

예시

Pytorch Tutorial의 Quick Start 예시 코드(필자 일부 수정)를 통해 그 사용법을 확인 할 수 있다.

class NeuralNetwork(nn.Module):	# 괄호는 본 class가 nn.Module을 상속함을 의미
    def __init__(self):
        super(NeuralNetwork, self).__init__() # 상위의 모든 class 상속
        
        self.flatten = nn.Flatten()	# instansiate: nn.Flatten을 인스턴스화 하여 
        				# 해당 클래스를 호출 시 self.flatten(입력)으로 가능
                                                
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):	# x는 모델의 입력
        x = self.flatten(x) 	# nn.Flatten(x)의 호출
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork().to(device)
print(model)

X, y in dataloader
pred = model(X)		# model.forward() 함수에 input 전달 및 forward propagation

# 코드 출처: https://tutorials.pytorch.kr/beginner/basics/quickstart_tutorial.html

 

 

 

[참고] 

  • evnet trigger인 hook에 의해 forward()함수를 호출하지 않아도 자동으로 forward propagation이 수행
  • 모델의 인스턴스 선언( ex. Model(input))만으로도 forward() 함수에 입력 전달 및 forward propagation 수행 

 

출처: 빠른 시작(Quickstart) — 파이토치 한국어 튜토리얼 (PyTorch tutorials in Korean)

 

빠른 시작(Quickstart)

파이토치(PyTorch) 기본 익히기|| 빠른 시작|| 텐서(Tensor)|| Dataset과 Dataloader|| 변형(Transform)|| 신경망 모델 구성하기|| Autograd|| 최적화(Optimization)|| 모델 저장하고 불러오기 이번 장에서는 기계 학습의

tutorials.pytorch.kr

Module — PyTorch 1.12 documentation

728x90