괴랄한 torch.gather 명령어 이해
a = torch.arange(64).view(4,4,4) print(a.shape) print(a) indi = torch.tensor([[[0,1,2,3],[0,1,2,3]],[[0,1,2,3],[0,1,2,3]],[[0,1,2,3],[0,1,2,3]]]) out = torch.gather(a,0,indi) # axis = 0 print(out) out = torch.gather(a,1,indi) # axis = 1 print(out) out = torch.gather(a,2,indi) # axis = 2 print(out) axis= 0, 1, 2 의 경우를 관찰 ! 여기서 indices는 2x4x3 (row col channel)으로 만들었는데 (torch.shape는 3x2x4) input이 i..
2023. 11. 13.