nn.Module의 입력은 뭐고, 클래스 메소드 forward()에 모델 입력은 어떻게 전달하는가?
(1) 모델의 입력 X가 nn.Module(X)로써 전달이 된다면 이 X는 forward에 입력 인자로 전달된다.
forward 함수는 사용자가 직접 호출하지 않기 때문
(2) 이렇게 입력 받은 모델의 input은 자동으로 forward의 함수 입력으로 들어가서 자동으로 forward propagtion한다.
(3) 이때 입력의 길이, 갯수는 가변인자(*input) 임
nn.Module을 활용한 모델 디자인
nn.Module을 상속하여 다음과 같이 모델을 디자인할 수 있다.
일반적으로, pytorch의 nn.Module을 사용하여 모델을 디자인 하기 위해서는
(1) nn.Module을 상속하는 모델 클래스의 선언
(2) __init__ 함수와 forward 함수 등의 오버라이딩 필수 함수들을 오버라이딩
하여 디자인하고, 이렇게 디자인된 모델에 대하여 하기와 같이 인스턴스 입력에 input을 넣어주면 자동으로 forward의 입력 인자로 전달 및 forward propagtion이 수행된다.
# NeuralNetwork는 모델을 디자인한 class
model = NeuralNetwork().to(device)
# 모델의 입력이 X일 때 forward propagation
pred = model(X)
예시
Pytorch Tutorial의 Quick Start 예시 코드(필자 일부 수정)를 통해 그 사용법을 확인 할 수 있다.
class NeuralNetwork(nn.Module): # 괄호는 본 class가 nn.Module을 상속함을 의미
def __init__(self):
super(NeuralNetwork, self).__init__() # 상위의 모든 class 상속
self.flatten = nn.Flatten() # instansiate: nn.Flatten을 인스턴스화 하여
# 해당 클래스를 호출 시 self.flatten(입력)으로 가능
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10)
)
def forward(self, x): # x는 모델의 입력
x = self.flatten(x) # nn.Flatten(x)의 호출
logits = self.linear_relu_stack(x)
return logits
model = NeuralNetwork().to(device)
print(model)
X, y in dataloader
pred = model(X) # model.forward() 함수에 input 전달 및 forward propagation
# 코드 출처: https://tutorials.pytorch.kr/beginner/basics/quickstart_tutorial.html
[참고]
- evnet trigger인 hook에 의해 forward()함수를 호출하지 않아도 자동으로 forward propagation이 수행
- 모델의 인스턴스 선언( ex. Model(input))만으로도 forward() 함수에 입력 전달 및 forward propagation 수행
출처: 빠른 시작(Quickstart) — 파이토치 한국어 튜토리얼 (PyTorch tutorials in Korean)
728x90
'Computer Vision' 카테고리의 다른 글
[환경 구축] Windows11 / WSL2 / NVIDIA GeForce GTX 1660 Ti / Pytorch / VS code 환경에서 딥러닝 개발하기 (0) | 2022.09.30 |
---|---|
(진행 중) [논문리뷰] Batch Normalization (0) | 2022.09.13 |
[Python] .scatter()함수 이해 (0) | 2022.08.18 |
(pred.argmax(1) == y).type(torch.float).sum().item() 해석 (0) | 2022.08.17 |
Python 변수로 Class 명 치환(Instanstiate) (0) | 2022.08.16 |