一、PyTorch 中的 nn.Module
nn.Module
是 PyTorch 框架中的核心类,用于定义和创建神经网络的基本结构。它可以帮助我们组织和管理网络中的各种组件,比如神经网络的层。在使用 PyTorch 创建神经网络时,通常都会创建一个继承自 nn.Module
的类。
二、为什么使用 nn.Module
?
- 组织网络:通过
nn.Module
,我们可以简单地定义并组织神经网络的各个层。它提供了一种结构化的方式来把常见的计算组件(如卷积层、线性层等)拼接在一起。参数管理:nn.Module
自动管理所有定义的层和它们的可训练参数,这让我们免去手动处理每个参数。设备管理:可以轻松地把整个网络模型放到 GPU 或 CPU 上运行三、使用步骤
1.创建一个新类继承 nn.Module
import torchimport torch.nn as nnclass SimpleNet(nn.Module): def __init__(self): super(SimpleNet, self).__init__() # 调用父类的初始化方法 self.fc1 = nn.Linear(784, 128) # 第一层,全连接层 self.fc2 = nn.Linear(128, 64) # 第二层,全连接层 self.out = nn.Linear(64, 10) # 输出层,全连接层
2.定义前向传播 forward
方法
def forward(self, x): x = torch.flatten(x, 1) # 展平成一维 x = torch.relu(self.fc1(x)) # 通过第一层并进行 ReLU 激活 x = torch.relu(self.fc2(x)) # 通过第二层并进行 ReLU 激活 x = self.out(x) # 计算输出 return x
3.使用模型
# 创建模型实例model = SimpleNet()# 将模型移到 GPU(如果可用)if torch.cuda.is_available(): model = model.to('cuda')