torch.gather(input, dim, index, *, sparse_grad=False, out=None) → [Tensor]
Gathers values along an axis specified by dim.
Parameters
- input (Tensor) – the source tensor
- dim (int) – the axis along which to index
- index (LongTensor) – the indices of elements to gather
Keyword Arguments
- sparse_grad (bool, optional) – If
True
, gradient w.r.t.input
will be a sparse tensor.- out (Tensor, optional) – the destination tensor
Example:
>>> t = torch.tensor([[1, 2], [3, 4]])
>>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))
tensor([[ 1, 1],
[ 4, 3]])
通俗解释
a = torch.arange(12).reshape(3, 4)
print(a)
# tensor([[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11]])
gather_from_a = torch.gather(a, 0, torch.tensor(
[[2, 2, 1, 0],
[1, 1, 2, 2],
[2, 1, 2, 2],
[0, 0, 1, 2]]
))
print(gather_from_a)
# tensor([[ 8, 9, 6, 3],
# [ 4, 5, 10, 11],
# [ 8, 5, 10, 11],
# [ 0, 1, 6, 11]])
从原tensor中沿着由dim指定的轴收集数值,然后返回一个新的tensor。
对于二阶矩阵:
dim = 0
torch.gather()
的第三个参数,为一个tensor数组,其中每个元素的值代表从a
中的n-1
(下标从0开始)行取值,每个元素所在的位置,代表的是从a
中的某一列,这样横轴和纵轴都确定了,所以就能从a
中取值赋给gather_from_a
的指定位置。见下方例子。x_index\y_index
0
1
2
3
0
0 1 2 3 1
4 5 6 7 2
8 9 10 11 这是
tensor a
的轴下标定义x_index\y_index
0
1
2
3
0
2 2 1 0 1
1 1 2 2 2
2 1 2 2 3
0 0 1 2 这是
torch.gather()
的第三个参数的轴下标定义,叫做position
吧- 现在我们看,
position[0][0]=2
那么就代表,gather_from_a[0][0]
位置的值的横轴是a
中的x_index=2
,现在只需要固定一个纵轴,就可以从a
中取值了。看position[0][0]
所处于position
中的y_index
,不难发现y_index=0
,所以是从a
中的x_inde=2,y_index=0
即a[2][0]=8
取值赋给gather_from_a[0][0]
;接下来再随便看一个,比如说position[3][2]=1
,即从a
的x_index=1
,因为值等于横轴下标:n-1
,然后看纵轴,position[3][2]
,所在位置的y_index=2
,所以该位置从a
中取得值为a[1][2]=6
;🆗了吗?其余的也可以自己试试。
dim = 1
则,相反,position
中的值代表纵轴下标:n-1
,所在位置的x_index
代表从a
的哪个横轴下标。即a[x_index][position[x_index][y_index]]
注:因为是返回一个新的张量,该函数只是从原tensor中取值,所以返回张量的大小与
position
的shape有关。position
中的值注意别超过a
中的index
就行。
dim=0
position中数的值代表的是行数,所处位置代表的是列数。 dim=1 position中数的值代表的是列数,所处位置代表的是行数。
还不快抢沙发