TORCH.GATHER

Python,Torch,Daily life,Share 2024-03-12 155 次浏览 次点赞

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → [Tensor]

Gathers values along an axis specified by dim.

Parameters

  • input (Tensor) – the source tensor
  • dim (int) – the axis along which to index
  • index (LongTensor) – the indices of elements to gather

Keyword Arguments

  • sparse_grad (bool, optional) – If True, gradient w.r.t. input will be a sparse tensor.
  • out (Tensor, optional) – the destination tensor

Example:

>>> t = torch.tensor([[1, 2], [3, 4]])
>>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))
tensor([[ 1,  1],
        [ 4,  3]])

通俗解释

a = torch.arange(12).reshape(3, 4)
print(a)
# tensor([[ 0,  1,  2,  3],
#         [ 4,  5,  6,  7],
#         [ 8,  9, 10, 11]])
gather_from_a = torch.gather(a, 0, torch.tensor(
        [[2, 2, 1, 0],
         [1, 1, 2, 2],
         [2, 1, 2, 2],
         [0, 0, 1, 2]]
))
print(gather_from_a)
# tensor([[ 8,  9,  6,  3],
#         [ 4,  5, 10, 11],
#         [ 8,  5, 10, 11],
#         [ 0,  1,  6, 11]])

从原tensor中沿着由dim指定的轴收集数值,然后返回一个新的tensor。

对于二阶矩阵:

  • dim = 0

    • torch.gather()的第三个参数,为一个tensor数组,其中每个元素的值代表从a中的n-1(下标从0开始)行取值,每个元素所在的位置,代表的是从a中的某一列,这样横轴和纵轴都确定了,所以就能从a中取值赋给gather_from_a的指定位置。见下方例子。
    • x_index\y_index0123
      00123
      14567
      2891011

      这是tensor a的轴下标定义

      x_index\y_index0123
      02210
      11122
      22122
      30012

      这是torch.gather()的第三个参数的轴下标定义,叫做position

    • 现在我们看,position[0][0]=2那么就代表,gather_from_a[0][0]位置的值的横轴是a中的x_index=2,现在只需要固定一个纵轴,就可以从a中取值了。看position[0][0]所处于position中的y_index,不难发现y_index=0,所以是从a中的x_inde=2,y_index=0a[2][0]=8取值赋给gather_from_a[0][0];接下来再随便看一个,比如说position[3][2]=1,即从ax_index=1,因为值等于横轴下标:n-1,然后看纵轴,position[3][2],所在位置的y_index=2,所以该位置从a中取得值为a[1][2]=6;🆗了吗?其余的也可以自己试试。
  • dim = 1则,相反,position中的值代表纵轴下标:n-1,所在位置的x_index代表从a的哪个横轴下标。即a[x_index][position[x_index][y_index]]

    注:因为是返回一个新的张量,该函数只是从原tensor中取值,所以返回张量的大小与position的shape有关。position中的值注意别超过a中的index就行。

dim=0

    position中数的值代表的是行数,所处位置代表的是列数。
    dim=1
    position中数的值代表的是列数,所处位置代表的是行数。

本文由 fmujie 创作,采用 知识共享署名 3.0,可自由转载、引用,但需署名作者且注明文章出处。

还不快抢沙发

添加新评论

召唤看板娘