Pytorch 扩展Tensor维度、压缩Tensor维度的方法
1.扩展Tensor维度
相信刚接触Pytorch的宝宝们,会遇到这样一个问题,输入的数据维度和实验需要维度不一致,输入的可能是2维数据或3维数据,实验需要用到3维或4维数据,那么我们需要扩展这个维度。其实特别简单,只要对数据加一个扩展维度方法就可以了。
1.1torch.unsqueeze(self:Tensor,dim:_int)
torch.unsqueeze(self:Tensor,dim:_int)
参数说明:self:输入的tensor数据,dim:要对哪个维度扩展就输入那个维度的整数,可以输入0,1,2……
1.2Code
第一种方式,输入数据后直接加unsqueeze()
扩展第一维和第二维为1
importtorch defreset_unsqueeze1(): data=torch.rand([3,3]) data1=data.unsqueeze(dim=0).unsqueeze(dim=1) print("data_size:",data.shape) print("data:",data) print("data1_size:",data1.shape) print("data1:",data1)
结果显示
第二种方式,用torch.unsqueeze()
importtorch defreset_unsqueeze2(): data=torch.rand([3,3]) data1=torch.unsqueeze(data,dim=0) print("data_size:",data.shape) print("data:",data) print("data1_size:",data1.shape) print("data1:",data1)
结果显示
2.压缩Tensor维度
2.1torch.squeeze(self:Tensor,dim:_int)
这个方法刚好和torch.unsqueeze()方法效果相反,压缩Tensor维度。
2.2Code
第一种方式,输入数据后直接加squeeze()
importtorch defreset_squeeze1(): data=torch.rand([1,1,3,3]) data1=data.squeeze(dim=0).squeeze(dim=1) print("data_size:",data.shape) print("data:",data) print("data1_size:",data1.shape) print("data1:",data1)
结果显示
第二种方式,用torch.squeeze()
importtorch defreset_squeeze2(): data=torch.rand([1,1,3,3]) data1=torch.squeeze(data,dim=0) print("data_size:",data.shape) print("data:",data) print("data1_size:",data1.shape) print("data1:",data1)
结果显示
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。