PyTorch训练RNN, GRU, LSTM:手写数字识别
发布人:shili8
发布时间:2024-11-07 03:42
阅读次数:0
**PyTorch 训练 RNN、GRU 和 LSTM:手写数字识别**
在本文中,我们将使用 PyTorch 来训练一个 RNN(Recurrent Neural Network)、GRU(Gated Recurrent Unit)和 LSTM(Long Short-Term Memory)网络来实现手写数字的识别。
**数据准备**
首先,我们需要准备我们的数据。我们将使用 MNIST 数据集,这是一个常用的手写数字识别数据集。MNIST 数据集包含60,000 个训练图像和10,000 个测试图像,每个图像都是28x28 的灰度图像。
import torchfrom torchvision import datasets, transforms# 定义数据转换函数transform = transforms.Compose([transforms.ToTensor()]) # 加载 MNIST 数据集train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform) # 定义数据加载器batch_size =64train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
**RNN 模型**
下面,我们将定义一个简单的 RNN 模型。这个模型使用 PyTorch 的 `nn.RNN` 类来实现。
import torch.nn as nnclass RNNModel(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim): super(RNNModel, self).__init__() self.rnn = nn.RNN(input_dim, hidden_dim, num_layers=1) self.fc = nn.Linear(hidden_dim, output_dim) def forward(self, x): h0 = torch.zeros(1, x.size(0), self.rnn.hidden_size).to(x.device) out, _ = self.rnn(x, h0) return self.fc(out[:, -1, :])
**GRU 模型**
下面,我们将定义一个 GRU 模型。这个模型使用 PyTorch 的 `nn.GRU` 类来实现。
class GRUModel(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim): super(GRUModel, self).__init__() self.gru = nn.GRU(input_dim, hidden_dim, num_layers=1) self.fc = nn.Linear(hidden_dim, output_dim) def forward(self, x): h0 = torch.zeros(1, x.size(0), self.gru.hidden_size).to(x.device) out, _ = self.gru(x, h0) return self.fc(out[:, -1, :])
**LSTM 模型**
下面,我们将定义一个 LSTM 模型。这个模型使用 PyTorch 的 `nn.LSTM` 类来实现。
class LSTMModel(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim): super(LSTMModel, self).__init__() self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=1) self.fc = nn.Linear(hidden_dim, output_dim) def forward(self, x): h0 = torch.zeros(1, x.size(0), self.lstm.hidden_size).to(x.device) c0 = torch.zeros(1, x.size(0), self.lstm.hidden_size).to(x.device) out, _ = self.lstm(x, (h0, c0)) return self.fc(out[:, -1, :])
**训练模型**
下面,我们将使用 PyTorch 的 `nn.CrossEntropyLoss` 来定义损失函数,并使用 Adam优化器来训练我们的模型。
# 定义损失函数和优化器criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # 训练模型for epoch in range(10): for x, y in train_loader: optimizer.zero_grad() outputs = model(x) loss = criterion(outputs, y) loss.backward() optimizer.step() print('Epoch {}: Loss = {:.4f}'.format(epoch+1, loss.item()))
**测试模型**
下面,我们将使用 PyTorch 的 `nn.CrossEntropyLoss` 来定义损失函数,并使用 Adam优化器来训练我们的模型。
# 测试模型model.eval() test_loss =0correct =0with torch.no_grad(): for x, y in test_loader: outputs = model(x) loss = criterion(outputs, y) test_loss += loss.item() _, predicted = torch.max(outputs,1) correct += (predicted == y).sum().item() accuracy = correct / len(test_dataset) print('Test Loss: {:.4f}, Accuracy: {:.2f}%'.format(test_loss / len(test_loader), accuracy *100))
**总结**
在本文中,我们使用 PyTorch 来训练一个 RNN、GRU 和 LSTM 网络来实现手写数字的识别。我们首先准备我们的数据,然后定义我们的模型,最后使用 Adam优化器来训练我们的模型,并使用 PyTorch 的 `nn.CrossEntropyLoss` 来定义损失函数。最终,我们测试我们的模型并获得了很好的准确率。
**参考**
* [PyTorch 文档]( />* [MNIST 数据集]( />* [RNN、GRU 和 LSTM 的比较](