Source code for caloutils.distances.scaled_w1

from typing import Optional

import numpy as np
import torch

from ..utils import scale_b_to_a
from .wasserstein import wasserstein_distance


[docs]def scaled_w1_distance( r: torch.Tensor, f: torch.Tensor, rw: Optional[torch.Tensor] = None, fw: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Calculate the scaled Wasserstein-1 Distance between two 1D samples. The scaled Wasserstein-1 Distance measures the difference between the distributions of two samples while accounting for differences in scale and location. Args: r (torch.Tensor): Reference sample. f (torch.Tensor): Comparison sample. rw (Optional[torch.Tensor]): Weights for the reference sample (optional). fw (Optional[torch.Tensor]): Weights for the comparison sample (optional). Returns: torch.Tensor: Calculated scaled Wasserstein-1 Distance between the two samples. """ assert r.dim() == f.dim() == 1 rs, fs = scale_b_to_a(r, f) return wasserstein_distance(rs, fs, rw, fw)
[docs]def calc_sw1_dist( r: torch.Tensor, f: torch.Tensor, rw: Optional[torch.Tensor] = None, fw: Optional[torch.Tensor] = None, ) -> np.ndarray: """ Calculate scaled Wasserstein-1 distance metrics between two samples over multiple dimensions. Args: r (torch.Tensor): Reference sample. f (torch.Tensor): Comparison sample. rw (Optional[torch.Tensor]): Weights for the reference sample (optional). fw (Optional[torch.Tensor]): Weights for the comparison sample (optional). Returns: np.ndarray: Array of scaled Wasserstein-1 distance metrics calculated for each dimension. """ assert r.shape == f.shape assert (rw is None) == (fw is None) dists = [] for iftx in range(r.shape[-1]): cdfdist = scaled_w1_distance(r[..., iftx], f[..., iftx], rw, fw) dists.append(cdfdist) return torch.stack(dists, axis=0).cpu().numpy()