본문 바로가기

Computer Vision

[Python] .scatter()함수 이해

요약

1.  .scatter(dimension, index, src) 꼴이라면 다음과 같다.

  • dimension은 어느 방향으로 업데이트할지
  • index는 src의 값을 어떻게 선택할지
  • src는 어떤 값으로 업데이트를 할지

2. index 텐서의 각 element 의 위치에 대응되는 src의 element가 output의 어느 위치로 갈지 결정한다.

 

예시1

>>> src = torch.arange(1, 11).reshape((2, 5))
>>> src
tensor([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10]])
>>> index = torch.tensor([[0, 1, 2, 0]])
>>> output = torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
tensor([[1, 0, 0, 4, 0],
        [0, 2, 0, 0, 0],
        [0, 0, 3, 0, 0]])

예시 코드를 기준으로, scatter_(0, index, src)dimension 값이 0이므로 row 방향으로 업데이트하는 걸 생각해보자. (1이면 column 방향으로 업데이트함)

  1. 우선 output은 torch.zeros(3, 5, dtype=src.dtype)로 0으로만 이루어진 3 x 5 tensor로 초기화된다.
  2. index 텐서는 1 x 5의 2D tensor 이다. 그러므로 src (2 x 5의 2D tensor)에서 첫 번째 row인 [1, 2, 3, 4, 5]로만 output을 업데이트한다. 
  3. index의 하나 뿐인 row는 [0, 1, 2, 0]인데, 이는 src의 첫 번째  row에 각각 매핑되며, 각 값은 src의 element가 들어가는 output의 위치를 의미한다.
    • index [0, 1, 2, 0]이고, src의 [1, 2, 3, 4, 5]이므로 src의 1은 output의 index 0 column에 위치 (index의 크기가 1x5 이므로 한번에 한 column씩 업데이트)
    • index [0, 1, 2, 0]이고, src의 [1, 2, 3, 4, 5]이므로 src의 2는 output의 index 1 column에 위치 (index의 크기가 1x5 이므로 한번에 한 column씩 업데이트
    • index [0, 1, 2, 0]이고, src의 [1, 2, 3, 4, 5]이므로 src의 3은 output의 index 2 column에 위치 (index의 크기가 1x5 이므로 한번에 한 column씩 업데이트
    • index [0, 1, 2, 0]이고, src의 [1, 2, 3, 4, 5]이므로 src의 4는 output의 index 0 column에 위치 (index의 크기가 1x5 이므로 한번에 한 column씩 업데이트)
    • 크기 넘어가는 대상은 모두 0으로

예시2

>>> src = torch.arange(1, 11).reshape((2, 5))
>>> src
tensor([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10]])
        
>>> index = torch.tensor([[0, 1, 2], [0, 1, 4]])
>>> output = torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
tensor([[1, 2, 3, 0, 0],
        [6, 7, 0, 0, 8],
        [0, 0, 0, 0, 0]])

예시 코드를 기준으로, scatter_(1, index, src) dimension 값이1이므로 column 방향으로 업데이트하는 걸 생각해보자. 

  1.  우선 output은 torch.zeros(3, 5, dtype=src.dtype)로 0으로만 이루어진 3 x 5 tensor로 초기화된다.
  2.  index 텐서는 2 x 3의 2D tensor 이다. 그러므로 src (2 x 5의 2D tensor)에서 모든 row에 대하여 output을 업데이트한다.
  3. index의 인덱스0 번째의 row는 [0, 1, 2]인데, 이는 src의 첫 번째  element에 각각 매핑되며, 각 값은 src의 element가 들어가는 output의 위치를 의미한다.
    • index [0, 1, 2]이고, src의 [1, 2, 3, 4, 5]이므로 src의 1은 output의 (0, 0)에 위
    • index [0, 1, 2]이고, src의 [1, 2, 3, 4, 5]이므로 src의 2 output의 (0, 1)에 위치
    • index [0, 1, 2]이고, src의 [1, 2, 3, 4, 5]이므로 src의 3은 output의 (0, 2)에 위
    • 이 외는 인덱스에 맵핑되지 않으므로 0
  4. index의 인덱스 1번째의 row는 [0, 1, 4]인데, 이는 src의 두 번째  element에 각각 매핑되며, 각 값은 src의 element가 들어가는 output의 위치를 의미한다.
    • index [0, 1, 4]이고, src의 [6, 7, 8, 9, 10]이므로 src의 6은 output의 (1, 0)에 위
    • index [0, 1, 4]이고, src의 [6, 7, 8, 9, 10]이므로 src의 7 output의 (1, 1)에 위치
    • index [0, 1, 4]이고, src의 [6, 7, 8, 9, 10]이므로 src의 8은 output의 (1, 4)에 위
    • 이 외는 인덱스에 맵핑되지 않으므로 0

 

코드 원리

우선 Pytorch 공식 document 에 따르면 하기의 rule로 업데이트 된다. 

해당 내용은 3D 텐서 기준이다. 

self[index[i][j][k]][j][k] *= src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] *= src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] *= src[i][j][k]  # if dim == 2

 

출처: torch.Tensor.scatter_ — PyTorch 1.12 documentation

728x90