요약
1. .scatter(dimension, index, src) 꼴이라면 다음과 같다.
- dimension은 어느 방향으로 업데이트할지
- index는 src의 값을 어떻게 선택할지
- src는 어떤 값으로 업데이트를 할지
2. index 텐서의 각 element 의 위치에 대응되는 src의 element가 output의 어느 위치로 갈지 결정한다.
예시1
>>> src = torch.arange(1, 11).reshape((2, 5))
>>> src
tensor([[ 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10]])
>>> index = torch.tensor([[0, 1, 2, 0]])
>>> output = torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
tensor([[1, 0, 0, 4, 0],
[0, 2, 0, 0, 0],
[0, 0, 3, 0, 0]])
예시 코드를 기준으로, scatter_(0, index, src)의 dimension 값이 0이므로 row 방향으로 업데이트하는 걸 생각해보자. (1이면 column 방향으로 업데이트함)
- 우선 output은 torch.zeros(3, 5, dtype=src.dtype)로 0으로만 이루어진 3 x 5 tensor로 초기화된다.
- index 텐서는 1 x 5의 2D tensor 이다. 그러므로 src (2 x 5의 2D tensor)에서 첫 번째 row인 [1, 2, 3, 4, 5]로만 output을 업데이트한다.
- index의 하나 뿐인 row는 [0, 1, 2, 0]인데, 이는 src의 첫 번째 row에 각각 매핑되며, 각 값은 src의 element가 들어가는 output의 위치를 의미한다.
- index [0, 1, 2, 0]이고, src의 [1, 2, 3, 4, 5]이므로 src의 1은 output의 index 0 column에 위치 (index의 크기가 1x5 이므로 한번에 한 column씩 업데이트)
- index [0, 1, 2, 0]이고, src의 [1, 2, 3, 4, 5]이므로 src의 2는 output의 index 1 column에 위치 (index의 크기가 1x5 이므로 한번에 한 column씩 업데이트
- index [0, 1, 2, 0]이고, src의 [1, 2, 3, 4, 5]이므로 src의 3은 output의 index 2 column에 위치 (index의 크기가 1x5 이므로 한번에 한 column씩 업데이트
- index [0, 1, 2, 0]이고, src의 [1, 2, 3, 4, 5]이므로 src의 4는 output의 index 0 column에 위치 (index의 크기가 1x5 이므로 한번에 한 column씩 업데이트)
- 크기 넘어가는 대상은 모두 0으로
예시2
>>> src = torch.arange(1, 11).reshape((2, 5))
>>> src
tensor([[ 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10]])
>>> index = torch.tensor([[0, 1, 2], [0, 1, 4]])
>>> output = torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
tensor([[1, 2, 3, 0, 0],
[6, 7, 0, 0, 8],
[0, 0, 0, 0, 0]])
예시 코드를 기준으로, scatter_(1, index, src)의 dimension 값이1이므로 column 방향으로 업데이트하는 걸 생각해보자.
- 우선 output은 torch.zeros(3, 5, dtype=src.dtype)로 0으로만 이루어진 3 x 5 tensor로 초기화된다.
- index 텐서는 2 x 3의 2D tensor 이다. 그러므로 src (2 x 5의 2D tensor)에서 모든 row에 대하여 output을 업데이트한다.
- index의 인덱스0 번째의 row는 [0, 1, 2]인데, 이는 src의 첫 번째 element에 각각 매핑되며, 각 값은 src의 element가 들어가는 output의 위치를 의미한다.
- index [0, 1, 2]이고, src의 [1, 2, 3, 4, 5]이므로 src의 1은 output의 (0, 0)에 위치
- index [0, 1, 2]이고, src의 [1, 2, 3, 4, 5]이므로 src의 2는 output의 (0, 1)에 위치
- index [0, 1, 2]이고, src의 [1, 2, 3, 4, 5]이므로 src의 3은 output의 (0, 2)에 위치
- 이 외는 인덱스에 맵핑되지 않으므로 0
- index의 인덱스 1번째의 row는 [0, 1, 4]인데, 이는 src의 두 번째 element에 각각 매핑되며, 각 값은 src의 element가 들어가는 output의 위치를 의미한다.
- index [0, 1, 4]이고, src의 [6, 7, 8, 9, 10]이므로 src의 6은 output의 (1, 0)에 위치
- index [0, 1, 4]이고, src의 [6, 7, 8, 9, 10]이므로 src의 7는 output의 (1, 1)에 위치
- index [0, 1, 4]이고, src의 [6, 7, 8, 9, 10]이므로 src의 8은 output의 (1, 4)에 위치
- 이 외는 인덱스에 맵핑되지 않으므로 0
코드 원리
우선 Pytorch 공식 document 에 따르면 하기의 rule로 업데이트 된다.
해당 내용은 3D 텐서 기준이다.
self[index[i][j][k]][j][k] *= src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] *= src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] *= src[i][j][k] # if dim == 2
'Computer Vision' 카테고리의 다른 글
VOCSegmentation DataLoader 버그 (0) | 2022.11.14 |
---|---|
[환경 구축] Windows11 / WSL2 / NVIDIA GeForce GTX 1660 Ti / Pytorch / VS code 환경에서 딥러닝 개발하기 (0) | 2022.09.30 |
(pred.argmax(1) == y).type(torch.float).sum().item() 해석 (1) | 2022.08.17 |
nn.Module 입력과 forward 함수: Pytorch에서 모델 호출법 (0) | 2022.08.16 |
Python 변수로 Class 명 치환(Instanstiate) (0) | 2022.08.16 |