(pred.argmax(1) == y).type(torch.float).sum().item() 해석
1. pred는 model(x)의 출력으로 입력 이미지에 대하여 각 10개 class별 socre를 출력하여 모은 값.
[
[img1 10개 스코어],
[img2 10개 스코어],
[img3 10개 스코어]
...
]
2. pred.argmax(1)은 각 row의 방향에서 가장 큰 값을 찾고 해당 값의 인덱스를 반환함, 즉, 각 이미지가 가장 높은 score를 보인 class가 어딘지를 반환함
3. 높은 값을 보인 class(예측값)이 Groud Truth인 y와 동일한지 확인하여 True/False 를 반환(pred.argmax(1) == y)
각 이미지에 대한 True/Fase이므로 한번에 트레이닝 시키는 이미지 개수만큼 반환(batch)
4. 각 True/False로 이뤄진 텐서에 대하여 torch.float 형식으로 변환함 -> o, 또는 1, 로 변환
5. batch 단위로 0 또는 1(각 이미지 분류가 G.T 라벨과 같은지) 64개가 모인 tensor의 각 요소를 합하여 tensor(값.) 형식으로 반환함:sum()
6. accuracy 값! 이 필요하므로 tensor의 요소값을 빼내기 위하여 item()함수를 활용
def test(dataloader, model, loss_fn):
size=len(dataloader.dataset)
num_batches= len(dataloader)
model.eval()
test_loss, correct =0,0
with torch.no_grad(): # because it is test time
for X, y in dataloader: # in batch
X, y = X.to(device), y.to(device) # to GPU
pred=model(X) # forward propa
test_loss+= loss_func(pred,y).item() # if reduction is true of nn.CrossEntropyLoss(), its output is scalar or same as input
# loss btw 2 probabilities
correct+= (pred.argmax(1) == y).type(torch.float).sum().item() # y is G.T,
# output : prob of each class : vector
print((pred.argmax(1) == y).type(torch.float))
test_loss /= num_batches
correct /= size
print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
epoch = 5
for t in range(5):
train(train_dataloader, model, loss_func, optimizer)
test(test_dataloader, model, loss_func)
print("done!")
pred=model(x)에 대하여 pred 값 출력
1. pred는 model로 부터 출력된 score의 모음
2. 아래 각 색깔 별로 하나의 이미지고, 한 이미지는 10개의 class에 대하여 분류되므로 1 이미지당 10 socre(색깔 별)
코드 출처: 빠른 시작(Quickstart) — 파이토치 한국어 튜토리얼 (PyTorch tutorials in Korean)
'Computer Vision' 카테고리의 다른 글
[환경 구축] Windows11 / WSL2 / NVIDIA GeForce GTX 1660 Ti / Pytorch / VS code 환경에서 딥러닝 개발하기 (0) | 2022.09.30 |
---|---|
(진행 중) [논문리뷰] Batch Normalization (0) | 2022.09.13 |
[Python] .scatter()함수 이해 (0) | 2022.08.18 |
nn.Module 입력과 forward 함수: Pytorch에서 모델 호출법 (0) | 2022.08.16 |
Python 변수로 Class 명 치환(Instanstiate) (0) | 2022.08.16 |