首页 分享 pytorch训练网络冻结某些层

pytorch训练网络冻结某些层

来源:萌宠菠菠乐园 时间:2024-12-01 09:32

引言:首先我们应该很清楚地知道冻结网络中的某些层有什么作用?如何进行相关的冻结设置?代码何如呢?

话不多说说,首先我们探讨第一个问题:

1.冻结网络的某些层有什么作用?

        这个问题顾名思义就是冻结网络中的某些层,使网络在训练过程中,这些层都在不参与的状态,即网络中的某些参数设置就不会更改(已有的训练模型,类似于基于迁移学习的过程),如此大大加快了网络的训练过程,减少了训练的时间。此方法多用于基于迁移学习的模型训练与同时分别训练不同的网络。

2.冻结网络中的某些层应该如何设置?

        下面给出的是关于YOLO网络在冻结层的设置:

代码:

freeze = ['', ]

if any(freeze):

for k, v in model.named_parameters():

if any(x in k for x in freeze):

print('freezing %s' % k)

v.requires_grad = False

        最为明显的设置是最后一句,即   v.requires_grad = False   此语句就是冻结的核心,将梯度设置为False,因为在网络中,训练的基本思想就是利用梯度下降法寻找全局最优解,而此时将梯度关闭,就等价于梯度为0,如此就无法训练该层,所以在训练网络时只能忽略该层,并且保证该层的参数设置不改变。

3.代码何如?

        注意到冻结网络中的某些层,就在上述代码中的 freeze = ['', ]  # parameter names to freeze (full or partial),添加某些层,例如最后一层卷积层,就添加  freeze = ['last_conv', ] ,你学废了吗?

(补)4.优化器设置

        上述只是对网络中的某层进行冻结(在下不才,理解为使梯度为0),还有重要的是如何使优化器知道呢?

        就要重新设定网络了,提示优化器在网络中,这些层已经被我冻住了,你走吧!不需要你优化了,于是优化器一看,果然如此,于是绕道而行。下面是相关的优化器设置代码,如下:

代码:

pg0, pg1, pg2 = [], [], [] # optimizer parameter groups

for k, v in model.named_parameters():

v.requires_grad = True

if '.bias' in k:

pg2.append(v) # biases

elif '.weight' in k and '.bn' not in k:

pg1.append(v) # apply weight decay

else:

pg0.append(v) # all else

if opt.adam:

optimizer = optim.Adam(pg0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum

else:

optimizer = optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)

optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']}) # add pg1 with weight_decay

optimizer.add_param_group({'params': pg2}) # add pg2 (biases)

logger.info('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0)))

del pg0, pg1, pg2

        优化器看一眼就知道,记住了哪些网络已不需要优化了,采用分解记忆,在偏置、权重、批量归一化中记忆,采用Adam或者SGD优化器。最后把中间的一些变量删除,防止占用内存空间。

5.结束

GAME OVER

如果小伙伴梦有疑问欢迎在评论区留言哦!!!

如果感觉不错的话!你懂得(O(∩_∩)O哈哈~)

欢迎和小伙伴梦一起学习+探讨问题,共同努力,加油!!!

欢迎转载,记得通知我哦!!!

网址是:

https://blog.csdn.net/m0_56654441/article/details/120610487

相关知识

pytorch训练网络冻结某些层
Pytorch与深度学习自查手册4
使用PyTorch实现鸟类音频检测卷积网络模型
详解pytorch实现猫狗识别98%附代码
Pytorch 使用Pytorch Lightning DDP时记录日志的正确方法
猫狗分类PyTorch:深度学习与迁移学习的探索
Pytorch采用AlexNet实现猫狗数据集分类(训练与预测)
PyTorch HuggingFace Trainer 训练数据的日志记录
【深度学习】AlexNet网络实现猫狗分类
宠物行为识别教程:基于DenseNet模型与PyTorch框架

网址: pytorch训练网络冻结某些层 https://www.mcbbbk.com/newsview673888.html

所属分类:萌宠日常
上一篇: 学习笔记 c++ (编一个程序求
下一篇: scala案例练习

推荐分享