본문 바로가기

Computer Vision

VOCSegmentation DataLoader 버그

※ pytorch 를 활용하여 VOC 2011를 활용하고자 하는 경우, transforms 옵션으로는 정상적인 data transformation이 수행되지 않는다.

  • VOCSegmentation에 data transform을 적용할 수 있는 option은 'transform', 'transform_target', 'transforms'이다.
    • transform은 원본 RGB 이미지에 대한 data transformation을 적용한다.
    • transform_target은 segmentaition label image에 대한 transformation을 적용한다.
    • transforms은 이미지와 레이블에 대해 모두 transformation을 적용한다.
  • 그러나 실제로는 아래와 같이  transforms 옵션을 활용하면 정상적으로 적용되지 않는다.
    data_transform = transforms.Compose([transforms.Resize(size=400),
                                          transforms.ToTensor()])   
    
    if(mode=='train'):
        trainset = torchvision.datasets.VOCSegmentation(root='./data', 
                                                image_set= 'train', 
                                                year='2011',
                                                download=False,
                                                transforms = data_transform
                                                )
  • 에러코드는 아래와 같다.
    data = self._next_data()
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    data = [self.dataset[idx] for idx in possibly_batched_index]
    data = [self.dataset[idx] for idx in possibly_batched_index]
    img, target = self.transforms(img, target)
TypeError: __call__() takes 2 positional arguments but 3 were given

 

  • 이를 해결하기 위해서는 transform과 transform_target을 사용하여 각각 이미지와 레이블에 대한 transformation을 수행하여 해결할 수 있다. 
    data_transform = transforms.Compose([transforms.Resize(size=400),
                                          transforms.ToTensor()])   
    
    if(mode=='train'):
        trainset = torchvision.datasets.VOCSegmentation(root='./data', 
                                                image_set= 'train', 
                                                year='2011',
                                                download=False,
                                                transform = data_transform,
                                                target_transform=data_transform
                                                )

        dataloader = torch.utils.data.DataLoader(trainset, batch_size = batch_size, shuffle=True)
728x90