pytorch训练自己的数据集
发布时间
阅读量:
阅读量
使用pytorch自带的模型,并修改全连接层为自己数据集的类别数。
model = models.resnet50(pretrained=False)
class_num = 62
fc_features = model.fc.in_features
model.fc = nn.Linear(fc_features, class_num)
model = model.to(device)
根据数据集文件夹制作包含文件路径 标签 的txt文本文件
import os
data = []
labels = []
filetype = "ppm"
for root,dirs,files in os.walk("./"):
for f in files:
if f.split('.')[1] == filetype:
data.append(root[1:]+"/"+f)
labels.append(int(root[2:]))
with open("datalist.txt",'w') as fi:
for i in range(len(data)-1):
fi.write("{} {}\n".format(data[i],labels[i]))
fi.write("{} {}".format(data[i+1],labels[i+1]))
得到如图所示结果

参考<>
制作Dataset类,定义数据读取方法
def default_loader(path):
return Image.open(path).convert('RGB')
class BTSCDataset(Dataset):
def __init__(self, dir, transforms=None, loader=default_loader, train=True):
super(BTSCDataset, self).__init__()
self.sub_directory = 'Training' if train else 'Testing'
imgs = []
with open(dir + "/BelgiumTSC_" + self.sub_directory + "/" + self.sub_directory + "/datalist.txt", "r") as f:
data = f.readlines()
for line in data:
label = int(line.split()[1])
img = line.split()[0]
imgs.append((dir + "/BelgiumTSC_" + self.sub_directory + "/" + self.sub_directory + img, label))
self.imgs = imgs
self.transforms = transforms
self.loader = loader
def __getitem__(self, item):
fn, label = self.imgs[item]
img = self.loader(fn)
if self.transforms is not None:
img = self.transforms(self)
return img, fn
def __len__(self):
return len(self.imgs)
参考https://www.pythonf.cn/read/156398
设计训练过程
全部评论 (0)
还没有任何评论哟~
