Pytorch采用AlexNet实现猫狗数据集分类(训练与预测)
介绍
AlexNet模型是CNN网络中经典的网络模型,适合初学者学习,本文对AlexNet结构参数初步说明,详细可以下载论文。通过AlexNet对Kaggle的猫狗数据集进行训练和预测,相关资料为搜集总结。 1
AlexNet网络模型
如图是2012年AlexNet网络模型结构,由于之前GPU内存小,当时网络是采用了两块GPU,现在训练是不需要的,AlexNet的特点归纳为以下几点:
猫狗数据集
Kaggle猫狗数据集,可以直接在官网下载https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition/data。下载的压缩包解压后,如图:
train是25000张猫狗图片,各占一半,图片名字都进行了标记,由于在同一个文件中,在数据处理阶段,需要打乱图片顺序以及读取图片名对其进行标记分类。训练阶段,只需要用到train文件的图片,首先写个dataset方便图片读取和相关操作,文件命名为My_dataset.py.
import os import random from PIL import Image from torch.utils.data import Dataset random.seed(1) class CatDogDataset(Dataset): def __init__(self, data_dir, mode="train", split_n=0.9, rng_seed=620, transform=None): """ rmb面额分类任务的Dataset :param data_dir: str, 数据集所在路径 :param transform: torch.transform,数据预处理 """ self.mode = mode self.data_dir = data_dir self.rng_seed = rng_seed self.split_n = split_n self.data_info = self._get_img_info() # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本 self.transform = transform def __getitem__(self, index): path_img, label = self.data_info[index] img = Image.open(path_img).convert('RGB') # 0~255 if self.transform is not None: img = self.transform(img) # 在这里做transform,转为tensor等等 return img, label def __len__(self): if len(self.data_info) == 0: raise Exception("ndata_dir:{} is a empty dir! Please checkout your path to images!".format(self.data_dir)) return len(self.data_info) def _get_img_info(self): img_names = os.listdir(self.data_dir) img_names = list(filter(lambda x: x.endswith('.jpg'), img_names)) random.seed(self.rng_seed) random.shuffle(img_names) img_labels = [0 if n.startswith('cat') else 1 for n in img_names] split_idx = int(len(img_labels) * self.split_n) # 25000* 0.9 = 22500 # split_idx = int(100 * self.split_n) if self.mode == "train": img_set = img_names[:split_idx] # 数据集90%训练 # img_set = img_names[:22500] # hard code 数据集90%训练 label_set = img_labels[:split_idx] elif self.mode == "valid": img_set = img_names[split_idx:] label_set = img_labels[split_idx:] else: raise Exception("self.mode 无法识别,仅支持(train, valid)") path_img_set = [os.path.join(self.data_dir, n) for n in img_set] data_info = [(n, l) for n, l in zip(path_img_set, label_set)] return data_info
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960AlexNet网络训练
在开始网络搭建前,还需的准备工作:
为了提高模型分类准确度,引入AlexNet在ImageNet比赛时的预训练模型,AlexNet结构就不重写了,直接调用Pytorch中的预设模型,AlexNet最后全连接层是1000分类的,所以之后代码中还需要修改最后一层参数。
链接:https://pan.baidu.com/s/16xd6PjmjPrAKIbta81yUdw
提取码:cyh9
简单的函数对预训练模型读取。
def get_model(path_state_dict, vis_model=False): """ 创建模型,加载参数 :param path_state_dict: :return: """ model = models.alexnet() pretrained_state_dict = torch.load(path_state_dict) model.load_state_dict(pretrained_state_dict) if vis_model: from torchsummary import summary summary(model, input_size=(3, 224, 224), device="cpu") model.to(device) return model
12345678910111213141516'训练全代码
import os import numpy as np import torch.nn as nn import torch from torch.utils.data import DataLoader import torchvision.transforms as transforms import torch.optim as optim from matplotlib import pyplot as plt import torchvision.models as models from A_alexnet.tools.my_dataset import CatDogDataset BASE_DIR = os.path.dirname(os.path.abspath(__file__)) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def get_model(path_state_dict, vis_model=False): """ 创建模型,加载参数 :param path_state_dict: :return: """ model = models.alexnet() pretrained_state_dict = torch.load(path_state_dict) model.load_state_dict(pretrained_state_dict) if vis_model: from torchsummary import summary summary(model, input_size=(3, 224, 224), device="cpu") model.to(device) return model if __name__ == "__main__": # config data_dir = os.path.join(BASE_DIR, "..", "data", "train") # 读取预训练模型 path_state_dict = os.path.join(BASE_DIR, "..", "data", "alexnet-owt-4df8aa71.pth") # 二分类,设置类为2 num_classes = 2 MAX_EPOCH = 3 # 可自行修改,设置大效果会好点 BATCH_SIZE = 200 # 可自行修改,内存大可以设置大点,速度快点 LR = 0.001 # 可自行修改 log_interval = 1 # 可自行修改 val_interval = 1 # 可自行修改 classes = 2 start_epoch = -1 lr_decay_step = 1 # 可自行修改 # ============================ step 1/5 数据 ============================ norm_mean = [0.485, 0.456, 0.406] norm_std = [0.229, 0.224, 0.225] train_transform = transforms.Compose([ transforms.Resize((256)), # (256, 256) 区别 transforms.CenterCrop(256), transforms.RandomCrop(224), transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std), ]) normalizes = transforms.Normalize(norm_mean, norm_std) valid_transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.TenCrop(224, vertical_flip=False), transforms.Lambda(lambda crops: torch.stack([normalizes(transforms.ToTensor()(crop)) for crop in crops])), ]) # 构建MyDataset实例 train_data = CatDogDataset(data_dir=data_dir, mode="train", transform=train_transform) valid_data = CatDogDataset(data_dir=data_dir, mode="valid", transform=valid_transform) # 构建DataLoder train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) valid_loader = DataLoader(dataset=valid_data, batch_size=4) # ============================ step 2/5 模型 ============================ alexnet_model = get_model(path_state_dict, False) num_ftrs = alexnet_model.classifier._modules["6"].in_features alexnet_model.classifier._modules["6"] = nn.Linear(num_ftrs, num_classes) alexnet_model.to(device) # ============================ step 3/5 损失函数 ============================ criterion = nn.CrossEntropyLoss() # ============================ step 4/5 优化器 ============================ optimizer = optim.SGD(alexnet_model.parameters(), lr=LR, momentum=0.9) # 选择优化器 scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_step, gamma=0.1) # 设置学习率下降策略 # ============================ step 5/5 训练 ============================ train_curve = list() valid_curve = list() for epoch in range(start_epoch + 1, MAX_EPOCH): loss_mean = 0. correct = 0. total = 0. alexnet_model.train() for i, data in enumerate(train_loader): # if i > 1: # break # forward inputs, labels = data inputs, labels = inputs.to(device), labels.to(device) outputs = alexnet_model(inputs) # backward optimizer.zero_grad() loss = criterion(outputs, labels) loss.backward() # update weights optimizer.step() # 统计分类情况 _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).squeeze().cpu().sum().numpy() # 打印训练信息 loss_mean += loss.item() train_curve.append(loss.item()) if (i+1) % log_interval == 0: loss_mean = loss_mean / log_interval print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format( epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total)) loss_mean = 0. scheduler.step() # 更新学习率 # validate the model if (epoch+1) % val_interval == 0: correct_val = 0. total_val = 0. loss_val = 0. alexnet_model.eval() with torch.no_grad(): for j, data in enumerate(valid_loader): inputs, labels = data inputs, labels = inputs.to(device), labels.to(device) bs, ncrops, c, h, w = inputs.size() # [4, 10, 3, 224, 224 outputs = alexnet_model(inputs.view(-1, c, h, w)) outputs_avg = outputs.view(bs, ncrops, -1).mean(1) loss = criterion(outputs_avg, labels) _, predicted = torch.max(outputs_avg.data, 1) total_val += labels.size(0) correct_val += (predicted == labels).squeeze().cpu().sum().numpy() loss_val += loss.item() loss_val_mean = loss_val/len(valid_loader) valid_curve.append(loss_val_mean) print("Valid:t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format( epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val_mean, correct_val / total_val)) alexnet_model.train() train_x = range(len(train_curve)) train_y = train_curve train_iters = len(train_loader) valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations valid_y = valid_curve # 保存网络模型及参数 torch.save(alexnet_model.state_dict(), 'whole_CatDog_params.pth') plt.plot(train_x, train_y, label='Train') plt.plot(valid_x, valid_y, label='Valid') plt.legend(loc='upper right') plt.ylabel('loss value') plt.xlabel('Iteration') plt.show()
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184运行训练代码,就开始计算。
一般设置3个epoch,训练集和验证集准确率都在96%以上。
注意:在代码中,torch.save(alexnet_model.state_dict(), ‘whole_CatDog_params.pth’)已经保存了最终训练模型及参数,路径就自己设置,这里我直接保存。
预测
import os os.environ['NLS_LANG'] = 'SIMPLIFIED CHINESE_CHINA.UTF8' import time import json import torch import torchvision.transforms as transforms from PIL import Image from matplotlib import pyplot as plt import torchvision.models as models BASE_DIR = os.path.dirname(os.path.abspath(__file__)) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def img_transform(img_rgb, transform=None): """ 将数据转换为模型读取的形式 :param img_rgb: PIL Image :param transform: torchvision.transform :return: tensor """ if transform is None: raise ValueError("找不到transform!必须有transform对img进行处理") img_t = transform(img_rgb) return img_t def load_class_names(p_clsnames, p_clsnames_cn): """ 加载标签名 :param p_clsnames: :param p_clsnames_cn: :return: """ with open(p_clsnames, "r") as f: class_names = json.load(f) with open(p_clsnames_cn, encoding='UTF-8') as f: # 设置文件对象 class_names_cn = f.readlines() return class_names, class_names_cn def get_model(path_state_dict, num_classes, vis_model=False): """ 创建模型,加载参数 :param path_state_dict: :return: """ model = models.alexnet(num_classes=num_classes) pretrained_state_dict = torch.load(path_state_dict) model.load_state_dict(pretrained_state_dict) model.eval() if vis_model: from torchsummary import summary summary(model, input_size=(3, 224, 224), device="cpu") model.to(device) return model def process_img(path_img): # hard code norm_mean = [0.485, 0.456, 0.406] norm_std = [0.229, 0.224, 0.225] inference_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop((224, 224)), transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std), ]) # path --> img img_rgb = Image.open(path_img).convert('RGB') # img --> tensor img_tensor = img_transform(img_rgb, inference_transform) img_tensor.unsqueeze_(0) # chw --> bchw img_tensor = img_tensor.to(device) return img_tensor, img_rgb if __name__ == "__main__": num_classes=2 # config path_state_dict = os.path.join(BASE_DIR, "whole_CatDog_params_0909.pth") path_img = os.path.join(BASE_DIR, "..", "data", "272.jpg") # 1/5 load img img_tensor, img_rgb = process_img(path_img) # 2/5 load model alexnet_model = get_model(path_state_dict,num_classes, True) with torch.no_grad(): time_tic = time.time() outputs = alexnet_model(img_tensor) time_toc = time.time() # 4/5 index to class names _, pred_int = torch.max(outputs.data, 1) _, top1_idx = torch.topk(outputs.data, 1, dim=1) # pred_idx = int(pred_int.cpu().numpy()) if pred_idx == 0: pred_str= str("cat") print("img: {} is: {}".format(os.path.basename(path_img), pred_str)) else: pred_str = str("dog") print("img: {} is: {}".format(os.path.basename(path_img), pred_str)) print("time consuming:{:.2f}s".format(time_toc - time_tic)) # 5/5 visualization plt.imshow(img_rgb) plt.title("predict:{}".format(pred_str)) plt.text(5, 45, "top {}:{}".format(1, pred_str), bbox=dict(fc='yellow')) plt.show()
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119预测模型有需要注意的地方,是get_model直接修改最后全连接层的分类数为2,跟训练修改方法不一致。用训练模型的修改方法不知道为啥一直有bug,水平有限就直接换了方法修改。预测模型没有全导入test的照片预测,只是截取一张预测一个结果,感兴趣可以自行导入全部test照片。
推荐使用torchsummary,结果可以很直观看到网络结构及相关参数。直接pip install torchsummary就可以安装了。
预测图片
仅作为学习总结分享,有错误望小伙伴们指正。
相关知识
CNN简单实战:PyTorch搭建CNN对猫狗图片进行分类
详解pytorch实现猫狗识别98%附代码
基于Pytorch实现的声音分类
(转载)YOLOv5 实现目标检测(训练自己的数据集实现猫猫识别)
YOLOv5 实现目标检测(训练自己的数据集实现猫猫识别)
PyTorch深度学习:猫狗情感识别
深度学习卷积神经图像分类实现鸟类识别含训练代码和鸟类数据集(支持repVGG,googlenet, resnet, inception, mobilenet)
猫狗图片分类 03分析图片数据
使用PyTorch实现鸟类音频检测卷积网络模型
基于深度学习的鸟类识别系统(网页版+YOLOv8/v7/v6/v5代码+训练数据集)
网址: Pytorch采用AlexNet实现猫狗数据集分类(训练与预测) https://www.mcbbbk.com/newsview258445.html
上一篇: 鹦鹉家养小贴士:如何让你的羽毛朋 |
下一篇: 当领养它们第五天后,这变化太大了 |
推荐分享

- 1我的狗老公李淑敏33——如何 5096
- 2南京宠物粮食薄荷饼宠物食品包 4363
- 3家养水獭多少钱一只正常 3825
- 4豆柴犬为什么不建议养?可爱的 3668
- 5自制狗狗辅食:棉花面纱犬的美 3615
- 6狗交配为什么会锁住?从狗狗生 3601
- 7广州哪里卖宠物猫狗的选择性多 3535
- 8湖南隆飞尔动物药业有限公司宠 3477
- 9黄金蟒的价格 3396
- 10益和 MATCHWELL 狗 3352