Loading... # numpy和tensor互相转换 `用于记录python在深度学习中的使用` 经常使用到的场景是从本地opencv读取一张图片,需要将图片由numpy格式转为tensor格式,并经过一些通道变换(h,w,c)->(b,c,w,h) 。转为pytorch可以推理的张量tensor。 推理结束之后还需要将tensor转为numpy数据,用于图片的保存或者显示。 ```python def tensor2img(tensor): img = tensor.squeeze(0).permute(1,2,0).numpy() #转为tensor img = (img[:,:,::-1]*255.0).astype(np.uint8) #rgb格式转为bgr格式,并转为0-255 int8格式 return img #(224,224,3) 图片格式 def img2tensor(img): img = (img[:,:,::-1].astype(np.float32))/255.0 #bgr格式转为rgb格式,并转为0-1 float32格式 img = np.transpose(img,(2,0,1)) # (224,224,3) 转为(3,224,224) img = img[np.newaxis,:,:,:].copy() #(1,3,224,224) tensor = torch.from_numpy(img) return tensor #(1,3,224,224) 模型输入格式 ``` > 上面代码为img和tensor的具体转换实现。 其中代码的transpose、permute、unsqueeze、squeeze、np.newaxis都是维度的变化 ## transpose/permute transpose 可以理解为修改维度的先后顺序,例如将第三维度拉到第一维度,拉完之后在相应维度上的信息没有发生改变。 numpy 的transpose可以变换多个维度 torch tensor 的transpose 只能切换两个维度,permute才能切换多个维度 **numpy** ```python # numpy num = np.random.randn(3,4,5,6) ##两种实现方式 tran = num.transpose(1,2,3,0) #维度变成(4,5,6,3) tran2 = np.transpose(num,(1,2,3,0)) #维度变成(4,5,6,3) print(tran[:,:,:,0]==(num[0,:,:,:])) #相等 ``` 结果 ```python [[[ True True True True True True] [ True True True True True True] [ True True True True True True] [ True True True True True True] [ True True True True True True]]... ``` **tensor** ```python # torch num_tensor = torch.randn(3,4,5,6) #tran3 = num_tensor.transpose(3,0,1,2) #会出错 tran3 = num_tensor.permute(3,0,1,2) #会出错 print(num_tensor[:,:,:,0]==(tran3[0,:,:,:])) #相等 ``` 结果 ```python tensor([[[True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True]],.... ``` ## 维度变化之unsqueeze,squeeze,np.newaxis **numpy** ```python img = np.random.randn(224,224,3) print(img.shape) img = img[:,:,:,np.newaxis] print(img.shape) ``` 结果 ```python (224, 224, 3) (224, 224, 3, 1) ``` **tensor** ```python tensor = torch.randn(3,224,224) tensor = tensor.unsqueeze(0) #在第一个扩展一个维度 print(tensor.shape) tensor = tensor.squeeze(0) #在第一个减少一个维度 print(tensor.shape) ``` 结果 ```python torch.Size([1, 3, 224, 224]) torch.Size([3, 224, 224]) ``` 最后修改:2025 年 02 月 22 日 © 允许规范转载 打赏 赞赏作者 支付宝微信 赞 如果觉得我的文章对你有用,请随意赞赏