pytorch scatter的用法

官方给的用法:

scatter(dim, index, src)
self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

一个例子

import torch
input = torch.randn(2, 4)
print(input)
output = torch.zeros(2, 5)
index = torch.tensor([[3, 1, 2, 0], [1, 2, 0, 3]])
output = output.scatter(1, index, input)
print(output)

输出:

tensor([[ 0.0461,  0.4024, -1.0115,  0.2167],
        [-0.6123,  0.5036,  0.2310,  0.6931]])
tensor([[ 0.2167,  0.4024, -1.0115,  0.0461,  0.0000],
        [ 0.2310, -0.6123,  0.5036,  0.6931,  0.0000]])

scatter(scatter_)是将input tensor按照index赋值给output tensor来达到更新output的效果的。
我们从index下手,
index[0][0]=3,由于dim=1,那么我们取input[0][0]=0.0461, 赋值给output[0][index[0][0]]=output[0][3]
index[0][3]=0, input[0][3]=0.2167,赋值给output[0][inde[0][3]]=output[0][0]
index[1][2]=0, input[1][2]=0.2310, 赋值给output[1][index[1][2]]=output[1][0]
index[1][3]=3, input[1][3]=0.6931, 赋值给output[1][index[1][3]]=output[1][3]

也就是index的下标和input的下标是一致的,取出来的这个值赋值给谁呢,这个是index对应的值以及dim来确定的,如果dim=1, 那么更新的是output[i][index[i][j]]=input[i][j],官方文档给的是三维的情况,dim是多少,那么index的值就放在第几维。
scatter一个很重要的应用就是生成one-hot矩阵
假设总共有5类,现在一个batch有3个样本,分别对应的标签为1,2,0。那么生成的one-hot矩阵应该是这样的:

index=torch.tensor([[1], [2], [0]])
y=torch.zeros(3, 5)
y=y.scatter(1, index, 1)
print(y)

输出:

tensor([[0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [1., 0., 0., 0., 0.]])
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容

  • 介绍 相比TensorFlow的静态图开发,Pytorch的动态图特性使得开发起来更加人性化,选择Pytorch的...
    dawsonenjoy阅读 25,035评论 2 18
  • stack使用stack是为了保留两个信息: 序列(先后)和 张量矩阵信息。比如在循环神经网络中,网络的输出数据...
    lzjngu阅读 392评论 0 0
  • 1.pytorch中的索引 index_select(x, dim, indices)dim代表维度,indice...
    yumiii_阅读 5,423评论 0 0
  • scatter_(input, dim, index, src)将src中数据根据index中的索引按照dim的方...
    cjhfhb阅读 5,885评论 0 0
  • 我是黑夜里大雨纷飞的人啊 1 “又到一年六月,有人笑有人哭,有人欢乐有人忧愁,有人惊喜有人失落,有的觉得收获满满有...
    陌忘宇阅读 8,603评论 28 53