您发现了一个重要的问题!确实缺少了 `four_point_transform` 函数的定义。让我提供一个完整的、修正后的版本,包含所有必要的函数:
```python
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import numpy as np
import cv2
import os
import matplotlib.pyplot as plt
from torchvision import transforms
# === 1. 模型定义(保持不变)===
class EnhancedCNN(nn.Module):
def __init__(self):
super(EnhancedCNN, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.MaxPool2d(2, 2)
)
self.classifier = nn.Sequential(
nn.Linear(256 * 3 * 3, 512),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(256, 10)
)
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
# === 2. 四点排序函数(新增)===
def order_points(pts):
"""
对四个点进行排序,返回 [top-left, top-right, bottom-right, bottom-left]
"""
pts = pts.reshape(4, 2)
rect = np.zeros((4, 2), dtype="float32")
# 按坐标和排序:左上角和最小,右下角和最大
s = pts.sum(axis=1)
rect[0] = pts[np.argmin(s)] # 左上角
rect[2] = pts[np.argmax(s)] # 右下角
# 按坐标差排序:右上角差最小,左下角差最大
diff = np.diff(pts, axis=1)
rect[1] = pts[np.argmin(diff)] # 右上角
rect[3] = pts[np.argmax(diff)] # 左下角
return rect
# === 3. 透视变换函数(新增)===
def four_point_transform(image, pts):
"""
四点透视变换,将文档矫正为正视图
"""
# 获取有序的四个点
rect = order_points(pts)
(tl, tr, br, bl) = rect
# 计算新图像的宽度和高度
width_a = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
width_b = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
max_width = int(max(width_a, width_b))
height_a = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
height_b = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
max_height = int(max(height_a, height_b))
# 目标坐标:矩形的四个角
dst = np.array([
[0, 0],
[max_width - 1, 0],
[max_width - 1, max_height - 1],
[0, max_height - 1]
], dtype="float32")
# 计算透视变换矩阵
M = cv2.getPerspectiveTransform(rect, dst)
# 应用变换
warped = cv2.warpPerspective(image, M, (max_width, max_height))
return warped, max_width, max_height
# === 4. 查找文档轮廓 ===
def find_document_contour(edged_img):
"""
在边缘图像中查找最大的四边形轮廓(文档外框)
"""
# 查找所有轮廓
contours, _ = cv2.findContours(edged_img.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# 按面积降序排列
contours = sorted(contours, key=cv2.contourArea, reverse=True)
# 遍历轮廓
for cnt in contours:
# 过滤太小的轮廓
if cv2.contourArea(cnt) < 5000:
continue
# 计算轮廓周长
peri = cv2.arcLength(cnt, True)
# 多边形逼近
approx = cv2.approxPolyDP(cnt, 0.02 * peri, True)
# 如果是四边形
if len(approx) == 4:
return approx
return None
# === 5. 增强版图像预处理 ===
def enhance_image_advanced(gray):
"""
高级图像增强,解决对比度和边缘干扰问题
"""
# 步骤1: CLAHE增强(局部对比度)
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
enhanced = clahe.apply(gray)
# 步骤2: 形态学开运算去除小噪点
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
opened = cv2.morphologyEx(enhanced, cv2.MORPH_OPEN, kernel)
# 步骤3: 双边滤波(保边去噪)
filtered = cv2.bilateralFilter(opened, 9, 75, 75)
return filtered
# === 6. 改进的内容检测 ===
def detect_content_advanced(roi_gray, margin_ratio=0.15):
"""
改进的内容检测,排除边缘区域的影响
"""
h, w = roi_gray.shape
# 定义内部区域(排除边缘)
margin_h = int(h * margin_ratio)
margin_w = int(w * margin_ratio)
if margin_h >= h//2 or margin_w >= w//2:
inner_roi = roi_gray
else:
inner_roi = roi_gray[margin_h:h-margin_h, margin_w:w-margin_w]
# 计算内部区域的统计信息
mean_intensity = np.mean(inner_roi)
std_intensity = np.std(inner_roi)
# 使用自适应阈值检测前景像素
_, binary = cv2.threshold(inner_roi, 0, 255,
cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
foreground_pixels = cv2.countNonZero(binary)
total_pixels = inner_roi.size
foreground_ratio = foreground_pixels / total_pixels
return {
'mean': mean_intensity,
'std': std_intensity,
'foreground_ratio': foreground_ratio,
'inner_roi': inner_roi,
'binary': binary
}
# === 7. 智能裁剪和标准化 ===
def smart_crop_and_normalize(roi_gray, target_size=(28, 28)):
"""
智能裁剪:自动找到数字边界并居中
"""
# 创建副本进行处理
processed = roi_gray.copy()
# 应用更强的对比度
_, binary = cv2.threshold(processed, 0, 255,
cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
# 找到数字的边界框
coords = cv2.findNonZero(binary)
if coords is None or len(coords) < 10: # 太少像素,可能是噪声
return np.zeros(target_size, dtype=np.uint8)
x, y, w, h = cv2.boundingRect(coords)
# 添加一些padding
padding = max(w, h) // 8
x = max(0, x - padding)
y = max(0, y - padding)
w = min(processed.shape[1] - x, w + 2*padding)
h = min(processed.shape[0] - y, h + 2*padding)
# 提取数字区域
digit_roi = processed[y:y+h, x:x+w]
# 缩放到目标大小
resized = cv2.resize(digit_roi, target_size, interpolation=cv2.INTER_AREA)
# 进一步增强对比度
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(2, 2))
final = clahe.apply(resized)
return final
# === 8. 改进的预处理可视化 ===
def advanced_preprocess_visualization(roi_gray, cell_index):
"""
显示完整的预处理流程,帮助诊断问题
"""
fig, axes = plt.subplots(3, 4, figsize=(16, 12))
fig.suptitle(f'单元格 {cell_index} 高级预处理流程', fontsize=16)
# 原始图像
axes[0, 0].imshow(roi_gray, cmap='gray')
axes[0, 0].set_title('原始ROI')
axes[0, 0].axis('off')
# 直方图
axes[0, 1].hist(roi_gray.ravel(), bins=256, range=[0, 256], alpha=0.7)
axes[0, 1].set_title('原始直方图')
axes[0, 1].set_xlim([0, 256])
# CLAHE增强
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
enhanced = clahe.apply(roi_gray)
axes[0, 2].imshow(enhanced, cmap='gray')
axes[0, 2].set_title('CLAHE增强')
axes[0, 2].axis('off')
# 增强后的直方图
axes[0, 3].hist(enhanced.ravel(), bins=256, range=[0, 256], alpha=0.7, color='orange')
axes[0, 3].set_title('增强后直方图')
axes[0, 3].set_xlim([0, 256])
# 内部区域(排除边缘)
h, w = roi_gray.shape
margin_h, margin_w = h//6, w//6
inner = roi_gray[margin_h:h-margin_h, margin_w:w-margin_w]
axes[1, 0].imshow(inner, cmap='gray')
axes[1, 0].set_title('内部区域\n(排除边缘)')
axes[1, 0].axis('off')
# 自适应阈值
adaptive = cv2.adaptiveThreshold(enhanced, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY, 11, 2)
axes[1, 1].imshow(adaptive, cmap='gray')
axes[1, 1].set_title('自适应阈值')
axes[1, 1].axis('off')
# Otsu阈值
_, otsu = cv2.threshold(enhanced, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
axes[1, 2].imshow(otsu, cmap='gray')
axes[1, 2].set_title('Otsu阈值')
axes[1, 2].axis('off')
# 形态学闭运算
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
closed = cv2.morphologyEx(otsu, cv2.MORPH_CLOSE, kernel)
axes[1, 3].imshow(closed, cmap='gray')
axes[1, 3].set_title('形态学闭运算')
axes[1, 3].axis('off')
# 智能裁剪前
axes[2, 0].imshow(enhanced, cmap='gray')
axes[2, 0].set_title('智能裁剪输入')
axes[2, 0].axis('off')
# 数字边界框
_, binary_for_bbox = cv2.threshold(enhanced, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
coords = cv2.findNonZero(binary_for_bbox)
if coords is not None and len(coords) > 10:
x, y, w, h = cv2.boundingRect(coords)
bbox_img = cv2.cvtColor(enhanced, cv2.COLOR_GRAY2BGR)
cv2.rectangle(bbox_img, (x, y), (x+w, y+h), (0, 255, 0), 2)
axes[2, 1].imshow(bbox_img)
else:
axes[2, 1].imshow(enhanced, cmap='gray')
axes[2, 1].set_title('数字边界框')
axes[2, 1].axis('off')
# 智能裁剪结果
cropped = smart_crop_and_normalize(enhanced)
axes[2, 2].imshow(cropped, cmap='gray')
axes[2, 2].set_title('智能裁剪输出\n(28x28)')
axes[2, 2].axis('off')
# 最终输入模型
final_input = 255 - cropped # 颜色反转为黑底白字
axes[2, 3].imshow(final_input, cmap='gray')
axes[2, 3].set_title('最终输入\n(黑底白字)')
axes[2, 3].axis('off')
plt.tight_layout()
plt.savefig(f'advanced_preprocess_{cell_index}.png', dpi=300, bbox_inches='tight')
plt.show()
return final_input
# === 9. 网格分割 ===
def split_into_grid(img_gray, num_rows=13, num_cols=9):
"""
将矫正后的图像分割成网格
"""
h, w = img_gray.shape
cell_h, cell_w = h // num_rows, w // num_cols
cells = []
for r in range(num_rows):
for c in range(num_cols):
x, y = c * cell_w, r * cell_h
w_ = cell_w if c < num_cols - 1 else w - x
h_ = cell_h if r < num_rows - 1 else h - y
cells.append((x, y, w_, h_))
return cells, cell_w, cell_h
# === 10. 加载模型 ===
def load_model(model_path, device):
"""
加载训练好的模型
"""
model = EnhancedCNN().to(device)
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)
model.eval()
return model
# === 11. 改进的预测函数 ===
def predict_digit_improved(model, device, roi_gray, cell_index):
"""
改进的数字预测,包含多种策略
"""
print(f"\n--- 分析单元格 {cell_index} ---")
# 内容检测(排除边缘)
content_info = detect_content_advanced(roi_gray)
print(f"内部区域统计:")
print(f" 平均强度: {content_info['mean']:.1f}")
print(f" 标准差: {content_info['std']:.1f}")
print(f" 前景比例: {content_info['foreground_ratio']:.3f}")
# 如果内容太少,跳过
if content_info['foreground_ratio'] < 0.02:
print(" ⚠️ 内容太少,可能是空白或噪声")
return None, 0.0
# 高级预处理和可视化
final_input = advanced_preprocess_visualization(roi_gray, cell_index)
# 转换为张量
tensor = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])(final_input).unsqueeze(0).to(device)
# 预测
with torch.no_grad():
output = model(tensor)
prob = torch.nn.functional.softmax(output, dim=1)
confidence, pred_label = torch.max(prob, 1)
# 获取前3个最高概率
top3_prob, top3_labels = torch.topk(prob, 3)
print(f"\n预测结果:")
for i in range(3):
print(f" Top-{i+1}: 数字 {top3_labels[0][i].item()} "
f"(置信度: {top3_prob[0][i].item():.3f})")
return pred_label.item(), confidence.item()
# === 12. 主函数 ===
def improved_recognition_pipeline(image_path, model_path):
"""
主识别流程
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# 加载模型
try:
model = load_model(model_path, device)
except Exception as e:
print(f"加载模型失败: {e}")
return
# 读取图像
orig = cv2.imread(image_path)
if orig is None:
raise FileNotFoundError(f"无法读取图像: {image_path}")
# 预处理
gray = cv2.cvtColor(orig, cv2.COLOR_BGR2GRAY)
enhanced = enhance_image_advanced(gray)
edged = cv2.Canny(enhanced, 50, 150)
# 文档定位
doc_contour = find_document_contour(edged)
if doc_contour is None:
print("⚠️ 未检测到外框,使用全图处理")
corrected_gray = enhanced
else:
corrected_gray, _, _ = four_point_transform(enhanced, doc_contour)
# 网格分割
cells, _, _ = split_into_grid(corrected_gray, 13, 9)
# 处理第一个有内容的单元格
for idx, (x, y, w, h) in enumerate(cells[:5]):
margin = int(min(w, h) * 0.1)
ix, iy = x + margin, y + margin
iw, ih = w - 2 * margin, h - 2 * margin
if iw <= 5 or ih <= 5:
continue
roi = corrected_gray[iy:iy + ih, ix:ix + iw]
# 使用改进的内容检测
content_info = detect_content_advanced(roi)
if 0.02 <= content_info['foreground_ratio'] <= 0.8:
print(f"\n找到待识别的单元格: {idx+1}")
pred, conf = predict_digit_improved(model, device, roi, idx+1)
if pred is not None:
print(f"最终预测: {pred}, 置信度: {conf:.3f}")
break
# === 运行入口 ===
if __name__ == "__main__":
IMAGE_PATH = os.path.abspath("111.jpg")
MODEL_PATH = os.path.abspath("enhanced_mnist_cnn.pth")
if not os.path.exists(IMAGE_PATH):
print(f"错误: 图片不存在 {IMAGE_PATH}")
elif not os.path.exists(MODEL_PATH):
print(f"错误: 模型不存在 {MODEL_PATH}")
else:
improved_recognition_pipeline(IMAGE_PATH, MODEL_PATH)
```
### 主要修复内容:
1. **添加了缺失的函数**:
- `order_points(pts)`:对四个点进行排序
- `four_point_transform(image, pts)`:执行透视变换
2. **完善了所有依赖关系**:
- 确保每个函数都在使用前已定义
- 添加了详细的注释说明
3. **保持功能完整性**:
- 所有原有的高级功能都保留
- 包括对比度增强、边缘排除、智能裁剪等
现在这段代码应该可以完整运行,解决了您提到的所有问题。