본문 바로가기

Computer Vision

(pred.argmax(1) == y).type(torch.float).sum().item() 해석

(pred.argmax(1) == y).type(torch.float).sum().item() 해석

1. pred는 model(x)의 출력으로 입력 이미지에 대하여 각 10개 class별 socre를 출력하여 모은 값.

 

[

[img1 10개 스코어],

 [img2 10개 스코어],

 [img3 10개 스코어]

 ...

]

2. pred.argmax(1)은 각 row의 방향에서 가장 큰 값을 찾고 해당 값의 인덱스를 반환함, 즉, 각 이미지가 가장 높은 score를 보인 class가 어딘지를 반환

 

3. 높은 값을 보인 class(예측값)이 Groud Truth인 y와 동일한지 확인하여 True/False 를 반환(pred.argmax(1) == y) 

각 이미지에 대한 True/Fase이므로 한번에 트레이닝 시키는 이미지 개수만큼 반환(batch)

 

4. 각 True/False로 이뤄진 텐서에 대하여 torch.float 형식으로 변환함 -> o, 또는 1, 로 변환

 

5. batch 단위로 0 또는 1(각 이미지 분류가 G.T 라벨과 같은지) 64개가 모인 tensor의 각 요소를 합하여 tensor(값.) 형식으로 반환함:sum()

 

6. accuracy 값! 이 필요하므로 tensor의 요소값을 빼내기 위하여 item()함수를 활용 

def test(dataloader, model, loss_fn):
  size=len(dataloader.dataset)
  num_batches= len(dataloader)
  model.eval()

  test_loss, correct =0,0

  with torch.no_grad(): # because it is test time
    for X, y in dataloader: # in batch 
        X, y = X.to(device), y.to(device) # to GPU
        pred=model(X)   # forward propa
 
        test_loss+= loss_func(pred,y).item() # if reduction is true of nn.CrossEntropyLoss(), its output is scalar or same as input 
                                             # loss btw 2 probabilities 
        correct+= (pred.argmax(1) == y).type(torch.float).sum().item() # y is G.T, 
                                                                       # output : prob of each class : vector
        print((pred.argmax(1) == y).type(torch.float))

  test_loss /= num_batches
  correct /= size

  print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
  
epoch = 5

for t in range(5):
    train(train_dataloader, model, loss_func, optimizer)
    test(test_dataloader, model, loss_func)
  
print("done!")

 

pred=model(x)에 대하여 pred 값 출력

1. pred는 model로 부터 출력된 score의 모음

2. 아래 각 색깔 별로 하나의 이미지고, 한 이미지는 10개의 class에 대하여 분류되므로 1 이미지당 10 socre(색깔 별)

 

코드 출처: 빠른 시작(Quickstart) — 파이토치 한국어 튜토리얼 (PyTorch tutorials in Korean)

728x90