# LSTM用于分类任务的完整指南
## 1. LSTM分类基础架构
### 1.1 基本分类流程
```python
import torch
import torch.nn as nn
class LSTMClassifier(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, num_classes):
super(LSTMClassifier, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
# LSTM层
self.lstm = nn.LSTM(input_size, hidden_size, num_layers,
batch_first=True, dropout=0.2)
# 分类层
self.fc = nn.Linear(hidden_size, num_classes)
self.dropout = nn.Dropout(0.5)
def forward(self, x):
# 初始化隐藏状态
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
# LSTM前向传播
out, (hn, cn) = self.lstm(x, (h0, c0))
# 取最后一个时间步的输出
last_output = out[:, -1, :]
# 分类
out = self.dropout(last_output)
out = self.fc(out)
return out
```
## 2. 不同序列分类策略
### 2.1 输出策略对比
| 策略 | 适用场景 | 代码实现 |
|------|----------|----------|
| **最后时间步** | 序列整体分类 | `out[:, -1, :]` |
| **时间步平均** | 平稳序列分类 | `torch.mean(out, dim=1)` |
| **注意力池化** | 重要特征提取 | 自定义注意力机制 |
| **最大池化** | 突出显著特征 | `torch.max(out, dim=1)[0]` |
### 2.2 注意力池化实现
```python
class AttentionLSTMClassifier(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(AttentionLSTMClassifier, self).__init__()
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True, bidirectional=True)
self.attention = nn.Sequential(
nn.Linear(hidden_size * 2, 64),
nn.Tanh(),
nn.Linear(64, 1)
)
self.fc = nn.Linear(hidden_size * 2, num_classes)
def forward(self, x):
# LSTM输出 [batch, seq_len, hidden*2]
lstm_out, _ = self.lstm(x)
# 注意力权重 [batch, seq_len, 1]
attention_weights = self.attention(lstm_out)
attention_weights = torch.softmax(attention_weights, dim=1)
# 加权求和 [batch, hidden*2]
context_vector = torch.sum(attention_weights * lstm_out, dim=1)
# 分类
output = self.fc(context_vector)
return output
```
## 3. 计算机视觉中的LSTM分类应用
### 3.1 视频动作识别
```python
class VideoActionClassifier(nn.Module):
def __init__(self, cnn_feature_size, lstm_hidden, num_classes):
super(VideoActionClassifier, self).__init__()
# 使用预训练的CNN提取帧特征
self.cnn = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d((1, 1))
)
# LSTM处理时序特征
self.lstm = nn.LSTM(cnn_feature_size, lstm_hidden, 2,
batch_first=True, dropout=0.3)
self.classifier = nn.Linear(lstm_hidden, num_classes)
def forward(self, video_clip):
# video_clip: [batch, frames, 3, H, W]
batch_size, num_frames = video_clip.shape[0], video_clip.shape[1]
# 提取每帧特征
frame_features = []
for t in range(num_frames):
frame = video_clip[:, t] # [batch, 3, H, W]
feature = self.cnn(frame) # [batch, 128, 1, 1]
feature = feature.view(batch_size, -1) # [batch, 128]
frame_features.append(feature)
# 堆叠特征序列 [batch, frames, 128]
sequence = torch.stack(frame_features, dim=1)
# LSTM处理
lstm_out, _ = self.lstm(sequence)
# 分类
last_output = lstm_out[:, -1, :]
output = self.classifier(last_output)
return output
```
### 3.2 时序图像分类(医疗影像)
```python
class MedicalSequenceClassifier(nn.Module):
"""用于CT/MRI序列分析的LSTM分类器"""
def __init__(self, input_size, hidden_size, num_classes):
super(MedicalSequenceClassifier, self).__init__()
self.lstm = nn.LSTM(input_size, hidden_size, 2,
batch_first=True, bidirectional=True)
self.fc1 = nn.Linear(hidden_size * 2, hidden_size)
self.fc2 = nn.Linear(hidden_size, num_classes)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.4)
def forward(self, x):
# x: [batch, time_steps, feature_size]
lstm_out, (hn, cn) = self.lstm(x)
# 使用双向LSTM的最后状态
forward_final = hn[-2] # 前向最后隐藏状态
backward_final = hn[-1] # 后向最后隐藏状态
final_state = torch.cat([forward_final, backward_final], dim=1)
# 多层分类器
x = self.relu(self.fc1(final_state))
x = self.dropout(x)
x = self.fc2(x)
return x
```
## 4. 实际训练技巧
### 4.1 数据准备与预处理
```python
def prepare_sequence_data(data, labels, sequence_length):
"""
准备序列数据用于LSTM训练
data: [num_samples, total_timesteps, features]
labels: [num_samples]
"""
sequences = []
target_labels = []
for i in range(len(data)):
# 创建滑动窗口序列
for start_idx in range(0, data[i].shape[0] - sequence_length + 1):
end_idx = start_idx + sequence_length
sequence = data[i][start_idx:end_idx]
sequences.append(sequence)
target_labels.append(labels[i])
return torch.stack(sequences), torch.tensor(target_labels)
# 示例使用
# sequences, labels = prepare_sequence_data(training_data, training_labels, seq_len=10)
```
### 4.2 训练循环示例
```python
def train_lstm_classifier(model, train_loader, val_loader, num_epochs=100):
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5)
for epoch in range(num_epochs):
model.train()
train_loss = 0.0
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
# 梯度裁剪防止爆炸
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
train_loss += loss.item()
# 验证阶段
model.eval()
val_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for data, target in val_loader:
output = model(data)
val_loss += criterion(output, target).item()
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
val_acc = 100 * correct / total
scheduler.step(val_loss)
print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss/len(train_loader):.4f}, '
f'Val Loss: {val_loss/len(val_loader):.4f}, Val Acc: {val_acc:.2f}%')
```
## 5. 超参数调优指南
### 5.1 关键超参数
```python
# 超参数配置示例
hyperparameters = {
'hidden_size': [64, 128, 256], # 隐藏层维度
'num_layers': [1, 2, 3], # LSTM层数
'learning_rate': [0.001, 0.0005, 0.0001],
'batch_size': [16, 32, 64],
'dropout_rate': [0.2, 0.3, 0.5], # 防止过拟合
'sequence_length': [10, 20, 30] # 序列长度
}
```
### 5.2 模型评估指标
```python
from sklearn.metrics import classification_report, confusion_matrix
def evaluate_model(model, test_loader, class_names):
model.eval()
all_predictions = []
all_targets = []
with torch.no_grad():
for data, target in test_loader:
output = model(data)
_, predicted = torch.max(output, 1)
all_predictions.extend(predicted.cpu().numpy())
all_targets.extend(target.cpu().numpy())
print("Classification Report:")
print(classification_report(all_targets, all_predictions,
target_names=class_names))
print("\nConfusion Matrix:")
print(confusion_matrix(all_targets, all_predictions))
```
## 6. 常见问题与解决方案
### 6.1 过拟合处理
- **数据增强**:对序列数据进行时间扭曲、添加噪声
- **正则化**:L2正则化、Dropout、Early Stopping
- **简化模型**:减少LSTM层数或隐藏单元数
### 6.2 梯度问题
```python
# 梯度裁剪和监控
def check_gradients(model):
total_norm = 0
for p in model.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
return total_norm
# 在训练循环中添加
grad_norm = check_gradients(model)
if grad_norm > 1.0:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
```
### 6.3 序列长度选择
- **短序列**:计算效率高,可能丢失长期依赖
- **长序列**:能捕获长期模式,但容易过拟合
- **动态长度**:使用pack_padded_sequence处理变长序列
LSTM在分类任务中的优势在于能够同时考虑时序依赖和空间特征,特别适合处理视频、医疗影像序列、传感器数据等具有时间维度的计算机视觉任务。关键是根据具体任务选择合适的输出策略和网络架构。