import torch
import torch.nn as nn
import torch.optim as optim
# 定义一个简单的CNN模型
class SimpleCNN(nn.Module):
def __init__(self, num_classes=10):
super(SimpleCNN, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(16, 32, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Conv2d(32, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2)
)
self.classifier = nn.Sequential(
nn.Linear(64 * 8 * 8, 512), # 假设经过卷积和池化后特征图大小为8x8
nn.ReLU(),
nn.Dropout(p=0.5), # 添加dropout以防止过拟合
nn.Linear(512, num_classes) # 输出层,节点数量等于类别数
)
def forward(self, x):
x = self.features(x)
x = x.view(-1, 64 * 8 * 8) # 将卷积后的输出展平
x = self.classifier(x)
return x
# 创建模型实例
model = SimpleCNN()
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# 在训练过程中,你需要按照以下步骤进行:
# 1. 前向传播
# 2. 计算损失
# 3. 反向传播和优化
# 4. 更新模型参数
# 这些步骤通常会在一个训练循环中完成,循环遍历你的整个数据集。