Tensorflow2.0学习笔记
通过MNIST和CIFAR10两个数据集的练习之后,对整个模型的搭建、训练以及测试已经比较熟练了。猫狗分类同样也是比较经典的cv问题,本文再进一步强化之前知识点的同时,学习使用tf.keras.preprocessing.image.ImageDataGenerator来处理数据,并且尝试使用数据增强的方法来解决过拟合问题。
Step1:国际惯例加载tensorflow
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import os
import numpy as np
import matplotlib.pyplot as plt
Step2:加载数据集
数据集来自kaggle
_URL = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
path_to_zip = tf.keras.utils.get_file('cats_and_dogs.zip', origin=_URL, extract=True)
PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')
Step3:查看训练集内容
为变量分配相应的路径
train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'validation')
train_cats_dir = os.path.join(train_dir, 'cats')
train_dogs_dir = os.path.join(train_dir, 'dogs')
validation_cats_dir = os.path.join(validation_dir, 'cats')
validation_dogs_dir = os.path.join(validation_dir, 'dogs')
查看数据集中的图片数量
num_cats_tr = len(os.listdir(train_cats_dir))
num_dogs_tr = len(os.listdir(train_dogs_dir))
num_cats_val = len(os.listdir(validation_cats_dir))
num_dogs_val = len(os.listdir(validation_dogs_dir))
total_train = num_cats_tr + num_dogs_tr
total_val = num_cats_val + num_dogs_val
print('total training cat images:', num_cats_tr)
print('total training dog images:', num_dogs_tr)
print('total validation cat images:', num_cats_val)
print('total validation dog images:', num_dogs_val)
print("--")
print("Total training images:", total_train)
print("Total validation images:", total_val)
Step4:使用ImageDataGenerator处理数据
先设置一下参数
batch_size = 128
epochs = 15
IMG_HEIGHT = 150
IMG_WIDTH = 150
# 归一化处理
train_image_generator = ImageDataGenerator(rescale=1./255)
validation_image_generator = ImageDataGenerator(rescale=1./255)
#
train_data_gen = train_image_generator.flow_from_directory(batch_size=batch_size,
directory=train_dir,
shuffle=True,
target_size=(IMG_HEIGHT, IMG_WIDTH),
class_mode='binary')
val_data_gen = validation_image_generator.flow_from_directory(batch_size=batch_size,
directory=validation_dir,
target_size=(IMG_HEIGHT, IMG_WIDTH),
class_mode='binary')
Step5:搭建模型并设置训练流程
同样还是使用卷积网络
model = Sequential([
Conv2D(16, 3, padding='same', activation='relu', input_shape=(IMG_HEIGHT, IMG_WIDTH ,3)),
MaxPooling2D(),
Conv2D(32, 3, padding='same', activation='relu'),
MaxPooling2D(),
Conv2D(64, 3, padding='same', activation='relu'),
MaxPooling2D(),
Flatten(),
Dense(512, activation='relu'),
Dense(1)
])
model.compile(optimizer='adam',
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
metrics=['accuracy'])
显示网络瞧瞧
model.summary()
Step6:启动训练并可视化训练结果
history = model.fit_generator(
train_data_gen,
steps_per_epoch=total_train // batch_size,
epochs=epochs,
validation_data=val_data_gen,
validation_steps=total_val // batch_size
)
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss=history.history['loss']
val_loss=history.history['val_loss']
epochs_range = range(epochs)
plt.figure(figsize=(15, 5))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
相关知识
#学习笔记# 最有...
paddleocr学习笔记(四)评估、推理
【讨论】养猫学习笔记(02)——营养需求标准,什么是NRC、AAFCO
四上语文第6课《夜间飞行的秘密》学习笔记
xhtml & css 简易学习笔记(三)
《python深度学习》笔记
摄影学习笔记 · 用猫的视角里观察世界
IOS学习笔记(六)之UISlider的概念和使用方法
盗墓笔记宠物技能介绍
[半监督学习论文笔记
网址: Tensorflow2.0学习笔记 https://www.mcbbbk.com/newsview628745.html
上一篇: 这些你注意了吗? |
下一篇: Java适配器模式 |
推荐分享

- 1我的狗老公李淑敏33——如何 5096
- 2南京宠物粮食薄荷饼宠物食品包 4363
- 3家养水獭多少钱一只正常 3825
- 4豆柴犬为什么不建议养?可爱的 3668
- 5自制狗狗辅食:棉花面纱犬的美 3615
- 6狗交配为什么会锁住?从狗狗生 3601
- 7广州哪里卖宠物猫狗的选择性多 3535
- 8湖南隆飞尔动物药业有限公司宠 3477
- 9黄金蟒的价格 3396
- 10益和 MATCHWELL 狗 3352