Optimal Transport Tools (OTT)#
Introduction#
OTT is a JAX package that bundles
a few utilities to compute, and differentiate as needed, the solution to optimal
transport (OT) problems, taken in a fairly wide sense. For instance, OTT can
compute the Wasserstein distance
(or Gromov-Wasserstein distance) between weighted point clouds
(or histograms) in a wide variety of scenarios,
but also estimate a Monge map, Wasserstein barycenter, or even
help with simpler tasks such as differentiable approximations to ranking or
clustering.
To achieve this, OTT rests on two families of tools:
the first family consists in discrete solvers computing transport between two families of points or histograms using e.g. the Sinkhorn algorithm [Cuturi, 2013] or low-rank solvers [Scetbon et al., 2021], with further extensions to more advanced scenarios such as the Gromov-Wasserstein problem [Mémoli, 2011, Peyré et al., 2016];
the second family consists in continuous solvers, whose goal is to output, given two point cloud samples, a function that is an approximate Monge map, a transport map that can map efficiently the first measure to the second. Such functions can be recovered using directly tools above, notably the family of entropic map approximations. Such maps can also be parameterized as neural architectures such as an MLP or as gradients of input convex neural network [Amos et al., 2017], trained with advanced SGD approaches [Amos, 2023, Korotin et al., 2021, Makkuva et al., 2020]. Such functions can also be parameterized as Neural ODEs, where to an input source point is associated the end-result of a path integral with a time-parameterized velocity field. Such velocity fields can be parameterized as neural networks and learned from data, using the framework of flow matching [Albergo et al., n.d., Lipman et al., 2022] augmented with OT couplings for noise/data [Mousavi-Hosseini et al., 2025, Pooladian et al., 2023, Tong et al., 2023, Zhang et al., 2025].
Installation#
Install OTT from PyPI as:
pip install ott-jax
or with the neural OT dependencies:
pip install 'ott-jax[neural]'
or using conda as:
conda install -c conda-forge ott-jax
Design Choices#
OTT is designed with the following choices:
Take advantage whenever possible of JAX features, such as just-in-time (JIT) compilation, auto-vectorization (VMAP) and both automatic and implicit differentiation.
Split geometry from OT solvers in the discrete case: you will find one, and one implementation only, of every major OT algorithm (Sinkhorn, Gromov-Wasserstein, barycenters, etc…), that are all agnostic to speedups one may benefit from by using a specific cost (e.g. Sinkhorn being the geometric (i.e. the cost function) setup. To give a concrete example, if the inner operations in the Sinkhorn algorithm can be run more efficiently (because e.g. the cost function is low-rank, or the cost is a separable function for points supported on on a separable grid [Solomon et al., 2015]), this should not trigger a separate reimplementation of the Sinkhorn algorithm.
As a consequence, and to minimize code copy/pasting, use as often as possible object hierarchies, and interleave outer solvers (such as quadratic, aka Gromov-Wasserstein solvers) with inner solvers (e.g., low-rank Sinkhorn). This choice ensures that speedups achieved at lower computation levels (e.g. low-rank factorization of squared Euclidean distances) propagate seamlessly and automatically in higher level calls (e.g. updates in Gromov-Wasserstein), without requiring any attention from the user.
Packages#
ott.geometrycontains classes that instantiate the ground cost matrix used to specify a Kantorovich problem. Here cost matrix can be both understood in a literal (by instantiating a matrix) or abstract (by storing information that is sufficient to recreate that matrix, apply all or parts of it, or apply its kernel) sense. An important case is handled by thePointCloudclass which specifies two point clouds, paired with a cost function (to be chosen withinott.geometry.costs). Geometry objects are used to describe OT problems, solved next by solvers.ott.problemsare used to describe the interactions between multiple measures, to define linear (a.k.a. Kantorovich problem), quadratic (a.k.a. Gromov-Wasserstein problem) or Wasserstein barycenter problems.ott.solverssolve a problem instantiated withott.problemsusing one among many implemented approaches.ott.initializersimplement simple strategies to initialize the solvers above. When the problems are solved with a convex solver, such as aLinearProblemsolved with aSinkhornsolver, the resolution of OT solvers, then this initialization is mostly useful to speed up convergences. When the problem is not convex, which is the case for most other uses of this toolbox, the initialization can play a decisive role to reach a useful solution.ott.experimentallists tools whose API is not mature yet to be included in the main toolbox, with changes expected in the near future, but which might still prove useful for users. This includes at the moment theMMSinkhornsolver class to compute an optimal multimarginal couplingott.neuralprovides tools to parameterize and compute an optimal transport map as a neural network. Such networks can be parameterized as an input convex neural network, and trained to approximate the Brenier potential between two measures. Alternatively, one can parameterize that map as a neural ODE, using a time-dependent velocity field trained withflow_matching[Albergo et al., n.d., Lipman et al., 2022] using independent couplings or more advanced OT couplings [Mousavi-Hosseini et al., 2025, Pooladian et al., 2023, Tong et al., 2023, Zhang et al., 2025].ott.toolsprovides an interface to exploit OT solutions produced by solvers from theott.solversmodule. Such tasks include computing approximations to Wasserstein distances [Genevay et al., 2018, Séjourné et al., 2019], approximating OT between GMMs, or computing differentiable sort and quantile operations [Cuturi et al., 2019]. That module also provides plotting tools to display OT solutions.ott.mathholds low-level miscellaneous mathematical primitives, such as an implementation of the matrix square-root, or the Legendre transform.ott.utilsprovides miscellaneous helper functions.
Examples
API
References