본문 바로가기

Computer Vision

[에러해결] nll_loss 2d forward kernel: block: [5,0,0], thread: [845,0,0] Assertion `t >= 0 && t < n_classes` failed.

※ Segmentation / classification 시 target (혹은 ground truth) 의 class label이 범위를 넘어선 경우에 발생 

 

1. 필자 같은 경우는 배경(0), 객체(1~20), 경계선(255)으로 labeling 된 segmenatation mask를 nn.CrossEntropy() 에 target으로 넘겨주면서 상기의 에러메세지 발생

2. Pytorch의 CE는 long 타입으로 입력 받는 target에 대하여 one-hot encoding을 수행하고 이를 input에 대하여 softmax 및 cross entropy를 계산함.

3. 자동으로 one-hot encoding 하는 과정에서 연속적인 레이블(0~20)에 대하여 21개의 channel로 원핫 인코딩 가능하였으나, pixel value(혹은 label 값)가 255인 지점으로 인하여 에러 발생

4. (해결) 다른 레이블 범위와 동떨어진 레이블 값을 연속적인 범위 내로 맞추면 됨

    e.g., 255를 -1로 변환

5. 필자는 255에 해당하는 범위가 필요 없어 0으로 변환하여 트레이닝에 활용함 

 

에러가 발생한 코드 

 for i, (img, target) in enumerate(train):
            optimizer.zero_grad() 
        
            img, target = img.to(device), target.to(device)
            output = model(img) 
            
            target = torch.squeeze(target, dim=1)
            loss = criterion(output, target) # float, long , one-hot coding for target? // add ppt or question 
            losses.update(loss.item())
            loss.backward()
            optimizer.step()

 

 

해결한 코드

 for i, (img, target) in enumerate(train):
            optimizer.zero_grad() 
           
            # boundary(255) -> background(0)
            # backgroud(0), object(1-20)
            target = target - (target == 255).long() * 255
            img, target = img.to(device), target.to(device)
            
            output = model(img) # output ranges -1 ~ 1
            
            target = torch.squeeze(target, dim=1)
            loss = criterion(output, target) # float, long , one-hot coding for target? // add ppt or question 
            losses.update(loss.item())
            loss.backward()
            optimizer.step()

 

728x90