我们已经知道pytorch的tensor由“头信息区”和“storage”两部分组成,其中tensor的实际数据是以一维数组(storage)的方式存于某个连续的内存中的。。
1. tensor.view()
view 从字面意思上就是“视图”的意思,就是将原tensor以某种排列方式展示给我们,view()不会改变原storage、也不会新建storage,只新建头信息区。
本质上,view()返回的是针对一维storage的某种排列视图,并且这种视图只能是连续、等距切分storage再连续竖向叠加形成的视图,不能跳跃式切分,如下图例子。
注意:如果tensor是不连续的,则不能使用view()(//www.greatytc.com/p/51678ea7a959)。
>>>a = torch.tensor([1,2,3,4,5,6])
>>>b = a.view(2,3)
>>>print(a)
tensor([1, 2, 3, 4, 5, 6])
>>>print(b)
tensor([[1, 2, 3],
[4, 5, 6]])
# 查看两者头信息区,确实不一样
>>>print(id(a))
>>>print(id(b))
2639015088992
2639015088272
# 初步查看两者的storage,看起来一样
>>>print(a.storage())
>>>print(b.storage())
1
2
3
4
5
6
[torch.LongStorage of size 6]
1
2
3
4
5
6
[torch.LongStorage of size 6]
# 进一步确认两者的storage,确实一样
>>>print(a.storage().data_ptr())
>>>print(b.storage().data_ptr())
2638924338496
2638924338496
2. tensor.reshape()
我们知道,tensor不连续是不能使用 view() 方法的。
只有将不连续tensor转化为连续tensor(利用contiguous(),//www.greatytc.com/p/51678ea7a959)后,才能使用view()。
reshape()正是先完成连续化,然后再进行view()。
reshape() 和 view() 的区别:
(1)当 tensor 满足连续性要求时,reshape() = view(),和原来 tensor 共用存储区;
(2)当 tensor不满足连续性要求时,reshape() = **contiguous() + view(),会产生有新存储区的 tensor,与原来tensor 不共用存储区。
3.tensor.resize_()
前面说到的 view()和reshape()都必须要用到全部的原始数据,比如你的原始数据只有12个,无论你怎么变形都必须要用到12个数字,不能多不能少。因此你就不能把只有12个数字的 tensor 强行 reshap 成 2×5 的。
但是 resize_() 可以做到,无论原始存储区有多少个数字,我都能变成你想要的维度,数字不够怎么办?随机产生凑!数字多了怎么办?就取我需要的部分!
1.截取时:
会改变原tensor a,但不会改变storage(地址和值都不变),且a和b共用storage(这里是2638930351680
)。
# 原tensor a
>>>a = torch.tensor([1,2,3,4,5,6,7])
>>>print(a)
>>>print(a.storage())
>>>print(a.storage( ).data_ptr())
tensor([1, 2, 3, 4, 5, 6, 7])
1
2
3
4
5
6
7
[torch.LongStorage of size 7]
2638930351680
# b是a的截取,并reshape成2×3
>>>b = a.resize_(2,3)
>>>print(a)
>>>print(b)
tensor([[1, 2, 3],
[4, 5, 6]]) #a变了
tensor([[1, 2, 3],
[4, 5, 6]])
>>>print(a.storage())
>>>print(b.storage())
1
2
3
4
5
6
7
[torch.LongStorage of size 7]
1
2
3
4
5
6
7
[torch.LongStorage of size 7]
>>>print(a.storage( ).data_ptr())
>>>print(b.storage( ).data_ptr())
2638930352576
2638930352576
2.添加时:会改变原tensor a,且会改变storage(地址和值都变),但a和b还是共用storage(这里是2638924338752
)。
>>>a = torch.tensor([1,2,3,4,5])
>>>print(a)
>>>print(a.storage())
>>>print(a.storage( ).data_ptr())
tensor([1, 2, 3, 4, 5])
1
2
3
4
5
[torch.LongStorage of size 5]
2638924334528
>>>b = a.resize_(2,3)
>>>print(a)
>>>print(b)
tensor([[1, 2, 3],
[4, 5, 0]]) #a变了
tensor([[1, 2, 3],
[4, 5, 0]])
>>>print(a.storage())
>>>print(b.storage())
1
2
3
4
5
0
[torch.LongStorage of size 6]
1
2
3
4
5
0
[torch.LongStorage of size 6]
>>>print(a.storage( ).data_ptr())
>>>print(b.storage( ).data_ptr())
2638924338752
2638924338752