Optimal Transport Tools (OTT)

Documentation Documentation Documentation Coverage

Optimal Transport Tools (OTT)#

Introduction#

OTT is a JAX package that bundles a few utilities to compute, and differentiate as needed, the solution to optimal transport (OT) problems, taken in a fairly wide sense. For instance, OTT can compute the Wasserstein distance (or Gromov-Wasserstein distance) between weighted point clouds (or histograms) in a wide variety of scenarios, but also estimate a Monge map, Wasserstein barycenter, or even help with simpler tasks such as differentiable approximations to ranking or clustering.

To achieve this, OTT rests on two families of tools:

Installation#

Install OTT from PyPI as:

pip install ott-jax

or with the neural OT dependencies:

pip install 'ott-jax[neural]'

or using conda as:

conda install -c conda-forge ott-jax

Design Choices#

OTT is designed with the following choices:

  • Take advantage whenever possible of JAX features, such as just-in-time (JIT) compilation, auto-vectorization (VMAP) and both automatic and implicit differentiation.

  • Split geometry from OT solvers in the discrete case: you will find one, and one implementation only, of every major OT algorithm (Sinkhorn, Gromov-Wasserstein, barycenters, etc…), that are all agnostic to speedups one may benefit from by using a specific cost (e.g. Sinkhorn being the geometric (i.e. the cost function) setup. To give a concrete example, if the inner operations in the Sinkhorn algorithm can be run more efficiently (because e.g. the cost function is low-rank, or the cost is a separable function for points supported on on a separable grid [Solomon et al., 2015]), this should not trigger a separate reimplementation of the Sinkhorn algorithm.

  • As a consequence, and to minimize code copy/pasting, use as often as possible object hierarchies, and interleave outer solvers (such as quadratic, aka Gromov-Wasserstein solvers) with inner solvers (e.g., low-rank Sinkhorn). This choice ensures that speedups achieved at lower computation levels (e.g. low-rank factorization of squared Euclidean distances) propagate seamlessly and automatically in higher level calls (e.g. updates in Gromov-Wasserstein), without requiring any attention from the user.

Packages#