Pytorch基础:Tensor的permute()方法
相关阅读Pytorch基础https://blog.csdn.net/weixin_45791458/category_12457644.html在Pytorch中permute()是Tensor类的一个重要方法同时它也是一个torch模块中的一个函数它们的语法如下所示。Tensor.permute(*dims) → Tensor torch.permute(input, dims) → Tensor input (Tensor) – the input tensor dims (tuple of int) – The desired ordering of dimensions官方的解释是返回原始张量输入的视图并对其维度进行转置。这里返回视图指的是一个新的tensor对象但新旧tensor对象内的数据共享存储即数据元素是相同的返回的新对象可能会变得不连续使用is_contiguous()方法可以鉴定是否连续。关于非连续张量的更多细节可以看下面的文章。Pytorch基础Tensor的连续性https://blog.csdn.net/weixin_45791458/article/details/140736700?ops_request_misc%257B%2522request%255Fid%2522%253A%2522eb4c722817c335758581a52404bb2dce%2522%252C%2522scm%2522%253A%252220140713.130102334.pc%255Fblog.%2522%257Drequest_ideb4c722817c335758581a52404bb2dcebiz_id0utm_mediumdistribute.pc_search_result.none-task-blog-2~blog~first_rank_ecpm_v1~rank_v31_ecpm-2-140736700-null-null.nonecaseutm_term%E9%9D%9E%E8%BF%9E%E7%BB%ADspm1018.2226.3001.4450下面以几个例子更好的理解permute()方法import torch # 创建一个张量 x torch.rand(3, 3, 3) # 使用permute操作倒置三个维度 y x.permute(2, 1, 0) print(x, y) tensor([[[0.9701, 0.7507, 0.8002], [0.5876, 0.1460, 0.0386], [0.5126, 0.1538, 0.5863]], [[0.8500, 0.8774, 0.2415], [0.1053, 0.5650, 0.7321], [0.8260, 0.1564, 0.7447]], [[0.5131, 0.7111, 0.3469], [0.6031, 0.8140, 0.9770], [0.7578, 0.0223, 0.5515]]]) tensor([[[0.9701, 0.8500, 0.5131], [0.5876, 0.1053, 0.6031], [0.5126, 0.8260, 0.7578]], [[0.7507, 0.8774, 0.7111], [0.1460, 0.5650, 0.8140], [0.1538, 0.1564, 0.0223]], [[0.8002, 0.2415, 0.3469], [0.0386, 0.7321, 0.9770], [0.5863, 0.7447, 0.5515]]]) print(id(x),id(y)) 4554479952 4811331200 # 说明两个张量对象不同 print(x.storage().data_ptr(), y.storage().data_ptr()) 4830094080 4830094080 # 说明两个张量对象里面保存的数据存储是共享的 y[0, 0] 7 print(x, y) tensor([[[7.0000, 0.7507, 0.8002], [0.5876, 0.1460, 0.0386], [0.5126, 0.1538, 0.5863]], [[7.0000, 0.8774, 0.2415], [0.1053, 0.5650, 0.7321], [0.8260, 0.1564, 0.7447]], [[7.0000, 0.7111, 0.3469], [0.6031, 0.8140, 0.9770], [0.7578, 0.0223, 0.5515]]]) tensor([[[7.0000, 7.0000, 7.0000], [0.5876, 0.1053, 0.6031], [0.5126, 0.8260, 0.7578]], [[0.7507, 0.8774, 0.7111], [0.1460, 0.5650, 0.8140], [0.1538, 0.1564, 0.0223]], [[0.8002, 0.2415, 0.3469], [0.0386, 0.7321, 0.9770], [0.5863, 0.7447, 0.5515]]])# 说明对新tensor的更改影响了原tensor print(x.is_contiguous(), y.is_contiguous()) True False # 说明x是连续的y不是连续的类似于之前在列表的浅拷贝文中说到的那样对新列表内部嵌套的列表中的元素的更改会影响原列表如下所示。import copy my_list [1, 2, [1, 2]] your_list list(my_list) #工厂函数 his_list my_list[:] #切片操作 her_list copy.copy(my_list) #copy模块的copy函数 your_list[2][0] 3 print(my_list) print(your_list) print(his_list) print(her_list) his_list[2][1] 4 print(my_list) print(your_list) print(his_list) print(her_list) her_list[2].append(5) print(my_list) print(your_list) print(his_list) print(her_list) 输出 [1, 2, [3, 2]] [1, 2, [3, 2]] [1, 2, [3, 2]] [1, 2, [3, 2]] [1, 2, [3, 4]] [1, 2, [3, 4]] [1, 2, [3, 4]] [1, 2, [3, 4]] [1, 2, [3, 4, 5]] [1, 2, [3, 4, 5]] [1, 2, [3, 4, 5]] [1, 2, [3, 4, 5]]但与列表不一样的是tensor中非嵌套的内容的修改也会导致另一个tensor受到影响如下所示。import torch # 创建一个张量 x torch.tensor([[1, 2, 3], [4, 5, 6]]) # 使用permute操作 y x.permute(0, 1) print(x, y) tensor([[1, 2, 3], [4, 5, 6]]) tensor([[1, 4], [2, 5], [3, 6]]) x[0] torch.tensor[4, 4, 4] # 改变其中一个tensor的第0个元素 print(x, y) tensor([[4, 4, 4], [4, 5, 6]]) tensor([[4, 4], [4, 5], [4, 6]])在pytorch中和permute()方法类似的还有transpose()方法它的使用与permute()方法在功能上有一定重叠但也有区别。transpose()方法只能交换两个维度而permute()方法可以一次性对任意维度进行重排因此在复杂场景中更加灵活。更多细节可以看下面的文章。Pytorch基础Tensor的transpose()方法https://blog.csdn.net/weixin_45791458/article/details/133470992?spm1001.2014.3001.5502