본문 바로가기
Ai/pytorch

괴랄한 torch.gather 명령어 이해

by yooom 2023. 11. 13.
a = torch.arange(64).view(4,4,4)
print(a.shape)
print(a)


indi = torch.tensor([[[0,1,2,3],[0,1,2,3]],[[0,1,2,3],[0,1,2,3]],[[0,1,2,3],[0,1,2,3]]])


out = torch.gather(a,0,indi)  # axis = 0
print(out)

out = torch.gather(a,1,indi) # axis = 1
print(out)

out = torch.gather(a,2,indi) # axis = 2
print(out)

axis= 0, 1, 2 의 경우를 관찰 !

 

여기서 indices는 2x4x3 (row col channel)으로 만들었는데 (torch.shape는 3x2x4)

input이 indices이므로

output도 2x4x3으로 나와야한다. (torch.shape는 3x2x4 )

 

2x4 행렬이 3 채널 나올 건데,

각 2x4 행렬이 만들어내는 결과를 아래 figure에서 살펴보자.

 

 

◆ axis = 0

#indi = torch.tensor([[[0,1,2,3],[0,1,2,3]],[[0,1,2,3],[0,1,2,3]],[[0,1,2,3],[0,1,2,3]]])
out = torch.gather(a,0,indi)
print(out)

 

 

 

◆ axis = 1

#indi = torch.tensor([[[0,1,2,3],[0,1,2,3]],[[0,1,2,3],[0,1,2,3]],[[0,1,2,3],[0,1,2,3]]])
out = torch.gather(a,1,indi)
print(out)

 

 

◆ axis = 2

#indi = torch.tensor([[[0,1,2,3],[0,1,2,3]],[[0,1,2,3],[0,1,2,3]],[[0,1,2,3],[0,1,2,3]]])
out = torch.gather(a,2,indi)
print(out)

 

 

 

728x90

'Ai > pytorch' 카테고리의 다른 글

timm model list  (0) 2023.12.20
cuda 버전 확인, cuda 설치  (0) 2023.11.21
괴랄한 torch.scatter( ) 이해하기  (0) 2023.11.14
torch.swapdims(x,0,1) 의 이해  (0) 2023.11.14

댓글