※ pytorch 를 활용하여 VOC 2011를 활용하고자 하는 경우, transforms 옵션으로는 정상적인 data transformation이 수행되지 않는다.
- VOCSegmentation에 data transform을 적용할 수 있는 option은 'transform', 'transform_target', 'transforms'이다.
- transform은 원본 RGB 이미지에 대한 data transformation을 적용한다.
- transform_target은 segmentaition label image에 대한 transformation을 적용한다.
- transforms은 이미지와 레이블에 대해 모두 transformation을 적용한다.
- 그러나 실제로는 아래와 같이 transforms 옵션을 활용하면 정상적으로 적용되지 않는다.
data_transform = transforms.Compose([transforms.Resize(size=400),
transforms.ToTensor()])
if(mode=='train'):
trainset = torchvision.datasets.VOCSegmentation(root='./data',
image_set= 'train',
year='2011',
download=False,
transforms = data_transform
)
- 에러코드는 아래와 같다.
data = self._next_data()
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
data = [self.dataset[idx] for idx in possibly_batched_index]
data = [self.dataset[idx] for idx in possibly_batched_index]
img, target = self.transforms(img, target)
TypeError: __call__() takes 2 positional arguments but 3 were given
- 이를 해결하기 위해서는 transform과 transform_target을 사용하여 각각 이미지와 레이블에 대한 transformation을 수행하여 해결할 수 있다.
data_transform = transforms.Compose([transforms.Resize(size=400),
transforms.ToTensor()])
if(mode=='train'):
trainset = torchvision.datasets.VOCSegmentation(root='./data',
image_set= 'train',
year='2011',
download=False,
transform = data_transform,
target_transform=data_transform
)
dataloader = torch.utils.data.DataLoader(trainset, batch_size = batch_size, shuffle=True)