-
max()
torch.max(input, dim)
dim参数指出删去哪一维度,0-行,1-列;输出两个tensor,第一个得到最大值结果,第二个给出相对位置(0-index)
>>> a = torch.randn(4, 4)
>>> a
tensor([[-1.2360, -0.2942, -0.1222, 0.8475],
[ 1.1949, -1.1127, -2.2379, -0.6702],
[ 1.5717, -0.9207, 0.1297, -1.8768],
[-0.6172, 1.0036, -0.6060, -0.2432]])
>>> torch.max(a, 1)
(tensor([ 0.8475, 1.1949, 1.5717, 1.0036]), tensor([ 3, 0, 0, 1]))
dim=1,删除列的维度,只有1列,每一行为该行最大值,第二个tensor给出该最大值所在的列数
等同于a.max(1)
例:在训练网络时
output = net(img)
_, predicted = output.max(1)
output为对img的预测输出,batch行label列,每行是一个图片的输出,每次输出batch组。所以预测结果需要看每行的最大值,找每行最大值的位置。output.max(1)
找到每行最大值,有两个tensor输出,第一个为最大值,第二个为最大值所在位置,所关注的是位置,所以第一个下划线_
舍弃掉最大值。
item()
把tensor转换成数torch.nn.Sequential语法
nn.Sequential(a, b, c)
括号,逗号torchvision.transforms.Composed语法
transforms.Composed([a, b, c])
括号,方括号,逗号