首页 分享 CNN简单实战:PyTorch搭建CNN对猫狗图片进行分类

CNN简单实战:PyTorch搭建CNN对猫狗图片进行分类

来源:萌宠菠菠乐园 时间:2024-09-17 02:26

在上一篇文章:CNN训练前的准备:PyTorch处理自己的图像数据(Dataset和Dataloader),大致介绍了怎么利用pytorch把猫狗图片处理成CNN需要的数据,今天就用该数据对自己定义的CNN模型进行训练及测试。

首先导入需要的包:

import torch from torch import optim import torch.nn as nn from torch.autograd import Variable from torchvision import transforms from torch.utils.data import Dataset, DataLoader from PIL import Image 1234567定义自己的CNN网络

class cnn(nn.Module): def __init__(self): super(cnn, self).__init__() self.relu = nn.ReLU() self.sigmoid = nn.Sigmoid() self.conv1 = nn.Sequential( nn.Conv2d( in_channels=3, out_channels=16, kernel_size=3, stride=2, ), nn.BatchNorm2d(16), nn.ReLU(), nn.MaxPool2d(kernel_size=2), ) # self.conv2 = nn.Sequential( nn.Conv2d( in_channels=16, out_channels=32, kernel_size=3, stride=2, ), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(kernel_size=2), ) # self.conv3 = nn.Sequential( nn.Conv2d( in_channels=32, out_channels=64, kernel_size=3, stride=2, ), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(kernel_size=2), ) self.fc1 = nn.Linear(3 * 3 * 64, 64) self.fc2 = nn.Linear(64, 10) self.out = nn.Linear(10, 2) def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) # print(x.size()) x = x.view(x.shape[0], -1) x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) x = self.out(x) # x = F.log_softmax(x, dim=1) return x 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455训练(GPU)

def train(): Dtr, Val, Dte = load_data() print('train...') epoch_num = 30 best_model = None min_epochs = 5 min_val_loss = 5 model = cnn().to(device) optimizer = optim.Adam(model.parameters(), lr=0.0008) criterion = nn.CrossEntropyLoss().to(device) # criterion = nn.BCELoss().to(device) for epoch in tqdm(range(epoch_num), ascii=True): train_loss = [] for batch_idx, (data, target) in enumerate(Dtr, 0): data, target = Variable(data).to(device), Variable(target.long()).to(device) # target = target.view(target.shape[0], -1) # print(target) optimizer.zero_grad() output = model(data) # print(output) loss = criterion(output, target) loss.backward() optimizer.step() train_loss.append(loss.cpu().item()) # validation val_loss = get_val_loss(model, Val) model.train() if epoch + 1 > min_epochs and val_loss < min_val_loss: min_val_loss = val_loss best_model = copy.deepcopy(model) tqdm.write('Epoch {:03d} train_loss {:.5f} val_loss {:.5f}'.format(epoch, np.mean(train_loss), val_loss)) torch.save(best_model.state_dict(), "model/cnn.pkl") 12345678910111213141516171819202122232425262728293031323334

一共训练30轮,训练的步骤如下:

初始化模型:

model = cnn().to(device) 1选择优化器以及优化算法,这里选择了Adam:

optimizer = optim.Adam(model.parameters(), lr=0.00005) 1选择损失函数,这里选择了交叉熵:

criterion = nn.CrossEntropyLoss().to(device) 1对每一个batch里的数据,先将它们转成能被GPU计算的类型:

data, target = Variable(data).to(device), Variable(target.long()).to(device) 1梯度清零、前向传播、计算误差、反向传播、更新参数:

optimizer.zero_grad() # 梯度清0 output = model(data)[0] # 前向传播 loss = criterion(output, target) # 计算误差 loss.backward() # 反向传播 optimizer.step() # 更新参数 12345测试(GPU)

def test(): Dtr, Val, Dte = load_data() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = cnn().to(device) model.load_state_dict(torch.load("model/cnn.pkl"), False) model.eval() total = 0 current = 0 for (data, target) in Dte: data, target = data.to(device), target.to(device) outputs = model(data) predicted = torch.max(outputs.data, 1)[1].data total += target.size(0) current += (predicted == target).sum() print('Accuracy:%d%%' % (100 * current / total)) 12345678910111213141516

结果:80%
在这里插入图片描述
如果需要更高的准确率,可以使用一些预训练的模型,详见:
PyTorch搭建预训练AlexNet、DenseNet、ResNet、VGG实现猫狗图片分类

完整代码:cnn-dogs-vs-cats。原创不易,下载时请给个follow和star!感谢!!

相关知识

CNN简单实战:PyTorch搭建CNN对猫狗图片进行分类
PyTorch深度学习:猫狗情感识别
PyTorch猫狗:深度学习在宠物识别中的应用
CNN参数设置经验
基于CNN的狗叫,猫叫语音分类
基于Python的图像分类 项目实践——图像分类项目
基于Pytorch框架的深度学习densenet121神经网络鸟类行为识别分类系统源码
深度学习的艺术:从理论到实践
web网页html版通过CNN卷积神经网络的宠物行为训练识别
推荐几个提供免费GPU计算资源的平台,助力你的AI之路

原文链接: CNN简单实战:PyTorch搭建CNN对猫狗图片进行分类 https://www.mcbbbk.com/newsview171061.html

分类:萌宠日常
上一篇: 基于Python的图像分类 项目...下一篇: iOS17宠物识别功能在iPho...

推荐分享