Advertisement

计算数据集的均值和方差

阅读量:

cifar10:

复制代码
 def unpickle(file):

    
     import _pickle as cPickle
    
     with open(file, 'rb') as f:
    
     dict = cPickle.load(f, encoding='latin1')
    
     return dict
    
  
    
 def get_mean_and_std(root_dir):
    
     train_data = []
    
     train_label = []
    
     for n in range(1, 6):
    
     dpath = f'{root_dir}/data_batch_{n}'
    
     data_dict = unpickle(dpath)
    
     train_data.append(data_dict['data'])
    
     train_label = train_label + data_dict['labels']
    
  
    
     train_data = np.array(train_data, dtype=np.float32) / 255.
    
  
    
     train_data = np.reshape(train_data, (-1, 3, 32, 32))
    
     print(train_data.shape)
    
     mean_std = []
    
     for i in range(3):
    
     mean, std = np.mean(train_data[:, i]), np.std(train_data[:, i])
    
     mean_std.append([mean, std])
    
     mean_std = np.array(mean_std)
    
     mean_std = mean_std.transpose((1, 0))
    
     print(mean_std)
    
  
    
 if __name__ == '__main__':
    
     data_path = '..\ dataset\ cifar10\ cifar-10-batches-py'
    
     get_mean_and_std(data_path)
    
    
    
    
    python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-18/IOoPfNFu5C2We0ndYSiw6lxstgA4.png)

result:

复制代码
 (50000, 3, 32, 32)

    
 [[0.49139968 0.48215827 0.44653124]
    
  [0.24703233 0.24348505 0.26158768]]

true_value:

复制代码
 mean = [0.4914, 0.4822, 0.4465]

    
 std = [0.2023, 0.1994, 0.2010]

全部评论 (0)

还没有任何评论哟~