# 从推箱子游戏理解Wasserstein距离:用Python可视化scipy.stats的EMD计算过程
想象一下,你面前有两个仓库,每个仓库都有几个固定的货架位置,货架上堆着不同数量的箱子。你的任务是把左边仓库的箱子重新摆放,让它看起来和右边仓库一模一样。但有个限制:你只能把箱子从一个货架搬到另一个货架,而且搬动的总距离要尽可能小。这个“最小搬动距离”就是**Wasserstein距离**的核心思想。
我第一次接触这个概念是在处理一些图像风格迁移的项目中,当时需要量化两种颜色分布的差异。传统的KL散度或JS散度在面对分布没有重叠区域时会失效,而Wasserstein距离却能给出一个有意义的数值。后来我发现,这个看似抽象的数学概念,其实可以用生活中常见的“推箱子”游戏来直观理解。今天,我们就抛开复杂的公式,用Python和Matplotlib,像玩推箱子一样,一步步拆解Wasserstein距离的计算过程。
这篇文章适合对数学可视化感兴趣的朋友,无论你是数据科学家、机器学习工程师,还是单纯想理解最优传输理论的爱好者。我们将从零开始,构建一个动态的可视化系统,让你亲眼看到“箱子”是如何被“搬运”的,以及`scipy.stats.wasserstein_distance`这个函数背后究竟发生了什么。
## 1. 环境准备与核心概念拆解
在开始写代码之前,我们需要先搭建好实验环境。我习惯用Anaconda管理Python环境,这样依赖包冲突的问题会少很多。如果你还没有安装SciPy和Matplotlib,可以通过下面的命令快速安装。
```bash
pip install numpy scipy matplotlib
```
对于更复杂的动画生成,我们可能还会用到`matplotlib.animation`模块,不过别担心,它已经包含在Matplotlib里了。接下来,我们得先搞清楚几个关键术语,不然代码写起来会一头雾水。
在Wasserstein距离的语境下,我们常说的“推土机距离”或“最优传输距离”,其实描述的是同一件事:**用最小的成本把一个概率分布变成另一个**。在离散的一维情况下,这个“成本”就是每个概率质量(箱子)移动的距离乘以它的质量(箱子数量)。
`scipy.stats.wasserstein_distance`函数接受四个主要参数:
- `u_values`:第一个分布的支撑点(货架位置)
- `v_values`:第二个分布的支撑点(货架位置)
- `u_weights`:第一个分布在各个支撑点上的概率质量(箱子数量)
- `v_weights`:第二个分布在各个支撑点上的概率质量(箱子数量)
如果不指定权重,函数会默认每个支撑点上的质量都是1。这里有个容易踩坑的地方:`u_values`和`v_values`的长度可以不同,但`u_values`和`u_weights`的长度必须一致,`v_values`和`v_weights`也是如此。
为了让大家对这个概念有更具体的感受,我设计了一个简单的对照表,用推箱子的场景来类比函数的各个参数:
| 数学概念 | 推箱子类比 | 在代码中的体现 |
| :--- | :--- | :--- |
| 支撑点 (`u_values`, `v_values`) | 货架或槽位的固定位置 | 一维数组,如 `[0, 1, 2, 3]` |
| 概率质量 (`u_weights`, `v_weights`) | 每个货架上箱子的数量 | 一维数组,如 `[4, 2, 1, 3]` |
| 传输计划 | 具体的搬箱方案:从A货架搬多少箱子到B货架 | 一个矩阵,描述质量如何流动 |
| Wasserstein距离 | 所有箱子移动的**总距离×箱子数**的最小和 | 函数返回的一个浮点数 |
> 注意:权重数组的和不需要是1。如果权重和不为1,函数内部会先将其归一化,再计算距离。这意味着`[4, 2, 1, 3]`和`[0.4, 0.2, 0.1, 0.3]`作为权重输入,计算出的距离是一样的。
理解了这些,我们就可以开始动手,用代码创建第一个“仓库”和“箱子”了。
## 2. 构建基础可视化:静态分布与距离计算
我们先从最简单的静态图开始。假设有两个分布,它们的支撑点都在`[0, 1, 2, 3]`这四个位置,但箱子堆放的方式不同。分布A的箱子数是`[4, 2, 1, 3]`,分布B是`[3, 1, 2, 4]`。我们的目标是计算将它们对齐所需的最小成本。
首先,导入必要的库,并计算一下Wasserstein距离。
```python
import numpy as np
from scipy.stats import wasserstein_distance
import matplotlib.pyplot as plt
# 定义两个分布
u_positions = np.array([0, 1, 2, 3]) # 仓库A的货架位置
u_boxes = np.array([4, 2, 1, 3]) # 仓库A各货架的箱子数
v_positions = np.array([0, 1, 2, 3]) # 仓库B的货架位置
v_boxes = np.array([3, 1, 2, 4]) # 仓库B各货架的箱子数
# 计算Wasserstein距离
distance = wasserstein_distance(u_positions, v_positions, u_boxes, v_boxes)
print(f"Wasserstein距离为: {distance}")
```
运行这段代码,你会得到结果`0.4`。这个数字本身可能有点抽象,我们把它画出来看看。下面的代码会生成一张图,用柱状图表示两个分布的箱子堆放情况。
```python
def plot_static_distributions(u_pos, u_w, v_pos, v_w, distance):
"""
绘制两个静态分布的对比图
"""
fig, ax = plt.subplots(figsize=(10, 6))
# 设置柱状图的宽度和位置
width = 0.35
x_u = np.arange(len(u_pos))
x_v = np.arange(len(v_pos))
# 绘制分布A(上方,向下生长)
bars_u = ax.bar(x_u - width/2, u_w, width, label='分布 A (源)', color='skyblue', edgecolor='black')
# 绘制分布B(下方,向上生长)
bars_v = ax.bar(x_v + width/2, v_w, width, label='分布 B (目标)', color='lightcoral', edgecolor='black', bottom=0)
# 在柱子上方标注箱子数量
for bar, weight in zip(bars_u, u_w):
height = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2., height + 0.05,
f'{weight}', ha='center', va='bottom', fontsize=9)
for bar, weight in zip(bars_v, v_w):
height = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2., 0.05,
f'{weight}', ha='center', va='bottom', fontsize=9, color='white')
# 美化图表
ax.set_xlabel('货架位置', fontsize=12)
ax.set_ylabel('箱子数量', fontsize=12)
ax.set_title(f'两个分布的箱子堆放情况 | Wasserstein距离 = {distance:.2f}', fontsize=14, pad=15)
ax.set_xticks(x_u)
ax.set_xticklabels([str(pos) for pos in u_pos])
ax.legend()
ax.grid(axis='y', linestyle='--', alpha=0.7)
plt.tight_layout()
plt.show()
# 调用绘图函数
plot_static_distributions(u_positions, u_boxes, v_positions, v_boxes, distance)
```
生成的图表会清晰地显示,在位置0,我们需要把1个箱子搬走(因为A有4个,B只需要3个);在位置2,我们需要搬来1个箱子(A只有1个,B需要2个)。但问题来了:这些箱子具体从哪里搬到哪里?这就是**最优传输计划**要解决的问题。
仅仅看静态图,我们无法知道成本最低的搬箱方案是什么。也许从位置0搬1个箱子到位置2是最直接的,但距离是2。有没有可能通过“中转”来减少成本?比如从位置0搬1个箱子到位置1,再从位置1搬1个箱子到位置2?这样每步距离是1,总距离还是2,似乎没区别。但在更复杂的分布中,最优路径往往不是显而易见的。
为了找到这个最优计划,我们需要深入算法的内部。`scipy.stats.wasserstein_distance`在一维情况下使用了一个非常高效的算法,它基于累积分布函数(CDF)的差异。对于离散分布,Wasserstein距离有一个等价的简洁公式:
\[
l_1(u, v) = \sum_{i} |U_i - V_i|
\]
其中 \( U_i \) 和 \( V_i \) 分别是两个分布的累积概率质量。也就是说,我们可以通过计算两个“累积箱子堆”之间区域的面积来得到距离。下面这个函数演示了如何手动计算这个值,并与SciPy的结果进行对比验证。
```python
def manual_wasserstein_1d(u_pos, u_w, v_pos, v_w):
"""
手动计算一维Wasserstein距离,通过累积分布函数(CDF)方法。
此方法仅当两个分布的支撑点相同时简便。
"""
# 归一化权重,使其成为概率质量函数
u_prob = u_w / np.sum(u_w)
v_prob = v_w / np.sum(v_w)
# 计算累积分布函数 (CDF)
u_cdf = np.cumsum(u_prob)
v_cdf = np.cumsum(v_prob)
# 计算CDF之间差的绝对值之和
# 注意:这里假设支撑点已按相同顺序排序且一一对应
distance = np.sum(np.abs(u_cdf - v_cdf))
return distance
# 计算并对比
manual_dist = manual_wasserstein_1d(u_positions, u_boxes, v_positions, v_boxes)
print(f"SciPy计算的距离: {distance}")
print(f"手动计算的距离: {manual_dist}")
print(f"两者是否接近: {np.isclose(distance, manual_dist)}")
```
如果一切正常,两个结果应该是一致的。这个手动计算的过程揭示了Wasserstein距离的几何意义:它就是两个累积分布曲线之间区域的面积。理解这一点对我们后续制作动画至关重要,因为我们可以把“搬箱子”的过程看作是逐步填平这两个CDF曲线之间落差的过程。
## 3. 从静态到动态:模拟最优传输过程
静态图告诉我们起点和终点,但过程的魅力在于观察变化。接下来,我们要模拟箱子被搬运的中间状态。这不是简单的插值,而是需要根据最优传输计划,计算出在“工作”完成到某个百分比时,箱子应该如何分布。
首先,我们需要找出完整的最优传输计划。对于一维且支撑点相同的情况,有一个贪心策略:从左到右处理,每个位置的盈余或赤字依次向相邻位置传递。这就像水流一样,多余的会往低处流。我们可以实现一个函数来模拟这个“流动”过程。
```python
def compute_optimal_transport_flow(u_pos, u_w, v_pos, v_w):
"""
计算一维情况下(支撑点相同且有序)的最优传输流。
返回一个流矩阵,其中flow[i]表示从位置i移动到位置i+1的质量(正数表示向右)。
"""
# 确保输入是numpy数组并按位置排序
u_pos_sorted = np.array(u_pos)
v_pos_sorted = np.array(v_pos)
u_w_sorted = np.array(u_w)
v_w_sorted = np.array(v_w)
# 归一化质量,使其总和为1
u_mass = u_w_sorted / np.sum(u_w_sorted)
v_mass = v_w_sorted / np.sum(v_w_sorted)
# 计算每个位置上的质量盈余(正数表示有多余,负数表示有缺口)
surplus = u_mass - v_mass
# 初始化流数组,长度比位置数少1
n = len(surplus)
flow = np.zeros(n - 1)
# 从左到右传递盈余
current_carry = 0.0
for i in range(n - 1):
# 当前位置的净盈余加上之前传递过来的
current_carry += surplus[i]
# 流向下一个位置的质量就是当前携带的量
flow[i] = current_carry
# 注意:current_carry在循环中会持续累积,直到最后位置
return flow, surplus
# 计算我们示例中的流
flow, surplus = compute_optimal_transport_flow(u_positions, u_boxes, v_positions, v_boxes)
print("每个位置的质量盈余 (u - v):", surplus)
print("相邻位置间的传输流 (向右为正):", flow)
```
运行后,你可能会看到类似这样的输出:
```
每个位置的质量盈余 (u - v): [ 0.1 -0.1 -0.1 0.1]
相邻位置间的传输流 (向右为正): [ 0.1 0.0 -0.1]
```
解读一下:在位置0,分布A比分布B多0.1的质量(因为归一化了),所以有0.1的质量需要向右运走。到了位置1,盈余变成了-0.1(A比B少),但加上从左边运来的0.1,净携带量变为0,所以位置1到位置2的流为0。这个流矩阵清晰地刻画了质量是如何像接力棒一样在相邻货架间传递的。
有了传输流,我们就可以模拟任意中间时刻的状态了。假设整个搬运工作完成了比例 `t`(0到1之间),那么每个位置上的瞬时质量就是起始质量减去已经流出的质量加上已经流入的质量。下面的函数计算中间状态,并生成一帧图像。
```python
def compute_intermediate_mass(u_mass, flow, t):
"""
根据传输流和完成比例t,计算中间状态的质量分布。
u_mass: 起始归一化质量
flow: 传输流数组
t: 完成比例,0 <= t <= 1
"""
n = len(u_mass)
intermediate = u_mass.copy()
# 根据流和比例t,调整质量
# 流flow[i]表示从位置i到i+1的传输量。
# 在时间t,有 t * flow[i] 的质量已经从i离开,但尚未全部到达i+1。
# 我们采用一个简单的线性插值:离开的质量为 t*flow[i],到达下一个位置的质量也是 t*flow[i](假设瞬时到达)。
# 更精确的模拟需要跟踪在途质量,这里为简化做此假设。
for i in range(n - 1):
if flow[i] > 0: # 向右流动
mass_to_move = t * flow[i]
intermediate[i] -= mass_to_move
intermediate[i + 1] += mass_to_move
elif flow[i] < 0: # 向左流动
mass_to_move = t * abs(flow[i])
intermediate[i + 1] -= mass_to_move # 从i+1向左移动
intermediate[i] += mass_to_move
return intermediate
def plot_intermediate_state(u_pos, u_w, v_pos, v_w, flow, t, ax):
"""
在给定的axes上绘制时间t时的中间状态。
"""
# 归一化质量
u_mass = u_w / np.sum(u_w)
v_mass = v_w / np.sum(v_w)
# 计算中间状态质量
inter_mass = compute_intermediate_mass(u_mass, flow, t)
# 转换回箱子数量以便绘图(按原始总箱子数缩放)
total_boxes = np.sum(u_w)
inter_boxes = inter_mass * total_boxes
ax.clear()
width = 0.35
x = np.arange(len(u_pos))
# 绘制原始分布A(半透明,作为参考)
ax.bar(x - width/2, u_w, width, alpha=0.3, label='分布 A (初始)', color='skyblue')
# 绘制目标分布B(半透明,作为参考)
ax.bar(x + width/2, v_w, width, alpha=0.3, label='分布 B (目标)', color='lightcoral')
# 绘制当前中间状态(实心)
ax.bar(x, inter_boxes, width*0.7, label=f'传输中 (t={t:.2f})', color='gold', edgecolor='darkorange')
# 标注当前箱子数
for i, (pos, boxes) in enumerate(zip(x, inter_boxes)):
ax.text(pos, boxes + 0.1, f'{boxes:.1f}', ha='center', va='bottom', fontsize=9)
ax.set_xlabel('货架位置')
ax.set_ylabel('箱子数量')
ax.set_title(f'最优传输过程模拟 (t={t:.2f})')
ax.set_xticks(x)
ax.set_xticklabels([str(pos) for pos in u_pos])
ax.legend(loc='upper left')
ax.grid(axis='y', linestyle='--', alpha=0.5)
ax.set_ylim(0, max(np.max(u_w), np.max(v_w)) * 1.2)
# 生成一个从0到1的动画序列(这里先展示几个关键帧)
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes = axes.flatten()
time_points = [0.0, 0.33, 0.67, 1.0]
for ax, t in zip(axes, time_points):
plot_intermediate_state(u_positions, u_boxes, v_positions, v_boxes, flow, t, ax)
plt.tight_layout()
plt.show()
```
这几张分帧图会显示,箱子数量如何从初始的蓝色分布,经过金色的中间状态,最终变成红色的目标分布。你可以清楚地看到,在位置0多余的箱子逐渐减少,而在位置2短缺的箱子逐渐补上。但真正的动画应该让这些柱子“动起来”。这就需要用到Matplotlib的动画模块了。
## 4. 生成完整动画与交互式探索
创建动画的核心是定义一个更新函数,它会在每一帧被调用,改变图形对象的数据。然后使用`FuncAnimation`将这些帧串联起来。下面是一个完整的动画生成脚本,它会产生一个MP4视频文件,展示连续的传输过程。
```python
import matplotlib.animation as animation
from matplotlib.animation import FuncAnimation, PillowWriter, FFMpegWriter
def create_wasserstein_animation(u_pos, u_w, v_pos, v_w, output_path='wasserstein_transport.mp4'):
"""
创建并保存Wasserstein距离最优传输过程的动画。
"""
# 计算最优传输流
u_mass_norm = u_w / np.sum(u_w)
v_mass_norm = v_w / np.sum(v_w)
flow, surplus = compute_optimal_transport_flow(u_pos, u_w, v_pos, v_w)
fig, ax = plt.subplots(figsize=(10, 6))
# 初始化图形元素
width = 0.35
x = np.arange(len(u_pos))
# 绘制初始和目标分布(半透明参考)
bars_u_ref = ax.bar(x - width/2, u_w, width, alpha=0.3, label='分布 A (初始)', color='skyblue')
bars_v_ref = ax.bar(x + width/2, v_w, width, alpha=0.3, label='分布 B (目标)', color='lightcoral')
# 初始化中间状态柱状图(实心)
inter_bars = ax.bar(x, u_w, width*0.7, label='传输中', color='gold', edgecolor='darkorange', linewidth=1.5)
# 初始化文本标注(显示当前箱子数)
text_annotations = []
for i, bar in enumerate(inter_bars):
height = bar.get_height()
text = ax.text(bar.get_x() + bar.get_width()/2., height + 0.1, f'{height:.1f}',
ha='center', va='bottom', fontsize=10, fontweight='bold')
text_annotations.append(text)
# 在顶部添加一个进度条
progress_bar = ax.axhline(y=max(np.max(u_w), np.max(v_w)) * 1.3, xmin=0, xmax=0,
color='green', linewidth=5, solid_capstyle='round')
ax.set_xlabel('货架位置', fontsize=12)
ax.set_ylabel('箱子数量', fontsize=12)
ax.set_title('Wasserstein距离最优传输过程模拟', fontsize=14, pad=15)
ax.set_xticks(x)
ax.set_xticklabels([str(pos) for pos in u_pos])
ax.legend(loc='upper left')
ax.grid(axis='y', linestyle='--', alpha=0.5)
ax.set_ylim(0, max(np.max(u_w), np.max(v_w)) * 1.4)
# 计算总距离,并显示在图表上
total_distance = wasserstein_distance(u_pos, v_pos, u_w, v_w)
distance_text = ax.text(0.02, 0.98, f'Wasserstein距离 = {total_distance:.2f}',
transform=ax.transAxes, fontsize=12,
verticalalignment='top',
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
# 更新函数,用于每一帧
def update(frame):
t = frame / 100 # 假设总共100帧,t从0到1
inter_mass = compute_intermediate_mass(u_mass_norm, flow, t)
inter_boxes = inter_mass * np.sum(u_w)
# 更新中间状态柱子的高度
for i, bar in enumerate(inter_bars):
bar.set_height(inter_boxes[i])
# 更新文本标注
for i, text in enumerate(text_annotations):
text.set_position((inter_bars[i].get_x() + inter_bars[i].get_width()/2.,
inter_boxes[i] + 0.1))
text.set_text(f'{inter_boxes[i]:.1f}')
# 更新进度条
progress_bar.set_xdata([0, t])
# 更新标题,显示当前进度
ax.set_title(f'Wasserstein距离最优传输过程模拟 (进度: {t:.0%})', fontsize=14, pad=15)
return inter_bars + text_annotations + [progress_bar]
# 创建动画
ani = FuncAnimation(fig, update, frames=101, interval=50, blit=False) # 101帧包含t=0和t=1
# 保存为MP4视频(需要安装ffmpeg)
try:
writer = FFMpegWriter(fps=20, metadata=dict(artist='Me'), bitrate=1800)
ani.save(output_path, writer=writer)
print(f"动画已保存至: {output_path}")
except Exception as e:
print(f"保存MP4失败,尝试保存为GIF: {e}")
# 回退方案:保存为GIF
gif_path = output_path.replace('.mp4', '.gif')
writer = PillowWriter(fps=20)
ani.save(gif_path, writer=writer)
print(f"动画已保存至GIF: {gif_path}")
plt.close(fig)
return ani
# 生成动画
ani = create_wasserstein_animation(u_positions, u_boxes, v_positions, v_boxes)
```
运行这个脚本,你会得到一个视频文件,其中金色的柱子平滑地变化,顶部的绿色进度条逐渐延伸,最终金色的分布与红色的目标分布重合。这个过程直观地展示了“最小成本”的含义:质量总是沿着最短的路径(相邻位置)进行转移,没有不必要的长途搬运。
对于喜欢动手尝试的朋友,静态图片和预渲染动画可能还不够过瘾。我们可以利用Jupyter Notebook的交互式控件,创建一个可以实时调节参数的可视化工具。下面这段代码使用了`ipywidgets`库,允许你滑动滑块来改变传输进度,甚至动态修改分布。
```python
# 注意:此部分代码需要在Jupyter Notebook环境中运行
try:
import ipywidgets as widgets
from IPython.display import display, clear_output
%matplotlib inline
# 创建交互式控件
t_slider = widgets.FloatSlider(value=0.0, min=0.0, max=1.0, step=0.01,
description='进度 t:', continuous_update=True)
# 分布参数的可调输入(示例:改变位置2的箱子数)
box_at_pos2 = widgets.IntSlider(value=u_boxes[2], min=0, max=10, step=1,
description='位置2箱子数:')
output = widgets.Output()
def update_interactive_plot(t, new_box_count):
with output:
clear_output(wait=True)
# 更新分布A的箱子数
u_w_modified = u_boxes.copy()
u_w_modified[2] = new_box_count
# 重新计算流和距离
flow_mod, _ = compute_optimal_transport_flow(u_positions, u_w_modified, v_positions, v_boxes)
dist_mod = wasserstein_distance(u_positions, v_positions, u_w_modified, v_boxes)
# 计算中间状态
u_mass_norm = u_w_modified / np.sum(u_w_modified)
inter_mass = compute_intermediate_mass(u_mass_norm, flow_mod, t)
inter_boxes = inter_mass * np.sum(u_w_modified)
# 绘图
fig, ax = plt.subplots(figsize=(10, 6))
width = 0.35
x = np.arange(len(u_positions))
ax.bar(x - width/2, u_w_modified, width, alpha=0.3, label='分布 A (修改后)', color='skyblue')
ax.bar(x + width/2, v_boxes, width, alpha=0.3, label='分布 B (目标)', color='lightcoral')
ax.bar(x, inter_boxes, width*0.7, label=f'传输中 (t={t:.2f})', color='gold', edgecolor='darkorange')
for i, (pos, boxes) in enumerate(zip(x, inter_boxes)):
ax.text(pos, boxes + 0.1, f'{boxes:.1f}', ha='center', va='bottom', fontsize=9)
ax.set_xlabel('货架位置')
ax.set_ylabel('箱子数量')
ax.set_title(f'交互式模拟 | 实时Wasserstein距离 = {dist_mod:.2f}')
ax.set_xticks(x)
ax.set_xticklabels([str(pos) for pos in u_positions])
ax.legend()
ax.grid(axis='y', linestyle='--', alpha=0.5)
ax.set_ylim(0, max(np.max(u_w_modified), np.max(v_boxes)) * 1.2)
plt.show()
# 将控件与更新函数绑定
widgets.interactive(update_interactive_plot, t=t_slider, new_box_count=box_at_pos2)
except ImportError:
print("ipywidgets 未安装。要运行交互式示例,请在Jupyter中安装:`pip install ipywidgets`")
```
通过这个交互式工具,你可以实时看到,当改变某个货架上的箱子数量时,最优传输流如何变化,总距离如何响应。这种即时反馈对于培养对Wasserstein距离的“直觉”非常有帮助。比如,你会发现当两个分布在某个位置的质量差变大时,需要移动的总质量增加,距离通常会变大;但如果这个位置离另一个有相反质量差的位置很近,成本可能并不会增加太多。
## 5. 超越一维:挑战与多维可视化思路
到目前为止,我们都在讨论一维的“货架”。但在现实世界中,数据往往是多维的。例如,图像的颜色可以用RGB三维空间中的点表示,两个图像的颜色分布差异就需要用多维Wasserstein距离来衡量。SciPy从1.13版本开始,在`scipy.stats`模块中引入了`wasserstein_distance_nd`函数,专门用于计算N维离散分布之间的距离。
多维情况下的直观理解就困难多了。你不能再简单地想象一条线上的货架,而是要想象一个空间中的网格,每个格子有一定数量的箱子,你需要把它们重新排列成另一个形状。最优传输计划不再是一个简单的流数组,而是一个复杂的流矩阵,描述从每个源格子到每个目标格子的运输量。
计算多维Wasserstein距离本质上是一个线性规划问题。`wasserstein_distance_nd`的实现就是将其转化为线性规划问题,然后利用SciPy的线性规划求解器找到最优解。虽然我们无法像一维那样制作出简单明了的完整动画,但可以尝试对低维情况(如二维)进行可视化。
假设我们有两个二维分布,支撑点是二维平面上的点。我们可以用散点图的大小表示权重,然后用箭头表示主要的传输路径。下面是一个简化的概念性代码,展示如何可视化二维分布及其之间的“搬运”。
```python
def visualize_2d_wasserstein(u_points, u_weights, v_points, v_weights):
"""
可视化二维分布及它们之间的主要传输关系(概念性)。
u_points: 形状为 (n, 2) 的数组,表示n个二维点
u_weights: 长度为n的数组,表示每个点的权重
v_points: 形状为 (m, 2) 的数组
v_weights: 长度为m的数组
"""
from scipy.stats import wasserstein_distance_nd
# 计算距离
distance = wasserstein_distance_nd(u_points, v_points, u_weights, v_weights)
fig, ax = plt.subplots(figsize=(10, 8))
# 归一化权重以控制散点大小
u_size = u_weights / np.max(u_weights) * 500
v_size = v_weights / np.max(v_weights) * 500
# 绘制分布U的点
scatter_u = ax.scatter(u_points[:, 0], u_points[:, 1], s=u_size,
alpha=0.7, label='分布 U', color='blue', edgecolors='black')
# 绘制分布V的点
scatter_v = ax.scatter(v_points[:, 0], v_points[:, 1], s=v_size,
alpha=0.7, label='分布 V', color='red', edgecolors='black')
# 简单演示:假设每个点都移动到最近的点(这不是最优传输,仅为示意)
# 在实际应用中,需要求解线性规划得到最优传输计划
for i, u_pt in enumerate(u_points):
# 找到V中最近的点(按欧氏距离)
distances = np.linalg.norm(v_points - u_pt, axis=1)
nearest_idx = np.argmin(distances)
v_pt = v_points[nearest_idx]
# 绘制箭头,透明度根据权重比例调整
arrow_alpha = min(u_weights[i] / np.max(u_weights), 0.5)
ax.arrow(u_pt[0], u_pt[1],
v_pt[0] - u_pt[0], v_pt[1] - u_pt[1],
head_width=0.05, head_length=0.1,
fc='gray', ec='gray', alpha=arrow_alpha, linestyle='--')
ax.set_xlabel('维度 1')
ax.set_ylabel('维度 2')
ax.set_title(f'二维分布可视化 | Wasserstein距离 ≈ {distance:.2f}', fontsize=14)
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_aspect('equal', adjustable='box')
plt.tight_layout()
plt.show()
# 生成一些示例二维数据
np.random.seed(42)
n_points = 8
# 分布U:大致围绕(0,0)的点
u_pts = np.random.randn(n_points, 2) * 0.5
u_wts = np.random.randint(1, 6, size=n_points)
# 分布V:大致围绕(2,2)的点
v_pts = np.random.randn(n_points, 2) * 0.5 + np.array([2, 2])
v_wts = np.random.randint(1, 6, size=n_points)
visualize_2d_wasserstein(u_pts, u_wts, v_pts, v_wts)
```
这张图会显示两簇点,并用虚线箭头连接每个蓝点到最近的红点。**需要强调的是,这展示的不是最优传输计划**,因为最优解通常不是简单的“最近点匹配”。真正的多维最优传输计算量很大,涉及到求解一个可能非常庞大的线性规划问题。箭头只是为了给读者一个空间传输的感性认识。
在实际项目中,当维度升高或支撑点数量很大时,直接计算精确的Wasserstein距离会变得非常昂贵。这时人们通常会采用近似算法,如Sinkhorn迭代(熵正则化最优传输),它在许多机器学习框架中都有实现。
从一维推箱子到多维空间的质量搬运,Wasserstein距离为我们提供了一种强大而直观的工具来度量分布间的差异。通过今天的可视化探索,我希望你不再把它看作一个黑盒函数,而是能想象出那些“箱子”是如何被聪明地搬运,从而最小化总成本的。下次当你调用`wasserstein_distance`时,不妨在脑海中播放一下我们刚刚制作的动画,也许会对你的数据有新的洞察。