<think>
我们正在处理一个复杂的问题:在Jupyter Notebook中使用PyTorch实现CNN来识别手机照片中的多个数字(类似于MNIST),并使用Matplotlib进行可视化。
步骤分解:
1. 环境设置:确保Jupyter Notebook可以正确使用PyTorch环境(参考引用[2]中的问题,但我们已经假设环境配置好了)
2. 数据准备:手机照片中的多个数字,需要先进行预处理(如分割成单个数字)并转换成MNIST格式(灰度、28x28、归一化等)
3. 模型构建:构建CNN模型(类似于LeNet或更复杂的结构)
4. 训练模型:使用MNIST数据集训练模型(这里我们假设使用MNIST训练,然后用于手机照片的识别,注意域适应问题可能需要微调)
5. 测试:使用手机照片,进行分割和预处理,然后输入模型进行预测
6. 可视化:使用Matplotlib展示原始图片、分割后的单个数字以及预测结果
然而,问题中提到“手机照片中的多个数字”,所以我们需要额外的步骤:数字检测和分割。但原问题主要关注识别,我们可以假设已经分割好了?或者我们需要一个简单的分割方法?
考虑到时间,我们将重点放在识别上,并假设图片已经被分割成单个数字(每个数字一张图)。如果原图有多个数字,我们可以使用轮廓检测等方法进行分割(但这部分可能需要OpenCV,不是重点)。这里我们只处理单个数字的识别,多个数字则循环处理每个分割后的数字。
步骤详细:
一、环境配置(已解决,跳过)
二、数据准备
1. 加载MNIST数据集进行训练
2. 对手机照片进行预处理(假设已经分割成单个数字):
- 转换为灰度图
- 调整大小为28x28
- 归一化(和MNIST相同:0-1范围,然后标准化为均值为0.1307,标准差为0.3081?或者使用训练集的均值和标准差)
- 反转颜色(因为MNIST是黑底白字,而手机照片可能是白底黑字,需要反转)
三、构建CNN模型
使用一个简单的CNN结构,例如:
Conv1: 输入1通道,输出32通道,5x5卷积,ReLU
MaxPool: 2x2
Conv2: 输入32,输出64,5x5,ReLU
MaxPool: 2x2
FC1: 输入64*4*4(经过两次池化后为4x4? 需要计算尺寸),输出1024,ReLU
Dropout
FC2: 输出10(10个数字)
四、训练模型(使用MNIST训练集)
五、测试模型(使用MNIST测试集评估)
六、使用模型预测手机照片中的数字(单个)
七、可视化:使用Matplotlib展示原始手机图片(分割前的)、分割后的每个数字以及预测结果。
由于在Jupyter Notebook中,我们可以逐步执行代码并显示图像。
但是,问题要求识别多个数字,所以我们需要在预处理步骤中分割图片中的多个数字。这里我们使用OpenCV进行简单的分割(如果环境没有,需要安装,或者尝试用PIL和numpy实现)。如果用户不希望使用OpenCV,我们可以尝试用scikit-image等,但为了简化,我们使用OpenCV(因为常用且方便)。
如果无法使用OpenCV,我们可以提供另一种方法:使用连通区域分析(通过scipy.ndimage)?但考虑到通用性,我们假设可以安装OpenCV。
因此,整体步骤:
1. 安装必要的库(如果还没有):
!pip install opencv-python torch torchvision matplotlib numpy
2. 导入库
3. 加载和预处理MNIST数据,训练CNN模型(或者直接加载预训练模型,如果已有)
4. 手机照片预处理(读取、灰度化、二值化、找轮廓、分割每个数字)
注意:手机照片可能背景复杂,所以需要简单假设:数字是深色,背景是浅色,且数字之间不粘连。
5. 对每个分割的数字进行预处理(调整为28x28,归一化,反转颜色等),然后输入模型预测。
6. 使用Matplotlib展示原图、分割后的每个数字(小图)并在旁边标注预测结果。
由于代码较长,我们分块展示,并给出关键步骤的注释。
注意:由于Jupyter Notebook环境,我们可以使用%matplotlib inline
开始代码实现:
</think>
下面我将详细指导你在 Jupyter Notebook 中使用 PyTorch 实现 CNN 识别手机照片中的多个数字,并使用 Matplotlib 进行可视化:
### 1. 环境配置与库导入
```python
%matplotlib inline
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import cv2
from PIL import Image
```
### 2. 数据预处理流程
```python
# 定义图像转换管道
transform = transforms.Compose([
transforms.Grayscale(num_output_channels=1), # 转为灰度图
transforms.Resize((28, 28)), # 调整尺寸
transforms.ToTensor(), # 转为Tensor
transforms.Normalize((0.1307,), (0.3081,)) # MNIST标准化参数
])
```
### 3. CNN 模型构建 (LeNet 改进版)
```python
class DigitCNN(nn.Module):
def __init__(self):
super(DigitCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=5, padding=2)
self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
self.fc1 = nn.Linear(64*5*5, 256)
self.fc2 = nn.Linear(256, 10)
self.dropout = nn.Dropout(0.5)
def forward(self, x):
x = torch.relu(self.conv1(x)) # (32, 28, 28)
x = torch.max_pool2d(x, 2) # (32, 14, 14)
x = torch.relu(self.conv2(x)) # (64, 10, 10)
x = torch.max_pool2d(x, 2) # (64, 5, 5)
x = x.view(-1, 64*5*5) # 展平
x = torch.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x
model = DigitCNN()
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
```
### 4. 训练模型 (使用MNIST数据集)
```python
# 加载MNIST数据集
train_set = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
# 训练循环
for epoch in range(5):
for images, labels in train_loader:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
print(f'Epoch [{epoch+1}/5], Loss: {loss.item():.4f}')
```
### 5. 手机照片处理与数字识别
```python
def recognize_digits(image_path, model):
# 1. 读取并预处理手机照片
img = cv2.imread(image_path)
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# 2. 二值化处理
_, thresh = cv2.threshold(gray, 120, 255, cv2.THRESH_BINARY_INV)
# 3. 查找数字轮廓
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# 4. 数字识别与可视化
plt.figure(figsize=(15, 5))
plt.subplot(131), plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
plt.title('原始图片'), plt.axis('off')
plt.subplot(132), plt.imshow(thresh, cmap='gray')
plt.title('二值化处理'), plt.axis('off')
# 5. 对每个数字单独处理
results = []
for i, cnt in enumerate(contours):
x, y, w, h = cv2.boundingRect(cnt)
if w > 10 and h > 20: # 过滤噪点
# 裁剪单个数字
digit_img = gray[y:y+h, x:x+w]
# 转换为MNIST格式
digit_pil = Image.fromarray(digit_img).resize((20, 20))
digit_tensor = transform(digit_pil).unsqueeze(0)
# 预测数字
with torch.no_grad():
output = model(digit_tensor)
pred = output.argmax(dim=1).item()
results.append((x, pred))
# 在原图上标注
cv2.rectangle(img, (x, y), (x+w, y+h), (0, 255, 0), 2)
cv2.putText(img, str(pred), (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
plt.subplot(133), plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
plt.title('识别结果'), plt.axis('off')
plt.tight_layout()
plt.show()
return sorted(results, key=lambda x: x[0]) # 按x坐标排序
# 使用示例
results = recognize_digits('mobile_photo.jpg', model)
print("识别结果:", [digit for _, digit in results])
```
### 关键要点说明
1. **预处理流程**:
- 手机照片需转换为灰度图并二值化
- 通过轮廓检测分割多个数字
- 每个数字需调整为28×28像素并标准化
2. **模型优化**:
- 使用Dropout(0.5)防止过拟合
- Adam优化器加速收敛
- 双卷积层+池化提取特征
3. **可视化设计**:
- 原始图片、二值化效果和识别结果三图对比
- 绿色框标记数字区域
- 蓝色文本标注识别结果
4. **手机照片要求**:
- 数字与背景需有足够对比度
- 数字间保持适当间距
- 避免复杂背景干扰
### 常见问题解决
1. **数字分割失败**:
- 调整二值化阈值(120)
- 添加形态学操作(膨胀/腐蚀)
```python
kernel = np.ones((3,3), np.uint8)
thresh = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel)
```
2. **识别准确率低**:
- 增加训练轮次(5→10)
- 使用数据增强
```python
transform = transforms.Compose([
transforms.RandomRotation(10),
transforms.RandomAffine(0, scale=(0.9, 1.1)),
# ...其他转换
])
```
3. **环境配置问题**:
- 确保Jupyter内核使用正确环境
```bash
# 创建环境
conda create -n pytorch_env python=3.8
conda activate pytorch_env
conda install pytorch torchvision -c pytorch
conda install jupyter
python -m ipykernel install --user --name=pytorch_env
```
### 效果展示
执行后生成三幅图像:
1. 原始手机照片
2. 二值化处理结果
3. 带识别结果的标注图
最终输出按从左到右排序的数字序列