# Argoverse v2 轨迹预测实战:从数据加载到模型输入的完整工程化指南
如果你正在为自动驾驶轨迹预测项目寻找一个既权威又实用的数据集,Argoverse v2 大概率已经进入了你的视野。这个由 Argo AI 开源的大规模数据集,以其丰富的传感器数据、精细的地图标注和极具挑战性的城市驾驶场景,迅速成为了学术界和工业界进行运动预测研究的基准。然而,从官网下载的原始数据包到最终能喂给深度学习模型的标准张量,中间隔着一道不浅的“工程鸿沟”。数据格式陌生、坐标系复杂、地图信息庞大,这些都可能让急于验证算法的工程师感到头疼。
这篇文章的目的,就是充当你的“开箱即用”手册。我们不谈空洞的理论,只聚焦于最实际的工程问题:如何用最高效的方式,把 Argoverse v2 的轨迹预测数据(Motion Forecasting Dataset)处理成模型友好的格式。我会分享一套经过实战检验的 Python 处理流程,涵盖数据读取、关键信息提取、坐标系对齐、特征工程以及最终的数据加载器构建。无论你是想快速跑通一个基线模型,还是为自己的新算法搭建数据管道,这里的内容都能让你节省大量摸索时间。
## 1. 环境搭建与数据获取:避开初学者的坑
在开始写任何代码之前,一个稳定、可复现的环境是高效工作的基石。Argoverse v2 的官方 API (`av2`) 仍在积极开发中,依赖管理需要一些技巧。
### 1.1 创建隔离的 Python 环境
我强烈建议使用 `conda` 或 `venv` 创建一个独立的环境,避免与系统或其他项目的包发生冲突。这里以 `conda` 为例:
```bash
conda create -n argoverse2 python=3.9 -y
conda activate argoverse2
```
选择 Python 3.8 或 3.9 是比较稳妥的,对大多数深度学习框架兼容性最好。
### 1.2 安装核心依赖
接下来安装 `av2` API 及其依赖。官方推荐通过 pip 从 GitHub 安装最新版。
```bash
pip install "git+https://github.com/argoverse/av2-api.git"
```
这个命令会自动处理大部分依赖。但根据我的经验,你很可能还需要手动安装以下关键包,以确保所有功能(特别是与地图和可视化相关的)正常工作:
```bash
pip install pandas pyarrow scikit-learn matplotlib opencv-python
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 # 根据你的CUDA版本选择
```
> **注意**:如果安装 `av2` 过程中遇到 `shapely` 等地理信息库的编译错误,可以尝试先通过 `conda install shapely` 安装二进制版本,再安装 `av2`。
### 1.3 下载数据集:选择正确的“姿势”
Argoverse v2 运动预测数据集体积庞大(约 200 GB)。官方提供了 AWS S3 的下载链接。对于国内用户,直接使用 `wget` 或浏览器下载可能速度缓慢且不稳定。
**推荐方案:使用 `s5cmd` 进行高速下载**
这是一个用 Go 编写的高性能 S3 命令行工具,支持并行传输,能极大提升从 S3 下载大文件的速度。
1. **安装 s5cmd**:前往其 [GitHub 发布页](https://github.com/peak/s5cmd/releases) 下载对应操作系统的二进制文件,或通过包管理器安装(如 `brew install s5cmd`)。
2. **执行下载命令**:
```bash
s5cmd cp "s3://argoverse/motion-forecasting/*" ./argoverse2_data/
```
这条命令会递归下载整个运动预测数据集到本地 `argoverse2_data` 目录。你可以通过 `--concurrency` 参数调整并发数来优化下载速度。
下载后的目录结构如下所示,清晰地区分了训练、验证和测试集:
```
argoverse2_data/
├── train/
│ ├── 0000b0f9-99f9-4a1f-a231-5be9e4c523f7/
│ │ ├── log_map_archive_0000b0f9-99f9-4a1f-a231-5be9e4c523f7.json
│ │ └── scenario_0000b0f9-99f9-4a1f-a231-5be9e4c523f7.parquet
│ ├── 0000b6ab-e100-4f6b-aee8-b520b57c0530/
│ └── ...
├── val/
└── test/
```
每个场景(Scenario)由一个唯一的 UUID 命名的文件夹包含,里面有两个核心文件:`.parquet` 格式的场景轨迹文件和 `.json` 格式的高精地图文件。
## 2. 深入解析数据:读懂每一个字段的含义
拿到数据后,第一步不是急着写预处理,而是先彻底理解数据的组织结构。这能帮你避免后续很多因误解数据含义而导致的错误。
### 2.1 场景文件 (.parquet):轨迹数据的核心
Parquet 是一种高效的列式存储格式。我们用 `pandas` 可以轻松读取。但更重要的是理解每一列代表什么。
```python
import pandas as pd
from pathlib import Path
def inspect_scenario_file(scenario_path: Path):
"""查看一个场景文件的基本信息"""
df = pd.read_parquet(scenario_path)
print(f"数据形状: {df.shape}")
print(f"列名: {df.columns.tolist()}")
print("\n前几行数据:")
print(df.head())
print(f"\n‘track_id’的唯一值数量 (场景中所有物体): {df['track_id'].nunique()}")
print(f"‘focal_track_id’的唯一值 (本场景预测目标): {df['focal_track_id'].unique()}")
print(f"时间步 (timestep) 范围: {df['timestep'].min()} 到 {df['timestep'].max()}")
return df
# 示例:查看第一个训练场景
sample_scenario = Path("./argoverse2_data/train/0000b0f9-99f9-4a1f-a231-5be9e4c523f7/scenario_0000b0f9-99f9-4a1f-a231-5be9e4c523f7.parquet")
df_sample = inspect_scenario_file(sample_scenario)
```
运行上述代码,你会对数据有一个直观认识。下面这个表格总结了 `.parquet` 文件中最关键字段的含义,这是构建特征的基础:
| 字段名 | 数据类型 | 描述与关键点 |
| :--- | :--- | :--- |
| **scenario_id** | str | 场景的唯一标识符 (UUID)。 |
| **track_id** | str | 场景中每个可追踪物体(车辆、行人等)的唯一ID。 |
| **focal_track_id** | str | **核心字段**。指定本场景中需要进行轨迹预测的“目标物体”的 `track_id`。每个场景只有一个。 |
| **timestep** | int | 时间步索引,从 0 到 110。**0-49 是历史观测部分,50-110 是未来真值部分**。这是模型输入和标签分割的依据。 |
| **position_x, position_y** | float | 物体边界框中心在**地图坐标系**下的坐标(单位:米)。 |
| **heading** | float | 物体的朝向角(偏航角),单位是**弧度**。0弧度表示朝向地图坐标系X轴正方向。 |
| **velocity_x, velocity_y** | float | 物体在**地图坐标系**下X和Y方向的速度分量(米/秒)。 |
| **object_type** | str | 物体类型,如 `VEHICLE`, `PEDESTRIAN`, `CYCLIST` 等。 |
| **object_category** | str | **轨迹质量标签**。`FOCAL_TRACK`是预测目标;`SCORED_TRACK`是高质量背景车轨迹;`UNSCORED_TRACK`和`TRACK_FRAGMENT`质量较低。 |
| **observed** | bool | 指示该物体在当前时间步是否被观测到。对于`focal_track`,历史段(0-49)均为True。 |
### 2.2 地图文件 (.json):理解驾驶环境
轨迹不是发生在真空中的,周围的车道线、路口、可行驶区域构成了关键的上下文信息。`av2` 的 `ArgoverseStaticMap` 类让加载和查询地图变得简单。
```python
from av2.map.map_api import ArgoverseStaticMap
def load_map_for_scenario(scenario_path: Path):
"""加载并查看场景对应的地图"""
scenario_id = scenario_path.stem.split('_')[-1]
map_json_path = scenario_path.parent / f"log_map_archive_{scenario_id}.json"
static_map = ArgoverseStaticMap.from_json(map_json_path)
# 查看地图包含的元素
print(f"车道段 (lane_segments) 数量: {len(static_map.lane_segments)}")
print(f"可行驶区域 (drivable_areas) 数量: {len(static_map.vector_drivable_areas)}")
print(f"人行横道 (pedestrian_crossings) 数量: {len(static_map.vector_pedestrian_crossings)}")
# 获取地图的边界,用于可视化或坐标归一化
all_lane_points = []
for lane_id, lane_seg in static_map.lane_segments.items():
all_lane_points.extend(lane_seg.centerline.xyz)
# 转换为numpy数组后计算min/max
# ... (具体代码略)
return static_map
static_map_sample = load_map_for_scenario(sample_scenario)
```
地图数据非常丰富,但对于初期的轨迹预测模型,我们通常最关心**车道中心线 (centerline)**,因为它定义了车辆最可能的行驶路径。每个 `lane_segment` 对象包含了丰富的属性,例如 `is_intersection`(是否在路口内)、`lane_type`(车道类型)、`left_neighbor_id`/`right_neighbor_id`(相邻车道)等,这些都可以作为强大的上下文特征。
## 3. 核心预处理流程:从原始数据到模型输入
理解了数据结构后,我们就可以设计预处理流水线了。一个健壮的流水线应该完成以下转换:原始数据 -> 提取目标及周围物体轨迹 -> 坐标系处理与特征计算 -> 标准化/归一化 -> 组织成样本。
### 3.1 步骤一:提取场景核心数据
我们需要从庞大的 DataFrame 中,抽取出与当前预测任务最相关的数据:目标物体(focal)的历史轨迹,以及周围重要物体(如附近的车辆)的历史轨迹作为上下文。
```python
import numpy as np
from typing import Dict, List, Tuple, Optional
def extract_scenario_data(df: pd.DataFrame, static_map: ArgoverseStaticMap) -> Dict:
"""
从一个场景的DataFrame中提取结构化数据。
返回一个字典,包含目标轨迹、周围物体轨迹、地图信息等。
"""
scenario_id = df['scenario_id'].iloc[0]
focal_track_id = df['focal_track_id'].iloc[0]
# 1. 提取目标物体 (focal agent) 的完整轨迹 (0-110 timesteps)
focal_mask = df['track_id'] == focal_track_id
focal_df = df[focal_mask].sort_values('timestep')
# 分离历史 (0-49) 和未来 (50-110)
hist_mask = focal_df['timestep'] <= 49
fut_mask = focal_df['timestep'] >= 50
focal_history = focal_df[hist_mask][['timestep', 'position_x', 'position_y', 'heading', 'velocity_x', 'velocity_y']].values
focal_future = focal_df[fut_mask][['position_x', 'position_y']].values # 未来通常只关心位置
# 2. 提取周围物体 (neighbor agents) 的历史轨迹
# 策略:选择在历史最后一帧 (timestep=49) 时,距离目标一定范围内的物体
last_hist_step_df = df[df['timestep'] == 49]
focal_pos_last = focal_history[-1, 1:3] # 获取目标在t=49的位置 (x, y)
neighbor_data = {}
for _, row in last_hist_step_df.iterrows():
if row['track_id'] == focal_track_id:
continue # 跳过自己
# 计算与目标的欧氏距离
neighbor_pos = np.array([row['position_x'], row['position_y']])
dist = np.linalg.norm(neighbor_pos - focal_pos_last)
# 只保留距离小于阈值的物体,例如50米
if dist < 50.0 and row['object_category'] in ['SCORED_TRACK', 'FOCAL_TRACK']:
track_id = row['track_id']
# 获取该物体全部历史轨迹
neighbor_track_df = df[df['track_id'] == track_id]
neighbor_track_df = neighbor_track_df[neighbor_track_df['timestep'] <= 49].sort_values('timestep')
# 对齐时间步:有些物体可能不是在所有50帧都出现,需要插值或填充
# 这里简化处理,只取存在的帧
if len(neighbor_track_df) >= 10: # 至少需要一定长度的轨迹才有意义
neighbor_traj = neighbor_track_df[['position_x', 'position_y', 'heading', 'velocity_x', 'velocity_y']].values
neighbor_data[track_id] = {
'trajectory': neighbor_traj,
'object_type': row['object_type'],
'last_position': neighbor_pos
}
# 3. 提取局部地图信息 (例如,目标终点附近的车道)
# 获取目标历史轨迹的终点作为查询点
query_point = focal_pos_last
nearby_lanes = []
for lane_id, lane_seg in static_map.lane_segments.items():
centerline = lane_seg.centerline.xyz
# 计算车道中心线到查询点的最近距离
distances = np.linalg.norm(centerline - query_point, axis=1)
if np.min(distances) < 20.0: # 保留20米内的车道
nearby_lanes.append(centerline)
return {
'scenario_id': scenario_id,
'focal_track_id': focal_track_id,
'focal_history': focal_history, # [T_hist, 6] (ts, x, y, heading, vx, vy)
'focal_future': focal_future, # [T_fut, 2] (x, y)
'neighbors': neighbor_data, # Dict{track_id: {traj, type, ...}}
'nearby_lane_centerlines': nearby_lanes, # List of [N_i, 3] arrays
'city_name': df['city'].iloc[0]
}
```
这个函数是预处理的核心,它执行了关键的信息筛选和聚合。注意其中对周围物体的选择策略——基于距离和轨迹质量,这是平衡计算开销和信息完整性的常用方法。
### 3.2 步骤二:坐标系转换与特征工程
原始数据中的位置和速度是基于**全局地图坐标系**的。对于模型来说,使用以目标物体为原点的相对坐标系通常更有效,因为它具有平移不变性,模型更容易学习运动模式。
```python
def convert_to_agent_centric_coords(scene_data: Dict) -> Dict:
"""
将全局坐标转换为以目标物体在最后一帧观测时刻(t=49)的状态为原点的坐标系。
新的坐标系:X轴指向目标车头的方向,Y轴指向左侧。
"""
focal_hist = scene_data['focal_history']
focal_fut = scene_data['focal_future']
# 获取参考状态:历史最后一帧 (t=49) 的位置和朝向
ref_x, ref_y = focal_hist[-1, 1], focal_hist[-1, 2] # position_x, position_y
ref_heading = focal_hist[-1, 3] # heading
# 计算旋转矩阵 (用于将全局坐标旋转到以目标车头方向为X轴的坐标系)
cos_ref, sin_ref = np.cos(ref_heading), np.sin(ref_heading)
rotation_matrix = np.array([[cos_ref, sin_ref],
[-sin_ref, cos_ref]]) # 注意:这是将点从全局系转到自车系的旋转
# 1. 转换目标自身的历史和未来轨迹
def transform_trajectory(traj_xy):
"""转换轨迹的xy坐标"""
# traj_xy: [N, 2]
translated = traj_xy - np.array([ref_x, ref_y])
rotated = translated @ rotation_matrix.T # 等价于 R * translated^T 再转置
return rotated
focal_hist_xy = focal_hist[:, 1:3] # 提取xy
focal_hist_xy_local = transform_trajectory(focal_hist_xy)
focal_fut_local = transform_trajectory(focal_fut)
# 速度也需要同样的旋转(平移不变,但方向随坐标系旋转)
focal_hist_vel = focal_hist[:, 4:6] # velocity_x, velocity_y
focal_hist_vel_local = focal_hist_vel @ rotation_matrix.T
# 组装新的历史特征:[x_local, y_local, vx_local, vy_local, (可选:加速度)]
focal_hist_local = np.column_stack([focal_hist_xy_local, focal_hist_vel_local])
# 2. 转换所有周围物体的历史轨迹
neighbors_local = {}
for track_id, neighbor_info in scene_data['neighbors'].items():
traj_global = neighbor_info['trajectory']
traj_xy_global = traj_global[:, :2]
traj_xy_local = transform_trajectory(traj_xy_global)
traj_vel_global = traj_global[:, 3:5]
traj_vel_local = traj_vel_global @ rotation_matrix.T
traj_local = np.column_stack([traj_xy_local, traj_vel_local])
neighbors_local[track_id] = {
'trajectory': traj_local,
'object_type': neighbor_info['object_type']
}
# 3. 转换附近的车道线中心点
lanes_local = []
for lane_centerline in scene_data['nearby_lane_centerlines']:
lane_xy_global = lane_centerline[:, :2] # 忽略z坐标
lane_xy_local = transform_trajectory(lane_xy_global)
lanes_local.append(lane_xy_local)
# 更新后的数据字典
scene_data_local = scene_data.copy()
scene_data_local.update({
'focal_history_tensor': focal_hist_local.astype(np.float32), # [50, 4]
'focal_future_tensor': focal_fut_local.astype(np.float32), # [60, 2]
'neighbors_tensor': neighbors_local,
'lane_centerlines_tensor': lanes_local,
'origin_state': np.array([ref_x, ref_y, ref_heading], dtype=np.float32) # 保存转换原点,用于后续反变换
})
return scene_data_local
```
经过这个转换,所有物体的位置和速度都变成了相对于目标最后状态的量。此时,目标物体自身的历史轨迹在最后一帧的位置是(0,0),速度也反映了相对于自身坐标系的速度。这种表示极大地简化了学习问题。
### 3.3 步骤三:构建 PyTorch Dataset 数据加载器
将上述处理流程封装成一个 `torch.utils.data.Dataset` 类,是集成到训练循环的标准做法。
```python
from torch.utils.data import Dataset, DataLoader
import torch
class Argoverse2ForecastingDataset(Dataset):
"""Argoverse 2 轨迹预测数据集类"""
def __init__(self, data_root: Path, split: str = 'train', transform=None):
"""
参数:
data_root: 数据集根目录路径
split: 'train', 'val', 或 'test'
transform: 可选的额外数据变换函数
"""
self.data_root = Path(data_root) / split
self.split = split
self.transform = transform
# 收集所有场景的parquet文件路径
self.scenario_files = list(self.data_root.glob('*/scenario_*.parquet'))
print(f"在 {split} 集中找到 {len(self.scenario_files)} 个场景。")
def __len__(self):
return len(self.scenario_files)
def __getitem__(self, idx):
scenario_path = self.scenario_files[idx]
# 1. 加载原始数据
df = pd.read_parquet(scenario_path)
scenario_id = scenario_path.stem.split('_')[-1]
map_path = scenario_path.parent / f"log_map_archive_{scenario_id}.json"
static_map = ArgoverseStaticMap.from_json(map_path)
# 2. 提取场景核心数据 (使用之前定义的函数)
raw_scene_data = extract_scenario_data(df, static_map)
# 3. 坐标系转换 (使用之前定义的函数)
scene_data_local = convert_to_agent_centric_coords(raw_scene_data)
# 4. 整理为模型输入格式
# 假设我们的模型需要:目标历史、邻居历史列表、车道线列表
focal_past = scene_data_local['focal_history_tensor'] # [50, 4]
# 邻居轨迹:处理成固定数量的张量,不足的用零填充
neighbor_trajs = []
neighbor_types = []
for neighbor_info in scene_data_local['neighbors_tensor'].values():
neighbor_trajs.append(neighbor_info['trajectory'])
# 将物体类型转换为数字编码
type_map = {'VEHICLE': 0, 'PEDESTRIAN': 1, 'CYCLIST': 2, 'MOTORCYCLIST': 3, 'BUS': 4}
neighbor_types.append(type_map.get(neighbor_info['object_type'], 5)) # 未知类型为5
# 假设我们最多考虑10个邻居,每个邻居轨迹长度50
max_neighbors = 10
hist_len = 50
feat_dim = 4 # x, y, vx, vy
neighbor_tensor = torch.zeros((max_neighbors, hist_len, feat_dim), dtype=torch.float32)
neighbor_mask = torch.zeros(max_neighbors, dtype=torch.bool) # True表示有效邻居
neighbor_type_tensor = torch.full((max_neighbors,), 5, dtype=torch.long) # 默认未知类型
num_valid = min(len(neighbor_trajs), max_neighbors)
for i in range(num_valid):
traj = neighbor_trajs[i]
traj_len = traj.shape[0]
# 注意:邻居轨迹可能不满50帧,我们需要对齐到最后一帧 (右对齐)
if traj_len < hist_len:
# 前面用第一帧填充,或者用零填充。这里采用复制第一帧的简单策略。
padded_traj = np.zeros((hist_len, feat_dim))
padded_traj[-traj_len:, :] = traj
padded_traj[:hist_len-traj_len, :] = traj[0]
neighbor_tensor[i] = torch.from_numpy(padded_traj)
else:
neighbor_tensor[i] = torch.from_numpy(traj[-hist_len:]) # 取最后50帧
neighbor_mask[i] = True
neighbor_type_tensor[i] = neighbor_types[i]
# 车道线:处理成固定数量的点云
lane_centerlines = scene_data_local['lane_centerlines_tensor']
# 将所有车道线点拼接起来,并添加一个特征维度(例如,点所属车道ID的嵌入,这里简化)
lane_points_list = []
for i, lane in enumerate(lane_centerlines):
# 为每个点添加一个车道索引作为特征
lane_feat = np.column_stack([lane, np.full((lane.shape[0], 1), i)]) # [N_points, 3] (x, y, lane_id)
lane_points_list.append(lane_feat)
if lane_points_list:
all_lane_points = np.vstack(lane_points_list).astype(np.float32)
# 随机采样或截断,控制点数
max_lane_points = 500
if all_lane_points.shape[0] > max_lane_points:
indices = np.random.choice(all_lane_points.shape[0], max_lane_points, replace=False)
all_lane_points = all_lane_points[indices]
lane_points_tensor = torch.from_numpy(all_lane_points) # [N, 3]
else:
lane_points_tensor = torch.zeros((1, 3), dtype=torch.float32) # 占位符
# 未来真值 (仅在训练和验证集有)
if self.split != 'test':
focal_future = scene_data_local['focal_future_tensor'] # [60, 2]
future_tensor = torch.from_numpy(focal_future)
else:
future_tensor = torch.zeros((60, 2), dtype=torch.float32) # 测试集没有真值
# 可选的数据增强变换
if self.transform:
focal_past, neighbor_tensor, lane_points_tensor, future_tensor = self.transform(
focal_past, neighbor_tensor, lane_points_tensor, future_tensor
)
# 返回一个字典,包含所有必要数据
data_item = {
'focal_past': torch.from_numpy(focal_past), # [50, 4]
'neighbor_past': neighbor_tensor, # [10, 50, 4]
'neighbor_mask': neighbor_mask, # [10]
'neighbor_type': neighbor_type_tensor, # [10]
'lane_points': lane_points_tensor, # [N, 3]
'future_gt': future_tensor, # [60, 2] (测试集为0)
'scenario_id': scenario_id,
'origin_state': torch.from_numpy(scene_data_local['origin_state']) # 用于后处理转换回全局坐标
}
return data_item
# 使用示例
if __name__ == '__main__':
dataset = Argoverse2ForecastingDataset(data_root='./argoverse2_data', split='train')
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
# 取一个batch看看
for batch in dataloader:
print(f"目标历史轨迹形状: {batch['focal_past'].shape}")
print(f"邻居历史轨迹形状: {batch['neighbor_past'].shape}")
print(f"邻居有效掩码: {batch['neighbor_mask'].sum(dim=1)}") # 每个样本的有效邻居数
print(f"车道点云形状: {batch['lane_points'].shape}")
break
```
这个 `Dataset` 类完成了从文件路径到模型可直接使用的张量的完整转换。它处理了变长数据(邻居数量、车道点数)到定长张量的填充,并封装了所有必要的预处理逻辑。
## 4. 高级技巧与实战注意事项
在基础流程之上,还有一些技巧能显著提升你的数据管道质量和模型性能。
### 4.1 数据增强:提升模型泛化能力
对于轨迹预测,在自车坐标系下进行数据增强既安全又有效。因为坐标系已经归一化,增强操作不会破坏物理逻辑。
```python
class TrajectoryAugmentation:
"""简单的轨迹数据增强类"""
def __init__(self, aug_prob=0.5):
self.aug_prob = aug_prob
def __call__(self, focal_past, neighbor_past, lane_points, future_gt):
# 1. 随机水平翻转 (镜像场景)
if torch.rand(1) < self.aug_prob:
# 翻转x坐标和x方向速度
focal_past[:, 0] = -focal_past[:, 0] # x
focal_past[:, 2] = -focal_past[:, 2] # vx
neighbor_past[:, :, 0] = -neighbor_past[:, :, 0]
neighbor_past[:, :, 2] = -neighbor_past[:, :, 2]
lane_points[:, 0] = -lane_points[:, 0] # 车道线x坐标
future_gt[:, 0] = -future_gt[:, 0]
# 注意:heading也需要调整 (heading = pi - heading),但我们的特征里没有直接存储heading,已隐含在速度中。
# 2. 随机轻微旋转 (模拟不同的初始朝向观测噪声)
if torch.rand(1) < self.aug_prob:
angle = torch.rand(1) * 0.1 - 0.05 # [-0.05, 0.05] 弧度
cos_a, sin_a = torch.cos(angle), torch.sin(angle)
rot_mat = torch.tensor([[cos_a, sin_a], [-sin_a, cos_a]])
# 旋转位置和速度
focal_past[:, :2] = focal_past[:, :2] @ rot_mat.T
focal_past[:, 2:4] = focal_past[:, 2:4] @ rot_mat.T
neighbor_past[:, :, :2] = neighbor_past[:, :, :2] @ rot_mat.T.unsqueeze(0)
neighbor_past[:, :, 2:4] = neighbor_past[:, :, 2:4] @ rot_mat.T.unsqueeze(0)
lane_points[:, :2] = lane_points[:, :2] @ rot_mat.T
future_gt = future_gt @ rot_mat.T
# 3. 随机缩放 (模拟速度感知的轻微变化)
if torch.rand(1) < self.aug_prob:
scale = 0.9 + torch.rand(1) * 0.2 # [0.9, 1.1]
focal_past[:, :2] *= scale
focal_past[:, 2:4] *= scale
neighbor_past[:, :, :2] *= scale
neighbor_past[:, :, 2:4] *= scale
lane_points[:, :2] *= scale
future_gt *= scale
return focal_past, neighbor_past, lane_points, future_gt
```
将这些增强集成到 Dataset 中,可以显著增加数据的多样性,尤其是在训练集规模有限的情况下。
### 4.2 处理地图信息的更优策略
之前我们只是简单地将附近的车道点作为点云输入。更高级的做法是使用**车道图 (Lane Graph)**。你可以利用 `av2` API 提供的 `lane_segments` 及其 `predecessors`/`successors` 属性来构建一个局部有向图,其中节点是车道段,边表示连通性。然后使用图神经网络 (GNN) 来编码地图信息,这比原始点云能更好地捕捉车道拓扑结构。
```python
def build_local_lane_graph(static_map, query_point, radius=50.0):
"""在查询点周围构建一个局部车道图"""
from collections import defaultdict
import networkx as nx
G = nx.DiGraph()
# 1. 找到半径内的所有车道段
nearby_lane_ids = []
for lane_id, lane_seg in static_map.lane_segments.items():
centerline = lane_seg.centerline.xyz[:, :2]
distances = np.linalg.norm(centerline - query_point, axis=1)
if np.min(distances) < radius:
nearby_lane_ids.append(lane_id)
# 将车道段作为节点,属性包括中心线坐标、类型、是否在路口等
G.add_node(lane_id,
centerline=centerline,
lane_type=lane_seg.lane_type,
is_intersection=lane_seg.is_intersection)
# 2. 根据前后继关系添加边
for lane_id in nearby_lane_ids:
lane_seg = static_map.lane_segments[lane_id]
for pred_id in lane_seg.predecessors:
if pred_id in nearby_lane_ids:
G.add_edge(pred_id, lane_id)
for succ_id in lane_seg.successors:
if succ_id in nearby_lane_ids:
G.add_edge(lane_id, succ_id)
return G
```
构建好的图可以作为后续 GNN 层的输入,极大地提升模型对结构化道路信息的理解能力。
### 4.3 性能优化与缓存
预处理所有场景的计算量可能很大。一个实用的技巧是**预处理并缓存**。你可以运行一次完整的提取和转换流程,将每个场景处理后的数据(`focal_past`, `neighbor_past` 等)以 `.npz` 或 `.pkl` 格式保存到磁盘。然后在 `Dataset` 的 `__getitem__` 中直接加载缓存文件,速度会快几个数量级。
```python
def preprocess_and_cache_all(data_root, split, cache_dir):
"""预处理整个数据集分片并缓存"""
dataset_raw = Argoverse2ForecastingDataset(data_root, split, transform=None)
cache_dir = Path(cache_dir) / split
cache_dir.mkdir(parents=True, exist_ok=True)
for i in tqdm(range(len(dataset_raw))):
data_item = dataset_raw[i]
cache_path = cache_dir / f"{data_item['scenario_id']}.npz"
# 将数据项中的张量转换为numpy保存
np.savez_compressed(cache_path,
focal_past=data_item['focal_past'].numpy(),
neighbor_past=data_item['neighbor_past'].numpy(),
neighbor_mask=data_item['neighbor_mask'].numpy(),
future_gt=data_item['future_gt'].numpy(),
origin_state=data_item['origin_state'].numpy())
print(f"缓存完成,保存在 {cache_dir}")
```
之后,你可以创建一个 `CachedArgoverseDataset`,它直接从缓存文件加载,从而在迭代训练时实现极高的数据读取速度。
经过以上四个部分的拆解,你应该已经掌握了从零开始处理 Argoverse v2 轨迹预测数据的全链路技能。从环境配置、数据理解,到核心的坐标转换和特征工程,再到最终封装成高性能的 DataLoader,每一步都结合了代码实例和实战经验。这套流程不是一成不变的模板,你可以根据自己的模型需求,灵活调整特征提取的维度、邻居选择策略和地图编码方式。例如,如果你想尝试最新的基于 Transformer 的预测模型,可能需要将所有的智能体轨迹和车道线序列化成一个长序列;而如果使用基于 CNN 的方法,则可能需要将场景栅格化为 BEV 图像。无论后续模型如何变化,坚实、高效、可扩展的数据预处理管道都是成功的第一步。