※ Segmentation / classification 시 target (혹은 ground truth) 의 class label이 범위를 넘어선 경우에 발생
1. 필자 같은 경우는 배경(0), 객체(1~20), 경계선(255)으로 labeling 된 segmenatation mask를 nn.CrossEntropy() 에 target으로 넘겨주면서 상기의 에러메세지 발생
2. Pytorch의 CE는 long 타입으로 입력 받는 target에 대하여 one-hot encoding을 수행하고 이를 input에 대하여 softmax 및 cross entropy를 계산함.
3. 자동으로 one-hot encoding 하는 과정에서 연속적인 레이블(0~20)에 대하여 21개의 channel로 원핫 인코딩 가능하였으나, pixel value(혹은 label 값)가 255인 지점으로 인하여 에러 발생
4. (해결) 다른 레이블 범위와 동떨어진 레이블 값을 연속적인 범위 내로 맞추면 됨
e.g., 255를 -1로 변환
5. 필자는 255에 해당하는 범위가 필요 없어 0으로 변환하여 트레이닝에 활용함
에러가 발생한 코드
for i, (img, target) in enumerate(train):
optimizer.zero_grad()
img, target = img.to(device), target.to(device)
output = model(img)
target = torch.squeeze(target, dim=1)
loss = criterion(output, target) # float, long , one-hot coding for target? // add ppt or question
losses.update(loss.item())
loss.backward()
optimizer.step()
해결한 코드
for i, (img, target) in enumerate(train):
optimizer.zero_grad()
# boundary(255) -> background(0)
# backgroud(0), object(1-20)
target = target - (target == 255).long() * 255
img, target = img.to(device), target.to(device)
output = model(img) # output ranges -1 ~ 1
target = torch.squeeze(target, dim=1)
loss = criterion(output, target) # float, long , one-hot coding for target? // add ppt or question
losses.update(loss.item())
loss.backward()
optimizer.step()
'Computer Vision' 카테고리의 다른 글
GPU에 남아있는 process 지우기 / kill -9 [process_id] 가 안될 때 /RuntimeError: Address already in use (0) | 2022.11.22 |
---|---|
ssl.SSLError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed 오류 해결 방법 (1) | 2022.11.22 |
VOCSegmentation DataLoader 버그 (0) | 2022.11.14 |
[환경 구축] Windows11 / WSL2 / NVIDIA GeForce GTX 1660 Ti / Pytorch / VS code 환경에서 딥러닝 개발하기 (0) | 2022.09.30 |
[Python] .scatter()함수 이해 (0) | 2022.08.18 |