scatter_(input, dim, index, src)将src中数据根据index中的索引按照dim的方向填进input中。
>>> x = torch.rand(2, 5)
>>> x
0.4319 0.6500 0.4080 0.8760 0.2355
0.2609 0.4711 0.8486 0.8573 0.1029
[torch.FloatTensor of size 2x5]
index的shape刚好与x的shape对应,也就是index中每个元素指定x中一个数据的填充位置。dim=0,表示按行填充,主要理解按行填充。举例index中的第0行第2列的值为2,表示在第2行(从0开始)进行填充,对应到input = zeros(3, 5)中就是位置(2,2)。所以此处要求input的列数要与x列数相同,而index中的最大值应与zeros(3, 5)行数相一致。
>>> torch.zeros(3, 5).scatter_(0, torch.LongTensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
0.4319 0.4711 0.8486 0.8760 0.2355
0.0000 0.6500 0.0000 0.8573 0.0000
0.2609 0.0000 0.4080 0.0000 0.1029
[torch.FloatTensor of size 3x5]
同上理,可以把1.23看成[[1.23], [1.23]]。此处按列填充,index中的2对应zeros(2, 4)的(0,2)位置。
>>> z = torch.zeros(2, 4).scatter_(1, torch.LongTensor([[2], [3]]), 1.23)
>>> z
0.0000 0.0000 1.2300 0.0000
0.0000 0.0000 0.0000 1.2300
[torch.FloatTensor of size 2x4]
综上,几点要注意:
index的shape要与填充数据src的shape一致,如果不一致,将进行广播
index中的索引指的是要把src中对应位置的数据按照指定那个维度(即dim)填充到原数据input中,我们知道了要填充的数据是什么,填充到input的哪行那列呢,dim指定哪个维度,这个维度就是index索引值,另一个维度就是这个索引在index中的位置。
scatter() 和 scatter_() 的作用是一样的,只不过 scatter() 不会直接修改原来的 Tensor,而 scatter_() 会。PyTorch 中,一般函数加下划线代表直接在原来的 Tensor 上修改
scatter() 一般可以用来对标签进行 one-hot 编码,这就是一个典型的用标量来修改张量的一个例子
class_num = 10
batch_size = 4
label = torch.LongTensor(batch_size, 1).random_() % class_num
#tensor([[6],
# [0],
# [3],
# [2]])
torch.zeros(batch_size, class_num).scatter_(1, label, 1)
#tensor([[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
# [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
# [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
# [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]])
转载于https://blog.csdn.net/qq_16234613/article/details/79827006
转载于https://www.cnblogs.com/dogecheng/p/11938009.html