计算pytorch标准化(Normalize)所需要数据集的均值和方差实例
pytorch做标准化利用transforms.Normalize(mean_vals,std_vals),其中常用数据集的均值方差有:
if'coco'inargs.dataset: mean_vals=[0.471,0.448,0.408] std_vals=[0.234,0.239,0.242] elif'imagenet'inargs.dataset: mean_vals=[0.485,0.456,0.406] std_vals=[0.229,0.224,0.225]
计算自己数据集图像像素的均值方差:
importnumpyasnp importcv2 importrandom #calculatemeansandstd train_txt_path='./train_val_list.txt' CNum=10000#挑选多少图片进行计算 img_h,img_w=32,32 imgs=np.zeros([img_w,img_h,3,1]) means,stdevs=[],[] withopen(train_txt_path,'r')asf: lines=f.readlines() random.shuffle(lines)#shuffle,随机挑选图片 foriintqdm_notebook(range(CNum)): img_path=os.path.join('./train',lines[i].rstrip().split()[0]) img=cv2.imread(img_path) img=cv2.resize(img,(img_h,img_w)) img=img[:,:,:,np.newaxis] imgs=np.concatenate((imgs,img),axis=3) #print(i) imgs=imgs.astype(np.float32)/255. foriintqdm_notebook(range(3)): pixels=imgs[:,:,i,:].ravel()#拉成一行 means.append(np.mean(pixels)) stdevs.append(np.std(pixels)) #cv2读取的图像格式为BGR,PIL/Skimage读取到的都是RGB不用转 means.reverse()#BGR-->RGB stdevs.reverse() print("normMean={}".format(means)) print("normStd={}".format(stdevs)) print('transforms.Normalize(normMean={},normStd={})'.format(means,stdevs))
以上这篇计算pytorch标准化(Normalize)所需要数据集的均值和方差实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。