torch.gather(input, dim, index, *, sparse_grad=False, out=None) → [Tensor]Gathers values along an axis specified by dim.
Parameters
- input (Tensor) – the source tensor
- dim (int) – the axis along which to index
- index (LongTensor) – the indices of elements to gather
Keyword Arguments
- sparse_grad (bool, optional) – If
True, gradient w.r.t.inputwill be a sparse tensor.- out (Tensor, optional) – the destination tensor
Example:
>>> t = torch.tensor([[1, 2], [3, 4]])
>>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))
tensor([[ 1, 1],
[ 4, 3]])通俗解释
a = torch.arange(12).reshape(3, 4)
print(a)
# tensor([[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11]])
gather_from_a = torch.gather(a, 0, torch.tensor(
[[2, 2, 1, 0],
[1, 1, 2, 2],
[2, 1, 2, 2],
[0, 0, 1, 2]]
))
print(gather_from_a)
# tensor([[ 8, 9, 6, 3],
# [ 4, 5, 10, 11],
# [ 8, 5, 10, 11],
# [ 0, 1, 6, 11]])从原tensor中沿着由dim指定的轴收集数值,然后返回一个新的tensor。
对于二阶矩阵:
dim = 0torch.gather()的第三个参数,为一个tensor数组,其中每个元素的值代表从a中的n-1(下标从0开始)行取值,每个元素所在的位置,代表的是从a中的某一列,这样横轴和纵轴都确定了,所以就能从a中取值赋给gather_from_a的指定位置。见下方例子。x_index\y_index012300 1 2 3 14 5 6 7 28 9 10 11 这是
tensor a的轴下标定义x_index\y_index012302 2 1 0 11 1 2 2 22 1 2 2 30 0 1 2 这是
torch.gather()的第三个参数的轴下标定义,叫做position吧- 现在我们看,
position[0][0]=2那么就代表,gather_from_a[0][0]位置的值的横轴是a中的x_index=2,现在只需要固定一个纵轴,就可以从a中取值了。看position[0][0]所处于position中的y_index,不难发现y_index=0,所以是从a中的x_inde=2,y_index=0即a[2][0]=8取值赋给gather_from_a[0][0];接下来再随便看一个,比如说position[3][2]=1,即从a的x_index=1,因为值等于横轴下标:n-1,然后看纵轴,position[3][2],所在位置的y_index=2,所以该位置从a中取得值为a[1][2]=6;🆗了吗?其余的也可以自己试试。
dim = 1则,相反,position中的值代表纵轴下标:n-1,所在位置的x_index代表从a的哪个横轴下标。即a[x_index][position[x_index][y_index]]注:因为是返回一个新的张量,该函数只是从原tensor中取值,所以返回张量的大小与
position的shape有关。position中的值注意别超过a中的index就行。
dim=0
position中数的值代表的是行数,所处位置代表的是列数。 dim=1 position中数的值代表的是列数,所处位置代表的是行数。

还不快抢沙发