a = torch.arange(64).view(4,4,4)
print(a.shape)
print(a)
indi = torch.tensor([[[0,1,2,3],[0,1,2,3]],[[0,1,2,3],[0,1,2,3]],[[0,1,2,3],[0,1,2,3]]])
out = torch.gather(a,0,indi) # axis = 0
print(out)
out = torch.gather(a,1,indi) # axis = 1
print(out)
out = torch.gather(a,2,indi) # axis = 2
print(out)
axis= 0, 1, 2 의 경우를 관찰 !
여기서 indices는 2x4x3 (row col channel)으로 만들었는데 (torch.shape는 3x2x4)
input이 indices이므로
output도 2x4x3으로 나와야한다. (torch.shape는 3x2x4 )
2x4 행렬이 3 채널 나올 건데,
각 2x4 행렬이 만들어내는 결과를 아래 figure에서 살펴보자.
◆ axis = 0
#indi = torch.tensor([[[0,1,2,3],[0,1,2,3]],[[0,1,2,3],[0,1,2,3]],[[0,1,2,3],[0,1,2,3]]])
out = torch.gather(a,0,indi)
print(out)
◆ axis = 1
#indi = torch.tensor([[[0,1,2,3],[0,1,2,3]],[[0,1,2,3],[0,1,2,3]],[[0,1,2,3],[0,1,2,3]]])
out = torch.gather(a,1,indi)
print(out)
◆ axis = 2
#indi = torch.tensor([[[0,1,2,3],[0,1,2,3]],[[0,1,2,3],[0,1,2,3]],[[0,1,2,3],[0,1,2,3]]])
out = torch.gather(a,2,indi)
print(out)
728x90
'Ai > pytorch' 카테고리의 다른 글
timm model list (0) | 2023.12.20 |
---|---|
cuda 버전 확인, cuda 설치 (0) | 2023.11.21 |
괴랄한 torch.scatter( ) 이해하기 (0) | 2023.11.14 |
torch.swapdims(x,0,1) 의 이해 (0) | 2023.11.14 |
댓글