본문 바로가기
Ai/pytorch

괴랄한 torch.scatter( ) 이해하기

by yooom 2023. 11. 14.

scatter를 이해하기 위해

https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html#torch.Tensor.scatter_

 

torch.Tensor.scatter_ — PyTorch 2.1 documentation

Shortcuts

pytorch.org

여기에 들어가보면

 

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

 

이렇게 친절하게 알려준다.....

친절하게 예시도 준다...(못알아 먹겠다.)

scatter와 scatter_는 하나만 이해하면 반대쪽도 이해할 수 있다.

 

 

예시 하나만 보자.

index = torch.tensor([[0, 1, 2, 0]])

src = torch.arange(1, 11).reshape((2, 5))
# tensor([[ 1,  2,  3,  4,  5],
#         [ 6,  7,  8,  9, 10]])

torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
# tensor([[1, 0, 0, 4, 0],
#         [0, 2, 0, 0, 0],
#         [0, 0, 3, 0, 0]])

# self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0

axis =0일 때

index[0][0]은 0 → tensor[index[0][0]] [0] = tensor[0][0] 에 src[0][0]값인 1이 대입된다. 

index[0][1]은 0 → tensor[index[0][1]] [1] = tensor[1][1] 에 src[0][1]값인 2가 대입된다.

index[0][2]은 0 → tensor[index[0][2]] [2] = tensor[2][2] 에 src[0][2]값인 3이 대입된다. 

index[0][3]은 0 → tensor[index[0][3]] [3] = tensor[0][3] 에 src[0][3]값인 4가 대입된다. 

 

솔직히 뭔 말인지는 알겠지만 직관적이지 않다.

 

그림으로 보면 좀 쉬울 것 같아서 정리했다.

 

import torch

 

axis = 0

index = torch.tensor([[0, 1, 2], [2, 0, 4]])  # 보낼 위치

src = torch.arange(1, 11).reshape((2, 5))  # 보낼 값

torch.zeros(5, 5, dtype=src.dtype).scatter_(0, index, src)

 

axis= 1

index = torch.tensor([[0, 1, 2], [2, 0, 4]])  # 보낼 위치

src = torch.arange(1, 11).reshape((2, 5))  # 보낼 값

torch.zeros(5, 5, dtype=src.dtype).scatter_(1, index, src)

 

728x90

'Ai > pytorch' 카테고리의 다른 글

timm model list  (0) 2023.12.20
cuda 버전 확인, cuda 설치  (0) 2023.11.21
torch.swapdims(x,0,1) 의 이해  (0) 2023.11.14
괴랄한 torch.gather 명령어 이해  (0) 2023.11.13

댓글