scatter를 이해하기 위해
https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html#torch.Tensor.scatter_
torch.Tensor.scatter_ — PyTorch 2.1 documentation
Shortcuts
pytorch.org
여기에 들어가보면
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
이렇게 친절하게 알려준다.....
친절하게 예시도 준다...(못알아 먹겠다.)
scatter와 scatter_는 하나만 이해하면 반대쪽도 이해할 수 있다.
예시 하나만 보자.
index = torch.tensor([[0, 1, 2, 0]])
src = torch.arange(1, 11).reshape((2, 5))
# tensor([[ 1, 2, 3, 4, 5],
# [ 6, 7, 8, 9, 10]])
torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
# tensor([[1, 0, 0, 4, 0],
# [0, 2, 0, 0, 0],
# [0, 0, 3, 0, 0]])
# self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
axis =0일 때
index[0][0]은 0 → tensor[index[0][0]] [0] = tensor[0][0] 에 src[0][0]값인 1이 대입된다.
index[0][1]은 0 → tensor[index[0][1]] [1] = tensor[1][1] 에 src[0][1]값인 2가 대입된다.
index[0][2]은 0 → tensor[index[0][2]] [2] = tensor[2][2] 에 src[0][2]값인 3이 대입된다.
index[0][3]은 0 → tensor[index[0][3]] [3] = tensor[0][3] 에 src[0][3]값인 4가 대입된다.
솔직히 뭔 말인지는 알겠지만 직관적이지 않다.
그림으로 보면 좀 쉬울 것 같아서 정리했다.
import torch
axis = 0
index = torch.tensor([[0, 1, 2], [2, 0, 4]]) # 보낼 위치
src = torch.arange(1, 11).reshape((2, 5)) # 보낼 값
torch.zeros(5, 5, dtype=src.dtype).scatter_(0, index, src)
axis= 1
index = torch.tensor([[0, 1, 2], [2, 0, 4]]) # 보낼 위치
src = torch.arange(1, 11).reshape((2, 5)) # 보낼 값
torch.zeros(5, 5, dtype=src.dtype).scatter_(1, index, src)
728x90
'Ai > pytorch' 카테고리의 다른 글
timm model list (0) | 2023.12.20 |
---|---|
cuda 버전 확인, cuda 설치 (0) | 2023.11.21 |
torch.swapdims(x,0,1) 의 이해 (0) | 2023.11.14 |
괴랄한 torch.gather 명령어 이해 (0) | 2023.11.13 |
댓글