掘金 人工智能 05月09日 14:45
手写 VGG 网络模型实现 CIFAR10 数据集分类
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

本文介绍了VGG卷积神经网络的架构特点,包括其深度、小卷积核、连续卷积层、全连接层、ReLU激活函数以及归一化处理。同时,详细阐述了CIFAR-10数据集的构成、特点及其在图像识别领域的重要性。文章还提供了使用PyTorch实现VGG网络在CIFAR-10数据集上进行分类的完整代码,包括数据预处理、模型定义、训练、测试以及推理过程。代码实现部分,展示了如何加载数据集、定义VGG16网络、设置损失函数和优化器,并最终训练和测试模型。

🏢VGG网络是一种经典的深度卷积神经网络,以其深度和使用3x3小卷积核为特点,通过堆叠多个卷积层来提取图像特征,并在图像识别领域表现出色。VGG网络有不同深度版本,如VGG11、VGG16和VGG19,它们的主要区别在于卷积层的数量。

🖼️CIFAR-10数据集是一个包含10个类别的彩色图像数据集,常用于评估图像识别算法的性能。数据集包含60000张32x32的RGB图像,每个类别有6000张图像,分为50000张训练集和10000张测试集。图像类别包括飞机、汽车、鸟类等。

💻文章提供了一份使用PyTorch实现VGG网络在CIFAR-10数据集上进行分类的完整Python代码。代码展示了如何进行数据预处理(包括标准化和数据增强)、定义VGG16网络结构、设置损失函数和优化器,以及训练和测试模型。

⏱️代码中包含训练模型的函数`train_model`,该函数使用训练数据集训练VGG模型,并在每个epoch结束后保存模型的状态字典。训练过程中,会打印每个batch的损失值,并记录每个epoch所花费的时间。训练完成后,模型会被保存到指定的路径。

VGG 网络

VGG是一个经典的卷积神经网络(CNN)架构,由牛津大学的视觉几何组(Visual Geometry Group)在2014年提出。VGG网络因其简单而有效的设计而闻名,在图像识别领域取得了很好的效果。VGG网络的主要特点是:

    深度:VGG网络非常深,原始的VGG网络有16层(包括卷积层和全连接层),后来简化为19层和11层的版本。

    小卷积核:VGG网络只使用3x3的卷积核,这与其他网络(如AlexNet)使用的更大卷积核不同。

    连续的卷积层:VGG网络在最大池化层之间堆叠多个卷积层,这样可以增加网络的深度。

    全连接层:在卷积层之后,VGG网络使用几个全连接层来进一步处理特征。

    激活函数:VGG网络使用ReLU作为激活函数。

    归一化:VGG网络在每个卷积层之后使用局部响应归一化(Local Response Normalization, LRN),但在后来的一些实现中,这个步骤被省略了。

VGG网络的架构如下:

VGG网络的一个关键贡献是证明了网络的深度对于性能的重要性,并且通过实验表明,更深的网络可以通过堆叠简单的3x3卷积层来构建。

CIFAR10 数据集

CIFAR-10数据集是一个广泛用于机器学习和计算机视觉研究的图像数据集。它由加拿大高级研究院(Canadian Institute For Advanced Research)和多伦多大学(University of Toronto)的计算机科学系提供。CIFAR-10数据集包含10个类别的60000张32x32的RGB彩色图像,每个类别有6000张图像。这些类别包括:

    飞机(Airplane)汽车(Automobile)鸟类(Bird)猫(Cat)鹿(Deer)狗(Dog)蛙(Frog)马(Horse)船(Ship)卡车(Truck)

每个类别的图像都是随机拍摄的,并且图像中的对象可能会出现不同的方向、大小和背景。CIFAR-10数据集被分为训练集和测试集,其中训练集包含50000张图像,测试集包含10000张图像。

CIFAR-10数据集的特点包括:

CIFAR-10数据集常用于评估图像识别算法的性能,尤其是在卷积神经网络(CNN)的研究中。由于其图像尺寸较小,CIFAR-10也适合用于测试算法的计算效率和内存使用情况。此外,CIFAR-10数据集也是许多深度学习框架和库中内置的标准数据集,方便研究人员和开发者快速开始实验和模型训练。

使用CIFAR-10数据集时,研究人员通常会进行数据增强、归一化处理等预处理步骤,以提高模型的性能和泛化能力。此外,CIFAR-10数据集也常用于比较不同深度学习架构的效果,如VGG、ResNet、AlexNet等。

VGG 网络模型实现 CIFAR10 数据集分类

    keras 中已经有CIFAR10 数据集。PyTorch的torchvision.models模块已经实现了VGG网络,可以直接调用。手写 vgg 网络训练CIFAR10 数据集完整Python代码:
import torchimport torch.nn as nnimport torch.optim as optimimport torchvisionimport torchvision.transforms as transformsfrom torch.utils.data import DataLoaderimport time# ===========================================训练===================================================# 数据预处理:标准化到[0, 1]范围,并做数据增强transform_train = transforms.Compose([    transforms.RandomHorizontalFlip(),    transforms.RandomCrop(32, padding=4),    transforms.ToTensor(),    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),])transform_test = transforms.Compose([    transforms.ToTensor(),    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),])# 加载训练集和测试集trainset = torchvision.datasets.CIFAR10(root='E:/datasets/cifar-10-batches-py', train=True, download=True, transform=transform_train)testset = torchvision.datasets.CIFAR10(root='E:/datasets/cifar-10-batches-py', train=False, download=True, transform=transform_test)trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)testloader = DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)# 定义vgg16 网络class VGG(nn.Module):    def __init__(self, num_classes=10):        super(VGG, self).__init__()        self.features = nn.Sequential(            nn.Conv2d(3, 64, kernel_size=3, padding=1),            nn.ReLU(inplace=True),            nn.Conv2d(64, 64, kernel_size=3, padding=1),            nn.ReLU(inplace=True),            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),                        nn.Conv2d(64, 128, kernel_size=3, padding=1),            nn.ReLU(inplace=True),            nn.Conv2d(128, 128, kernel_size=3, padding=1),            nn.ReLU(inplace=True),            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),                        nn.Conv2d(128, 256, kernel_size=3, padding=1),            nn.ReLU(inplace=True),            nn.Conv2d(256, 256, kernel_size=3, padding=1),            nn.ReLU(inplace=True),            nn.Conv2d(256, 256, kernel_size=3, padding=1),            nn.ReLU(inplace=True),            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),                        nn.Conv2d(256, 512, kernel_size=3, padding=1),            nn.ReLU(inplace=True),            nn.Conv2d(512, 512, kernel_size=3, padding=1),            nn.ReLU(inplace=True),            nn.Conv2d(512, 512, kernel_size=3, padding=1),            nn.ReLU(inplace=True),            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),                        nn.Conv2d(512, 512, kernel_size=3, padding=1),            nn.ReLU(inplace=True),            nn.Conv2d(512, 512, kernel_size=3, padding=1),            nn.ReLU(inplace=True),            nn.Conv2d(512, 512, kernel_size=3, padding=1),            nn.ReLU(inplace=True),            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),        )        self.classifier = nn.Sequential(            nn.Linear(512 * 1 * 1, 4096),            nn.ReLU(inplace=True),            nn.Dropout(p=0.5),            nn.Linear(4096, 4096),            nn.ReLU(inplace=True),            nn.Dropout(p=0.5),            nn.Linear(4096, num_classes),        )    def forward(self, x):        x = self.features(x)        x = torch.flatten(x, 1)        x = self.classifier(x)        return x# 初始化模型,损失函数,优化器device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = VGG(num_classes=10).to(device)# 定义损失函数和优化器criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练模型并保存模型文件def train_model(model, trainloader, criterion, optimizer, epochs=1, model_save_path='vgg_cifar10.pth'):    model.train()    for epoch in range(epochs):        running_loss = 0.0        start_time = time.time()        for i, (inputs, labels) in enumerate(trainloader, 0):            inputs, labels = inputs.to(device), labels.to(device)            optimizer.zero_grad()            outputs = model(inputs)            loss = criterion(outputs, labels)            loss.backward()            optimizer.step()            running_loss += loss.item()            if i % 100 == 99:  # 每100个小批量打印一次                print(f"[{epoch + 1}, {i + 1}] loss: {running_loss / 100:.3f}")                running_loss = 0.0        print(f"Epoch {epoch + 1} took {time.time() - start_time:.2f} seconds")    # 保存训练好的模型    torch.save(model.state_dict(), model_save_path)    print(f"Model saved to {model_save_path}")    # =========================================测试====================================================# 加载保存的模型#model = VGG(num_classes=10).to(device)#model.load_state_dict(torch.load('vgg_cifar10.pth'))#model.eval()  # 设置为评估模式def test_model(model, testloader):    correct = 0    total = 0    with torch.no_grad():  # 在测试时不需要计算梯度        for inputs, labels in testloader:            inputs, labels = inputs.to(device), labels.to(device)            outputs = model(inputs)            _, predicted = torch.max(outputs, 1)            total += labels.size(0)            correct += (predicted == labels).sum().item()    print(f"Accuracy on the 10000 test images: {100 * correct / total:.2f}%")# ===========================================推理===============================================# 1. 加载模型def load_model(model_path):    # 加载预训练的VGG16模型,替换最后的分类层以适应CIFAR-10的10个类别    model = models.vgg16(pretrained=False)    num_features = model.classifier[6].in_features    model.classifier[6] = torch.nn.Linear(num_features, 10)    model.load_state_dict(torch.load(model_path))    model.eval()  # 设置为评估模式    return model# 2. 定义预处理步骤def transform_image(image_path):    preprocess = transforms.Compose([        transforms.Resize((224, 224)),  # 调整图像大小以匹配VGG输入        transforms.ToTensor(),  # 将图像转换为Tensor        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # 归一化    ])    image = Image.open(image_path).convert('RGB')  # 打开图像并转换为RGB    image_t = preprocess(image)  # 预处理图像    batch_t = torch.unsqueeze(image_t, 0)  # 创建一个批次作为模型输入    return batch_t# 3. 进行推理def infer(model, input_tensor):    with torch.no_grad():  # 在这个上下文中,不计算梯度        output = model(input_tensor)    _, predicted_class = torch.max(output, 1)  # 获取预测结果    return predicted_class.item()# 主函数def main():    model_path = 'vgg_cifar10.pth'  # 模型文件路径    image_path = 'path_to_your_image.jpg'  # 替换为你的图片文件路径    model = load_model(model_path)  # 加载模型    input_tensor = transform_image(image_path)  # 图像预处理    predicted_class = infer(model, input_tensor)  # 推理    print(f'Predicted class index: {predicted_class}')# ===============================================主函数=====================================================if __name__ == '__main__':    # 训练    train_model(model, trainloader, criterion, optimizer, epochs=1, model_save_path='vgg_cifar10.pth')    # 测试    # test_model(model, testloader)    # 推理    # main()     
    训练好的模型:只进行了一次epoch

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

VGG网络 CIFAR-10 PyTorch 图像分类
相关文章