# 深度学习项目训练环境实操手册:matplotlib/seaborn绘图脚本修改与结果导出
## 1. 环境准备与快速上手
深度学习项目训练完成后,结果可视化是分析模型性能的关键步骤。本镜像预装了完整的深度学习开发环境,集成了训练、推理及评估所需的所有依赖,包括matplotlib和seaborn等绘图库,开箱即用。
使用前只需激活conda环境,环境名称为dl:
```bash
conda activate dl
```
上传训练代码到数据盘后,进入代码目录:
```bash
cd /root/workspace/你的源码文件夹名称
```
环境已预装主要依赖:`pytorch==1.13.0`、`torchvision==0.14.0`、`matplotlib`、`seaborn`、`pandas`、`numpy`等,如需额外库可自行安装。
## 2. 训练结果可视化基础
### 2.1 理解训练日志文件
深度学习训练过程通常会产生日志文件,记录损失值、准确率等关键指标。常见的日志格式包括:
- CSV文件:包含epoch、loss、accuracy等列
- JSON文件:结构化存储训练指标
- 文本日志:自定义格式的训练输出
```python
import pandas as pd
# 读取训练日志CSV文件
log_data = pd.read_csv('runs/train/log.csv')
print(log_data.head()) # 查看前几行数据
```
### 2.2 matplotlib基础绘图
matplotlib是Python最常用的绘图库,提供灵活的绘图功能:
```python
import matplotlib.pyplot as plt
import numpy as np
# 创建简单折线图
epochs = range(1, len(log_data) + 1)
plt.figure(figsize=(10, 6))
plt.plot(epochs, log_data['train_loss'], label='Training Loss')
plt.plot(epochs, log_data['val_loss'], label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True)
plt.show()
```
### 2.3 seaborn高级可视化
seaborn基于matplotlib,提供更美观的统计图形:
```python
import seaborn as sns
# 设置seaborn样式
sns.set_style("whitegrid")
sns.set_palette("deep")
# 创建多子图对比
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
# 损失曲线
sns.lineplot(data=log_data, x='epoch', y='train_loss', ax=ax1, label='Train')
sns.lineplot(data=log_data, x='epoch', y='val_loss', ax=ax1, label='Validation')
ax1.set_title('Loss Curve')
# 准确率曲线
sns.lineplot(data=log_data, x='epoch', y='train_acc', ax=ax2, label='Train')
sns.lineplot(data=log_data, x='epoch', y='val_acc', ax=ax2, label='Validation')
ax2.set_title('Accuracy Curve')
plt.tight_layout()
plt.show()
```
## 3. 绘图脚本修改实战
### 3.1 修改现有绘图脚本
大多数深度学习项目都提供了绘图脚本,通常需要修改以下部分:
```python
# 原脚本可能是这样的
results_path = 'path/to/your/results' # 需要修改的路径
log_file = 'results.csv' # 需要修改的文件名
# 修改为你的实际路径
results_path = '/root/workspace/your_project/runs/train'
log_file = 'exp3/log.csv' # 根据你的实际结构修改
```
### 3.2 适应不同日志格式
如果你的日志格式与脚本预期不同,需要调整数据读取方式:
```python
# 如果你的日志格式不同,可以这样调整
def load_training_log(file_path):
try:
# 尝试多种格式
if file_path.endswith('.csv'):
data = pd.read_csv(file_path)
elif file_path.endswith('.json'):
data = pd.read_json(file_path)
else:
# 自定义文本格式处理
data = parse_custom_log(file_path)
# 统一列名
column_mapping = {
'epoch': 'epoch',
'loss': 'train_loss',
'val_loss': 'val_loss',
'accuracy': 'train_acc',
'val_accuracy': 'val_acc'
}
data = data.rename(columns=column_mapping)
return data
except Exception as e:
print(f"Error loading log file: {e}")
return None
```
### 3.3 自定义绘图样式
修改绘图样式以适应论文或报告要求:
```python
# 自定义绘图样式
plt.style.use('default') # 重置样式
# 设置中文字体(如果需要)
plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
# 自定义颜色 palette
custom_palette = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#F9A602', '#9B59B6']
sns.set_palette(custom_palette)
# 设置全局字体大小
plt.rcParams['font.size'] = 12
plt.rcParams['axes.titlesize'] = 16
plt.rcParams['axes.labelsize'] = 14
```
## 4. 高级可视化技巧
### 4.1 多实验对比可视化
当有多个训练实验时,对比分析很重要:
```python
# 比较多个实验结果
experiments = {
'Baseline': 'runs/exp1/log.csv',
'Data Augmentation': 'runs/exp2/log.csv',
'Modified Architecture': 'runs/exp3/log.csv'
}
plt.figure(figsize=(12, 5))
for exp_name, log_path in experiments.items():
data = pd.read_csv(log_path)
sns.lineplot(data=data, x='epoch', y='val_acc', label=exp_name)
plt.title('Validation Accuracy Comparison')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)
plt.show()
```
### 4.2 混淆矩阵可视化
对于分类任务,混淆矩阵是重要的评估工具:
```python
from sklearn.metrics import confusion_matrix
import itertools
def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion Matrix'):
"""
绘制混淆矩阵
"""
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='.2f' if normalize else 'd',
cmap='Blues', xticklabels=classes, yticklabels=classes)
plt.title(title)
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.tight_layout()
# 使用示例
y_true = [0, 1, 2, 0, 1, 2] # 替换为你的真实标签
y_pred = [0, 1, 1, 0, 0, 2] # 替换为你的预测标签
class_names = ['Class 0', 'Class 1', 'Class 2'] # 替换为你的类别名称
cm = confusion_matrix(y_true, y_pred)
plot_confusion_matrix(cm, class_names, normalize=True)
```
### 4.3 特征可视化
可视化模型学到的特征:
```python
# 特征分布可视化
def plot_feature_distributions(features, labels, class_names):
"""
绘制不同类别的特征分布
"""
n_features = features.shape[1]
n_cols = 3
n_rows = (n_features + n_cols - 1) // n_cols
fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, n_rows*4))
axes = axes.flatten()
for i in range(n_features):
for class_idx in range(len(class_names)):
class_mask = (labels == class_idx)
sns.histplot(features[class_mask, i], ax=axes[i],
label=class_names[class_idx], alpha=0.6, kde=True)
axes[i].set_title(f'Feature {i+1}')
axes[i].legend()
# 隐藏多余的子图
for j in range(i+1, len(axes)):
axes[j].set_visible(False)
plt.tight_layout()
return fig
```
## 5. 结果导出与保存
### 5.1 高质量图片导出
导出适合论文发表的高质量图片:
```python
def save_high_quality_plot(filename, dpi=300, format='png'):
"""
保存高质量图片
"""
plt.savefig(
filename,
dpi=dpi,
format=format,
bbox_inches='tight',
pad_inches=0.1,
facecolor='white',
edgecolor='none'
)
print(f"Plot saved as {filename}")
# 使用示例
plt.figure(figsize=(10, 6))
sns.lineplot(data=log_data, x='epoch', y='val_acc')
plt.title('Validation Accuracy')
save_high_quality_plot('val_accuracy.png', dpi=300)
```
### 5.2 多种格式导出
根据不同需求导出不同格式:
```python
def export_plots_multiple_formats(base_filename, plots_dir='plots'):
"""
导出多种格式的图片
"""
import os
os.makedirs(plots_dir, exist_ok=True)
formats = ['png', 'pdf', 'svg'] # 支持的格式
for fmt in formats:
filename = os.path.join(plots_dir, f"{base_filename}.{fmt}")
plt.savefig(
filename,
dpi=300,
format=fmt,
bbox_inches='tight',
facecolor='white'
)
print(f"Exported: {filename}")
# 使用示例
export_plots_multiple_formats('training_curves')
```
### 5.3 自动化导出脚本
创建自动化导出所有结果的脚本:
```python
def export_all_results(experiment_dir):
"""
自动导出实验的所有可视化结果
"""
import glob
import os
# 创建输出目录
output_dir = os.path.join(experiment_dir, 'visualizations')
os.makedirs(output_dir, exist_ok=True)
# 找到所有日志文件
log_files = glob.glob(os.path.join(experiment_dir, '**', '*.csv'), recursive=True)
for log_file in log_files:
# 生成对应的输出文件名
rel_path = os.path.relpath(log_file, experiment_dir)
base_name = os.path.splitext(rel_path)[0].replace(os.path.sep, '_')
# 生成并保存图表
try:
data = pd.read_csv(log_file)
# 损失曲线
plt.figure(figsize=(10, 6))
plt.plot(data['epoch'], data['train_loss'], label='Train Loss')
plt.plot(data['epoch'], data['val_loss'], label='Val Loss')
plt.legend()
plt.savefig(os.path.join(output_dir, f'{base_name}_loss.png'))
plt.close()
# 准确率曲线
plt.figure(figsize=(10, 6))
plt.plot(data['epoch'], data['train_acc'], label='Train Acc')
plt.plot(data['epoch'], data['val_acc'], label='Val Acc')
plt.legend()
plt.savefig(os.path.join(output_dir, f'{base_name}_accuracy.png'))
plt.close()
except Exception as e:
print(f"Error processing {log_file}: {e}")
print(f"All visualizations exported to {output_dir}")
```
## 6. 实用技巧与问题解决
### 6.1 常见绘图问题解决
```python
# 1. 处理内存不足问题
plt.figure(figsize=(10, 6))
# 绘制图表后及时关闭
plt.close('all') # 关闭所有图表释放内存
# 2. 处理中文显示问题
def setup_chinese_font():
"""
设置中文字体支持
"""
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
# 3. 批量处理多个图表
def batch_process_plots(plot_function, data_list, output_dir):
"""
批量处理多个图表
"""
import os
os.makedirs(output_dir, exist_ok=True)
for i, data in enumerate(data_list):
plt.figure(figsize=(10, 6))
plot_function(data)
plt.savefig(os.path.join(output_dir, f'plot_{i}.png'))
plt.close()
```
### 6.2 性能优化技巧
```python
# 对于大数据集,使用采样提高绘图性能
def plot_large_dataset_sampled(data, sample_frac=0.1, **kwargs):
"""
对大数据集进行采样后绘图
"""
if len(data) > 10000: # 如果数据量很大
sampled_data = data.sample(frac=sample_frac, random_state=42)
sns.lineplot(data=sampled_data, **kwargs)
else:
sns.lineplot(data=data, **kwargs)
# 使用示例
plt.figure(figsize=(12, 6))
plot_large_dataset_sampled(large_log_data, x='epoch', y='loss')
plt.title('Training Loss (Sampled)')
plt.show()
```
### 6.3 交互式可视化
```python
# 使用Plotly创建交互式图表(需要安装plotly)
def create_interactive_plot(data):
"""
创建交互式可视化图表
"""
try:
import plotly.express as px
fig = px.line(data, x='epoch', y=['train_loss', 'val_loss'],
title='Training and Validation Loss')
fig.show()
# 保存为HTML
fig.write_html('interactive_plot.html')
except ImportError:
print("Plotly not installed. Using matplotlib instead.")
# 回退到matplotlib
plt.figure(figsize=(10, 6))
plt.plot(data['epoch'], data['train_loss'], label='Train')
plt.plot(data['epoch'], data['val_loss'], label='Validation')
plt.legend()
plt.show()
```
## 7. 总结
通过本实操手册,你学会了如何在深度学习项目训练环境中使用matplotlib和seaborn进行结果可视化。关键要点包括:
1. **环境准备**:使用预配置的深度学习环境,快速开始可视化工作
2. **基础绘图**:掌握matplotlib和seaborn的基本绘图方法
3. **脚本修改**:学会调整现有绘图脚本以适应你的项目需求
4. **高级技巧**:使用多实验对比、混淆矩阵等高级可视化方法
5. **结果导出**:导出高质量图片用于论文、报告等不同场景
记住这些实用技巧:
- 总是及时关闭图表释放内存:`plt.close('all')`
- 使用`bbox_inches='tight'`避免图片边缘被裁剪
- 根据用途选择适当的导出格式和DPI设置
- 对于大数据集,使用采样提高绘图性能
可视化不仅是展示结果的手段,更是理解模型行为、发现问题、指导改进的重要工具。熟练掌握这些技能将大大提升你的深度学习项目开发效率。
---
> **获取更多AI镜像**
>
> 想探索更多AI镜像和应用场景?访问 [CSDN星图镜像广场](https://ai.csdn.net/?utm_source=mirror_blog_end),提供丰富的预置镜像,覆盖大模型推理、图像生成、视频生成、模型微调等多个领域,支持一键部署。