一、案例介绍
- CIFAR10数据集介绍:
- 数据集中每张图片的尺寸是3 * 32 * 32, 代表彩色3通道
- CIFAR10数据集总共有10种不同的分类:
- 分别是”airplane”, “automobile”, “bird”, “cat”, “deer”, “dog”, “frog”, “horse”, “ship”, “truck”.
二、调优思路
- 数据预处理
- 数据归一化
- 减小不同维度量纲影响
- 降低模型误差,加速收敛
- 图像增强
- 原始数据+增强后的数据,比例建议2:1
- 图像放大(双线性差值)
- 调整图片尺寸,以适应resnet18的输入大小
- 数据归一化
- 模型优化
- 迁移学习(模型+预训练权重)
- 使用torch官方定义的模型以及imagenet上生成的权重,进行微调(站在巨人的肩膀上)
- 由于CIFAR10数据集是32*32,网络结构不需要特别复杂
- 主要使用模型:(注意调整output大小)
- torchvision.models.vgg16
- torchvision.models.resnet18
- 自适应学习率
- Nesterov动量算法
- Adam
- 二阶近似方法(不太会用)
- 牛顿法
- 求Hessian矩阵复杂度较高
- 拟牛顿法
- 牛顿法
- 迁移学习(模型+预训练权重)
- 其他
- 增加epoch
- 增加训练轮数,直到loss收敛到较小值(0.02左右)
- 如果数据量足够大,并不需要太多epoch
- GPU加速
- 主要提升训练速度,并不影响预测结果
- 需要在定义模型、加载batch_size数据时,xxx.to(device)即可
- 增加epoch
三、代码
1. 读取数据
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
# 查看是否能使用GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
transform = transforms.Compose([
# 数据增强:
# 增加随机水平旋转、垂直旋转、角度旋转、色度、亮度、饱和度、对比度的变化,以及随机的灰度化
# transforms.RandomHorizontalFlip(),
# transforms.RandomVerticalFlip(),
# transforms.RandomRotation(degrees=(0, 60)),
# transforms.ColorJitter(brightness=(0.7,1.3),contrast=(0.7,1.3),saturation=(0.7,1.3),hue=(-0.2,0.2)),
# transforms.RandomGrayscale(p=0.1),
# 调整图片尺寸,以适应resnet18的输入大小
transforms.Resize((224, 224)),
transforms.ToTensor(),
# 数据归一化:因为torchvision数据集的输出是PILImage格式, 数据域在[0, 1]. 我们将其转换为标准数据域[-1, 1]的张量格式
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32,
shuffle=True, num_workers=48)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=32,
shuffle=False, num_workers=48)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
2. 导入模型
# model = torchvision.models.resnet34().to(device)
# model = torchvision.models.vgg16(pretrained=False, progress=True, num_classes=10).to(device)
model = torchvision.models.resnet18(pretrained=True).to(device)
num_ftrs = model.fc.in_features
# 调整模型output,因为只有10分类
model.fc = nn.Linear(num_ftrs, 10).to(device)
# 定义优化器
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
# 梯度下降法
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, nesterov=True)
# 自适应学习率:Adam
# optimizer = optim.Adam(model.parameters(), lr=0.001)
# optimizer = torch.optim.RMSprop(model.parameters(), lr=0.001,momentum=0.9)
3. 模型训练
for epoch in range(20):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# data中包含输入图像张量inputs, 标签张量labels
inputs, labels = data[0].to(device), data[1].to(device)
# 首先将优化器梯度归零
optimizer.zero_grad()
# 输入图像张量进网络, 得到输出张量outputs
predicts = model(inputs)
# 利用网络的输出outputs和标签labels计算损失值
loss = criterion(predicts, labels)
loss.backward()
optimizer.step()
# 打印轮次和损失值
running_loss += loss.item()
if (i + 1) % 600 == 0:
print('[%d, %5d] loss: %.5f' %
(epoch + 1, i + 1, running_loss / 600))
running_loss = 0.0
print('Finished Training')
4. 模型评估
# 查看训练集上的ACC
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data[0].to(device),data[1].to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (
100 * correct / total))
# 查看在每个分类上的ACC
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
for data in testloader:
images, labels = data[0].to(device),data[1].to(device)
outputs = model(images)
_, predicted = torch.max(outputs, 1)
c = (predicted == labels).squeeze()
for i in range(4):
label = labels[i]
class_correct[label] += c[i].item()
class_total[label] += 1
for i in range(10):
print('Accuracy of %5s : %2d %%' % (
classes[i], 100 * class_correct[i] / class_total[i]))
# 展示预测图片和结果
dataiter = iter(testloader)
images, labels = dataiter.next()
# 打印原始图片
imshow(torchvision.utils.make_grid(images))
# 打印真实的标签
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(32)))
5. 模型保存与加载
# 模型保存
# 首先设定模型的保存路径
PATH = './cifar_net.pth'
# 保存模型的状态字典
torch.save(model.state_dict(), PATH)
# 模型加载
PATH = './cifar_net.pth'
model.load_state_dict(torch.load(PATH))
四、预测结果展示
- 查看训练集上的ACC
- 查看在每个分类上的ACC
五、模型下载链接
cifar_net.pth 提取码: i7hu