Transport meets Variational Inference:
Controlled Monte Carlo Diffusions
Abstract
Connecting optimal transport and variational inference, we present a principled and systematic framework for sampling and generative modelling centred around divergences on path space. Our work culminates in the development of Controlled Monte Carlo Diffusions for sampling and inference, a score-based annealing technique that crucially adapts both forward and backward dynamics in a diffusion model. On the way, we clarify the relationship between the EM-algorithm and iterative proportional fitting (IPF) for Schrödinger bridges, providing a conceptual link between fields. Finally, we show that CMCD has a strong foundation in the Jarzinsky and Crooks identities from statistical physics, and that it convincingly outperforms competing approaches across a wide array of experiments.
1 Introduction
Optimal transport (Villani et al., 2009) and variational inference (Blei et al., 2017) have for a long time been separate fields of research. In recent years, many fruitful connections have been established (Liu et al., 2019), in particular based on dynamical formulations (Tzen & Raginsky, 2019a), and in conjunction with time reversals (Huang et al., 2021a; Song et al., 2021). The goal of this paper is twofold: In the first part, we enhance those relationships based on forward and reverse time diffusions, and associated Girsanov transformations, arriving at a unifying framework for generative modeling and sampling. In the second part, we build on this and develop a novel score-based scheme for sampling from unnormalised densities. To set the stage, we recall a classical approach (Kingma & Welling, 2014; Rezende & Mohamed, 2015) towards generating samples from a target distribution , which is the goal both in generative modelling and sampling:
Generative processes, encoders and decoders. We consider methodologies which can be implemented via the following generative process,
(1) |
transforming a sample into a sample . Traditionally, is a simple auxiliary distribution, and the family of transitions is parameterised flexibly and in such a way that sampling according to (1) is tractable. Then we can frame the tasks of generative modelling and sampling as finding transition densities such that the marginal in matches the target distribution,
(2) |
To learn such a transition, it is helpful to introduce a reversed process
(3) |
relying on an appropriately parameterised backward transition . We will say that (1) and (3) are reversals of each other in the case when their joint distributions coincide, that is, when
(4) |
To appreciate the significance of (3), notice that if (4) holds, then (2) is implied by integrating both sides with respect to . Building on this observation, it is natural to define the loss function
(5) |
where is a divergence111As usual, divergences are characterised by the requirement that , with equality iff . between distributions yet to be specified. Along the lines of Bengio et al. (2021); Sohl-Dickstein et al. (2015); Wu et al. (2020); Liu et al. (b), we have now laid the foundations for algorithmic approaches that aim at sampling from by minimising :
Framework 1.
Let be an arbitrary divergence, and assume that . Then we have
(6) |
that is, is transformed into by , and is transformed into by .
The sampling problem. Let denote a probability density function on of the form where can be differentiated and evaluated pointwise but the normalizing constant is intractable. We are interested in both estimating and obtaining approximate samples from given we can sample from a more tractable density . Framework 1 provides us with an objective to tackle the sampling problem as once , we can generate samples from via the variational distribution . Through variational inference and optimal transport, we discuss relationships to classical methods as well as shortcomings:
KL-divergence, ELBO and variational inference. Choosing in (5), variational inference (VI) and latent variable model based approaches (Dempster et al., 1977; Blei et al., 2017; Kingma & Welling, 2014) can elegantly be placed within Framework 1. Indeed, direct computation (see Appendix B) shows that , so that minimising is equivalent to maximising the expected evidence lower bound (ELBO), also known as the negative free energy (Blei et al., 2017). This derivation is alternative to the standard approach via maximum likelihood and convex duality (or Jensen’s inequality) (Kingma et al., 2021, Section 2.2), and directly accomodates various modifications by replacing the -divergence (see Appendix B).
Couplings, (optimal) transport and nonuniqueness. Assuming (4) holds, it is natural to define the joint distribution , which is a coupling between and . Viewed from this angle, the set of minimisers of stands in one-to-one correspondence with the set of couplings between and , provided that the parameterisations are chosen flexibly enough. Under the latter assumption, the objective in (5) admits an infinite number of minimisers, rendering algorithmic approaches solely based on Framework 1 potentially unstable and their output hard to interpret. In the language of optimal transport (Villani, 2003), minimising enforces the marginal (‘transport’) constraints in (6) without a selection principle based on an appropriate cost function (‘optimal’).
Methods such as VAEs (Kingma & Welling, 2014) parameterise and with a restricted family of distributions (such as Gaussians), thus restricting the set of couplings. Expectation maximisation (EM) minimises in a component-wise fashion, resolving nonquniqueness in a procedural manner (see Section 3.1). Common diffusion models fix either or , and thus select a coupling (Section 2.2). In this paper, we argue that the full potential of diffusion models can be unleashed by training the forward and backward processes at the same time, but appropriate modifications that resolve the nonuniqueness inherent in Framework 1 need to be imposed. To develop principled approaches towards this, we proceed as follows:
Outline and contributions. In Section 2 we recall hierarchical VAEs (Rezende et al., 2014) and, following Tzen & Raginsky (2019a), proceed to the infinite-depth limit described by the SDEs in (12). Readers more familiar with VI and discrete time might want to take the development in Section 2.1 as an explanation of (12); readers with background in stochastic analysis might take Framework 1′ as their starting point. In Proposition 2.2 we provide a generalised form of the Girsanov theorem for forward-reverse time SDEs, crucially incorporating the choice of a reference process that allows us to reason about sampling and generation in a systematic and principled way. We demonstrate that a range of widely used approaches, such as score-based diffusions and path integral samplers, among others, are special cases of our unifying framework (Section 2.2). Similarly in Section 3.1 we unify optimal transport (OT) and VI under our framework by establishing a correspondence between expectation-maximisation (EM) and iterative proportional fitting (IPF). Going further, we show that this framework allows us to derive new methods:
In Section 3.2, we derive a novel score-based annealed flow technique, the Controlled Monte Carlo Diffusion (CMCD) sampler, and show that it may be viewed as an infinitesimal analogue of the method from Section 3.1. Finally, we connect CMCD to the foundational identities by Crooks and Jarzynki in statistical physics, and show that it empirically outperforms a range of state-of-the-art inference methods in sampling and estimating normalizing constants (Section 4).
2 From hierarchical VAEs to forward-reverse time diffusions
2.1 Hierarchical VAEs (Rezende et al., 2014)
A particularly flexible choice of implicitly parameterising and can be achieved via a hierarchical model with intermediate latents: We identify and with the ‘endpoints’ of the layered augmentation , and define
(7) |
so that and can be obtained from (7) by marginalising over the auxiliary variables . Here, and refer to sets of parameters to be specified in more detail below. Further introducing notation, we write as well as and think of those implied joint distributions as emanating from and , respectively, moving ‘forwards’ or ‘backwards’ according to the specific choices for and . In the regime when is large, the models in (7) are very expressive, even if the intermediate transition kernels are parameterised in a simple manner. We hence proceed by assuming Gaussian distributions,
(8) |
where controls the standard deviation, and is a small parameter, anticipating the limits , to be taken in Section 2.2 below. The vector fields and introduced in (8) should be thought of as parameterised by and , but we will henceforth suppress this for brevity.
The models (7)-(8) could equivalently be defined via the Markov chains
(9a) | ||||
(9b) |
where is an iid sequence of standard Gaussian random variables. As indicated, the forward process in (9a) may serve to define the distribution , whilst the backward process in (9b) induces . Note that the transition densities and obtained as the marginals of (7) will in general not be available in closed form. However, generalising slightly from Framework 1, we may set out to minimise the extended loss
(10) |
where refers to a divergence on the ‘discrete path space’ . Clearly, still implies (6), but is no longer equivalent. More specifically, in the case when , the data processing inequality yields
(11) |
so that provides an upper bound for as defined in (5).
2.2 Diffusion models – hierarchical VAEs in the infinite depth limit
Here we take inspiration from Section 2.1 and Tzen & Raginsky (2019a); Li et al. (2020); Huang et al. (2021a) to investigate the limit, using stochastic differential equations (SDEs). To this end, we think of as discrete instances in a fixed time interval , equidistant with time step , that is, we set . The discrete paths give rise to continuous paths by setting and linearly interpolating and . To complete the set-up, we think of and in (8) as arising from time-dependent vector fields via and .
Taking the limit , while keeping fixed, transforms the Markov chains in (9) into continuous-time dynamics described by the SDEs (Tzen & Raginsky, 2019a)
(12a) | ||||
(12b) |
where and denote forward and backward Itô integration (see Appendix A for more details and remarks on the notation), and is a standard Brownian motion. In complete analogy with (9), the SDEs in (12) induce the distributions and on the path space . Relating back to the discussion in the introduction, note that we maintain the relations and , and the transitions are parameterised by the vector fields , in the sense that and .
The following well-known result (Anderson, 1982; Nelson, 1967) allows us to relate forward and backward path measures via a local (score-matching) condition for the reversal relation in (4). 222The global condition is captured by the local condition (13) due to (12)’s Markovian nature.
Proposition 2.1 (Nelson’s relation).
For and of sufficient regularity, denote the time-marginals of the corresponding path measure by . Then if and only if
(13) |
Remark 1.
A similarly clean characterisation of equality between forward and backward path measures is not available for the discrete-time setting as presented in (9). In particular, Gaussianity of the intermediate transitions is not preserved under time-reversal.
A recurring theme in this work and related literature is the interplay between the score-matching condition in (13) and the global condition , invoking Framework 1. To enable calculations involving the latter, we will rely on the following result:
Proposition 2.2 (forward-backward Radon-Nikodym derivatives).
Proof.
Remark 2 (Role of the reference process).
According to Proposition 2.2, the Radon-Nikodym derivative between and can be decomposed into boundary terms (14a), as well as forward and backward path integrals (14b) and (14c). Since the left-hand side of (14a) does not depend on the reference , , the expressions in (14) are in principle equivalent for all choices of reference. The freedom in and allows us to ‘reweight’ between (14a), (14b) and (14c), or even cancel terms. A canonical choice is the Lebesgue measure for and , and , see Appendix C.1.
Remark 3 (Discretisation and conversion formulae).
The distinction between forward and backward integration in (14) is related to the time points at which the integrands and would be evaluated in discrete-time approximations, e.g.,
Alternatively, forward and backward integrals can be transformed into each other using the conversion
(15) |
We refer to Kunita (2019) and Appendix A for further details. In passing, we note that (15) allows us to eliminate the Hutchinson estimator (Hutchinson, 1989)from a variety of common score-matching objectives, potentially reducing the variance of gradient estimators, see Appendix C.1.
Framework 1 can be translated into the setting of (12), noting that (11) continues to hold with appropriate modifications:
Framework 1′.
At optimality, , Proposition 2.1 allows us to obtain the scores associated to the learned diffusion via . In this way, Framework 1′ is closely connected to (and in some ways extends) score-matching ideas (Song & Ermon, 2019; Song et al., 2021). Indeed, recent approaches towards generative modeling and sampling can be recovered from Framework 1′ by making specific choices for the divergence , the parameterisations for and , as well as for the reference diffusion in Proposition 2.2:
Score-based generative modeling: Letting be the target and fixing the forward drift , and, motivated by Proposition 2.1, parameterising the backward drift as , we recover the SGM objectives in Hyvärinen & Dayan (2005); Song & Ermon (2019); Song et al. (2021) from ; when , the variable drift component will represent the score . Modifications can be obtained from the conversion formula (15), see Appendix C.2.
Score-based sampling – ergodic drift: In this setting, becomes the target and we fix to be the drift of an ergodic (backward) process. Then choosing , allows us to recover the approaches in Vargas et al. (2023a); Berner et al. (2022). Possible generalisations based on Framework 1′ include IWAE-type objectives, see Appendix C.3.
3 Learning forward and backward transitions simultaneously
Recall from the introduction that complete flexibility in and will render the minima of highly nonunique. Furthermore, the approaches surveyed at the end of the previous section circumvent this problem by fixing either or . However, to leverage the full power of diffusion models, both or should be adapted to the problem at hand. In this section, we explore models of this kind, by imposing additional constraints on and . We end this section by presenting our new CMCD sampler connecting it to prior methodology within VI (Doucet et al., 2022b; Geffner & Domke, 2023; Papamakarios et al., 2017) and OT where we can view CMCD as an instance of entropy regularised OT in the infinite constraint limit (Bernton et al., 2019).
3.1 Connection to Entropic optimal transport
One way of selecting a particular transition between and is by imposing an entropic penalty, encouraging the dynamics to stay close to a prescribed, oftentimes physically or biologically motivated, reference process. Using the notation employed in Framework 1, the static Schrödinger problem (Schrödinger, 1931; Léonard, 2014a) is given by
(16) |
where is the Schrödinger prior encoding additional domain-specific information. In an analogous way, we can introduce a regulariser to the path-space approach of Framework 1’ to obtain the dynamic Schrödinger problem
(17) |
that is, the driving vector field determining should be chosen in such a way that (i), the corresponding diffusion transitions from to , and (ii), among such diffusions, the vector field remains close to the prescribed vector field , in mean square sense. Under mild conditions, the solutions to (16) and (17) exist and are unique. Further, the static and dynamic viewpoints are related through a mixture-of-bridges construction (assuming that the conditionals correspond to the transitions induced by ), see (Léonard, 2014a, Section 2).
Iterative proportional fitting (IPF) and the EM algorithm. It is well known that approximate solutions for and can be obtained using alternating -projections, keeping one of the marginals fixed in each iteration: Under mild conditions, the sequence defined by
(18a) | ||||
(18b) |
with initialisation , converges to as (De Bortoli et al., 2021), and this procedure is commonly referred to as iterative proportional fitting (IPF) (Fortet, 1940; Kullback, 1968; Ruschendorf, 1995) or Sinkhorn updates (Cuturi, 2013). IPF can straightforwardly be modified to the path space setting of (17), and the resulting updates coincide with the Föllmer drift updates discussed in Section C.3, see (Vargas et al., 2021a) and Appendix E.4.
To further demonstrate the coverage of our framework, we establish a connection between IPF and expectation-maximisation (EM) (Dempster et al., 1977), originally devised for finding maximum likelihood estimates in models with latent (or hidden) variables. According to Neal & Hinton (1998), the EM-algorithm can be described in the setting from the introduction, and written in the form
(19) |
with defined as in (5). If the initialisations are matched appropriately, the following result establishes an exact correspondence between the IPF updates in (18) and the EM updates in (19):
Proposition 3.1 (EM IPF).
Assume that the transition densities and are parameterised with perfect flexibility,444In precise terms, we assume that for any transition densities and , there exist and such that and . and furthermore that the EM-scheme (19) is initialised at in such a way that . Then the IPF iterations in (18) agree with the EM iterations in (19) for all , in the sense that
(20) |
From the proof (Appenix E), it is clear that flexibility of parameterisations is crucial, and thus fails for classical VAEs, but holds up to a negligle error for the SDE-parameterisations from Section 2.2, see also Liu et al. (b). Under this assumption, the key observation is that replacing forward- by reverse- in one or both of (18a) and (18b) does not – in theory – change the sequence of minimisers.
In practice favoring the EM objectives over IPF can offer an advantage as optimizing with respect to forward- and backward- encourages moment-matching and mode-seeking behavior, respectively, and so an alternating scheme as defined in (19) might present a suitable compromise over optimizing a single direction of ’s, empirical exploration is left for future work.
Whilst EM and IPF might seem appealing for learning a sampler they both require sequentially solving a series of minimization problems, which we can only solve approximately; this is not only slow but also causes a sequential accumulation of errors arising from each iterate (Vargas et al., 2021a; Fernandes et al., 2021). In order to address both issues we will present a novel approach (CMCD) that similarly to IPF learns both the forward and backward processes whilst preserving the desired uniqueness property. However, in contrast to IPF it does so in an end-to-end fashion and performs updates simultaneously. As an alternative in Appendix E.5 we also discuss a regularised IPF objective and leave further empirical exploration for future work.
3.2 Score-based annealing: the Controlled Monte Carlo Diffusion sampler
In this section, we fix a prescribed curve of distributions , whose scores (and unnormalised densities ) are assumed to be available in tractable form; this is the scenario typically encountered in annealed importance sampling (IS) and related approaches towards computing posterior expectations (Neal, 2001; Reich, 2011; Heng et al., 2021; 2020; Arbel et al., 2021; Doucet et al., 2022a). The Controlled Monte Carlo Diffusion sampler (CMCD) learns the vector field in
(21) |
so that (21) produces the interpolation from the prior to the posterior , i.e., , for all . Note that if were constant in time (), then would reduce (21) to equilibrium overdamped Langevin dynamics, preserving . With varying in time, can be thought of as a control enabling transitions between neighbouring densities and .
To obtain we invoke Framework 1′, but restrict to retain uniqueness. Proposition 2.1 motivates the choice ,555Note the additional factor of in Nelson’s relation due to the noise scaling in (21). leading to
(22) |
which is valid for any choice of divergence . The additional score constraint restores uniqueness in Framework 1′ (see Appendix D for a proof):