(Translated by https://www.hiragana.jp/)
Transport meets Variational Inference: Controlled Monte Carlo Diffusions
\stackMath

Transport meets Variational Inference:
Controlled Monte Carlo Diffusions

Francisco Vargas*, Shreyas Padhy*
University of Cambridge
Cambridge, UK
{fav25,sp2058}@cam.ac.uk &Denis Blessing
KIT
Karlsruhe, Germany
jl8142@kit.edu &Nikolas Nüsken*
Kings College London
London, UK
nik.nuesken@gmx.de
Abstract

Connecting optimal transport and variational inference, we present a principled and systematic framework for sampling and generative modelling centred around divergences on path space. Our work culminates in the development of Controlled Monte Carlo Diffusions for sampling and inference, a score-based annealing technique that crucially adapts both forward and backward dynamics in a diffusion model. On the way, we clarify the relationship between the EM-algorithm and iterative proportional fitting (IPF) for Schrödinger bridges, providing a conceptual link between fields. Finally, we show that CMCD has a strong foundation in the Jarzinsky and Crooks identities from statistical physics, and that it convincingly outperforms competing approaches across a wide array of experiments.

{NoHyper}$*$$*$footnotetext: Equal contribution.

1 Introduction

Optimal transport (Villani et al., 2009) and variational inference (Blei et al., 2017) have for a long time been separate fields of research. In recent years, many fruitful connections have been established (Liu et al., 2019), in particular based on dynamical formulations (Tzen & Raginsky, 2019a), and in conjunction with time reversals (Huang et al., 2021a; Song et al., 2021). The goal of this paper is twofold: In the first part, we enhance those relationships based on forward and reverse time diffusions, and associated Girsanov transformations, arriving at a unifying framework for generative modeling and sampling. In the second part, we build on this and develop a novel score-based scheme for sampling from unnormalised densities. To set the stage, we recall a classical approach (Kingma & Welling, 2014; Rezende & Mohamed, 2015) towards generating samples from a target distribution μみゅー(𝒙)𝜇𝒙\mu({\bm{x}})italic_μみゅー ( bold_italic_x ), which is the goal both in generative modelling and sampling:

Generative processes, encoders and decoders. We consider methodologies which can be implemented via the following generative process,

𝒛νにゅー(𝒛),𝒙|𝒛pθしーた(𝒙|𝒛),formulae-sequencesimilar-to𝒛𝜈𝒛similar-toconditional𝒙𝒛superscript𝑝𝜃conditional𝒙𝒛\displaystyle{\bm{z}}\sim\nu({\bm{z}}),\qquad{\bm{x}}|{\bm{z}}\sim p^{\theta}(% {\bm{x}}|{\bm{z}}),\vspace{-0.1cm}bold_italic_z ∼ italic_νにゅー ( bold_italic_z ) , bold_italic_x | bold_italic_z ∼ italic_p start_POSTSUPERSCRIPT italic_θしーた end_POSTSUPERSCRIPT ( bold_italic_x | bold_italic_z ) , (1)

transforming a sample 𝒛νにゅー(𝒛)similar-to𝒛𝜈𝒛{\bm{z}}\sim\nu({\bm{z}})bold_italic_z ∼ italic_νにゅー ( bold_italic_z ) into a sample 𝒙pθしーた(𝒙|𝒛)νにゅー(d𝒛)similar-to𝒙superscript𝑝𝜃conditional𝒙𝒛𝜈d𝒛{\bm{x}}\sim\int p^{\theta}({\bm{x}}|{\bm{z}})\nu(\mathrm{d}{\bm{z}})bold_italic_x ∼ ∫ italic_p start_POSTSUPERSCRIPT italic_θしーた end_POSTSUPERSCRIPT ( bold_italic_x | bold_italic_z ) italic_νにゅー ( roman_d bold_italic_z ). Traditionally, νにゅー(𝒛)𝜈𝒛\nu({\bm{z}})italic_νにゅー ( bold_italic_z ) is a simple auxiliary distribution, and the family of transitions pθしーた(𝒙|𝒛)superscript𝑝𝜃conditional𝒙𝒛p^{\theta}({\bm{x}}|{\bm{z}})italic_p start_POSTSUPERSCRIPT italic_θしーた end_POSTSUPERSCRIPT ( bold_italic_x | bold_italic_z ) is parameterised flexibly and in such a way that sampling according to (1) is tractable. Then we can frame the tasks of generative modelling and sampling as finding transition densities such that the marginal in 𝒙𝒙{\bm{x}}bold_italic_x matches the target distribution,

μみゅー(𝒙)=pθしーた(𝒙|𝒛)νにゅー(d𝒛).𝜇𝒙superscript𝑝𝜃conditional𝒙𝒛𝜈d𝒛\mu({\bm{x}})=\int p^{\theta}({\bm{x}}|{\bm{z}})\nu(\mathrm{d}{\bm{z}}).% \vspace{-0.1cm}italic_μみゅー ( bold_italic_x ) = ∫ italic_p start_POSTSUPERSCRIPT italic_θしーた end_POSTSUPERSCRIPT ( bold_italic_x | bold_italic_z ) italic_νにゅー ( roman_d bold_italic_z ) . (2)

To learn such a transition, it is helpful to introduce a reversed process

𝒙μみゅー(𝒙),𝒛|𝒙qϕ(𝒛|𝒙),formulae-sequencesimilar-to𝒙𝜇𝒙similar-toconditional𝒛𝒙superscript𝑞italic-ϕconditional𝒛𝒙\displaystyle{\bm{x}}\sim\mu({\bm{x}}),\qquad{\bm{z}}|{\bm{x}}\sim q^{\phi}({% \bm{z}}|{\bm{x}}),bold_italic_x ∼ italic_μみゅー ( bold_italic_x ) , bold_italic_z | bold_italic_x ∼ italic_q start_POSTSUPERSCRIPT italic_ϕ end_POSTSUPERSCRIPT ( bold_italic_z | bold_italic_x ) , (3)

relying on an appropriately parameterised backward transition qϕ(𝒛|𝒙)superscript𝑞italic-ϕconditional𝒛𝒙q^{\phi}({\bm{z}}|{\bm{x}})italic_q start_POSTSUPERSCRIPT italic_ϕ end_POSTSUPERSCRIPT ( bold_italic_z | bold_italic_x ). We will say that (1) and (3) are reversals of each other in the case when their joint distributions coincide, that is, when

qϕ(𝒛|𝒙)μみゅー(𝒙)=pθしーた(𝒙|𝒛)νにゅー(𝒛).superscript𝑞italic-ϕconditional𝒛𝒙𝜇𝒙superscript𝑝𝜃conditional𝒙𝒛𝜈𝒛q^{\phi}({\bm{z}}|{\bm{x}})\mu({\bm{x}})=p^{\theta}({\bm{x}}|{\bm{z}})\nu({\bm% {z}}).\vspace{-0.1cm}italic_q start_POSTSUPERSCRIPT italic_ϕ end_POSTSUPERSCRIPT ( bold_italic_z | bold_italic_x ) italic_μみゅー ( bold_italic_x ) = italic_p start_POSTSUPERSCRIPT italic_θしーた end_POSTSUPERSCRIPT ( bold_italic_x | bold_italic_z ) italic_νにゅー ( bold_italic_z ) . (4)

To appreciate the significance of (3), notice that if (4) holds, then (2) is implied by integrating both sides with respect to 𝒛𝒛{\bm{z}}bold_italic_z. Building on this observation, it is natural to define the loss function

D(ϕ,θしーた):=D(qϕ(𝒛|𝒙)μみゅー(𝒙)||pθしーた(𝒙|𝒛)νにゅー(𝒛)),\mathcal{L}_{D}(\phi,\theta):=D\left(q^{\phi}({\bm{z}}|{\bm{x}})\mu({\bm{x}})% \big{|}\big{|}p^{\theta}({\bm{x}}|{\bm{z}})\nu({\bm{z}})\right),caligraphic_L start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_ϕ , italic_θしーた ) := italic_D ( italic_q start_POSTSUPERSCRIPT italic_ϕ end_POSTSUPERSCRIPT ( bold_italic_z | bold_italic_x ) italic_μみゅー ( bold_italic_x ) | | italic_p start_POSTSUPERSCRIPT italic_θしーた end_POSTSUPERSCRIPT ( bold_italic_x | bold_italic_z ) italic_νにゅー ( bold_italic_z ) ) , (5)

where D𝐷Ditalic_D is a divergence111As usual, divergences are characterised by the requirement that D(αあるふぁ||βべーた)0D(\alpha\big{|}\big{|}\beta)\geq 0italic_D ( italic_αあるふぁ | | italic_βべーた ) ≥ 0, with equality iff αあるふぁ=βべーた𝛼𝛽\alpha=\betaitalic_αあるふぁ = italic_βべーた. between distributions yet to be specified. Along the lines of Bengio et al. (2021); Sohl-Dickstein et al. (2015); Wu et al. (2020); Liu et al. (b), we have now laid the foundations for algorithmic approaches that aim at sampling from μみゅー(𝒙)𝜇𝒙\mu({\bm{x}})italic_μみゅー ( bold_italic_x ) by minimising D(ϕ,θしーた)subscript𝐷italic-ϕ𝜃\mathcal{L}_{D}(\phi,\theta)caligraphic_L start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_ϕ , italic_θしーた ):

Framework 1.

Let D𝐷Ditalic_D be an arbitrary divergence, and assume that D(ϕ,θしーた)=0subscript𝐷italic-ϕ𝜃0\mathcal{L}_{D}(\phi,\theta)=0caligraphic_L start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_ϕ , italic_θしーた ) = 0. Then we have

μみゅー(𝒙)=pθしーた(𝒙|𝒛)νにゅー(d𝒛)andνにゅー(𝒛)=qϕ(𝒛|𝒙)μみゅー(d𝒙),formulae-sequence𝜇𝒙superscript𝑝𝜃conditional𝒙𝒛𝜈d𝒛and𝜈𝒛superscript𝑞italic-ϕconditional𝒛𝒙𝜇d𝒙\mu({\bm{x}})\!=\!\!\int\!\!p^{\theta}({\bm{x}}|{\bm{z}})\nu(\mathrm{d}{\bm{z}% })\;\quad\text{and}\quad\;\nu({\bm{z}})\!\!=\!\!\int\!\!q^{\phi}({\bm{z}}|{\bm% {x}})\mu(\mathrm{d}{\bm{x}}),italic_μみゅー ( bold_italic_x ) = ∫ italic_p start_POSTSUPERSCRIPT italic_θしーた end_POSTSUPERSCRIPT ( bold_italic_x | bold_italic_z ) italic_νにゅー ( roman_d bold_italic_z ) and italic_νにゅー ( bold_italic_z ) = ∫ italic_q start_POSTSUPERSCRIPT italic_ϕ end_POSTSUPERSCRIPT ( bold_italic_z | bold_italic_x ) italic_μみゅー ( roman_d bold_italic_x ) , (6)

that is, νにゅー(𝒛)𝜈𝒛\nu({\bm{z}})italic_νにゅー ( bold_italic_z ) is transformed into μみゅー(𝒙)𝜇𝒙\mu({\bm{x}})italic_μみゅー ( bold_italic_x ) by pθしーた(𝒙|𝒛)superscript𝑝𝜃conditional𝒙𝒛p^{\theta}({\bm{x}}|{\bm{z}})italic_p start_POSTSUPERSCRIPT italic_θしーた end_POSTSUPERSCRIPT ( bold_italic_x | bold_italic_z ), and μみゅー(𝒙)𝜇𝒙\mu({\bm{x}})italic_μみゅー ( bold_italic_x ) is transformed into νにゅー(𝒛)𝜈𝒛\nu({\bm{z}})italic_νにゅー ( bold_italic_z ) by qϕ(𝒛|𝒙)superscript𝑞italic-ϕconditional𝒛𝒙q^{\phi}({\bm{z}}|{\bm{x}})italic_q start_POSTSUPERSCRIPT italic_ϕ end_POSTSUPERSCRIPT ( bold_italic_z | bold_italic_x ).

The sampling problem. Let νにゅー𝜈\nuitalic_νにゅー denote a probability density function on dsuperscript𝑑{\mathbb{R}}^{d}blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT of the form νにゅー(𝒛)=νにゅー^(𝒛)Z,Z=dνにゅー^(𝒛)d𝒛,formulae-sequence𝜈𝒛^𝜈𝒛𝑍𝑍subscriptsuperscript𝑑^𝜈𝒛differential-d𝒛{\nu({\bm{z}})=\frac{\hat{\nu}({\bm{z}})}{Z},\quad Z=\int_{\mathbb{R}^{d}}\hat% {\nu}({\bm{z}})\mathrm{d}{\bm{z}},}italic_νにゅー ( bold_italic_z ) = divide start_ARG over^ start_ARG italic_νにゅー end_ARG ( bold_italic_z ) end_ARG start_ARG italic_Z end_ARG , italic_Z = ∫ start_POSTSUBSCRIPT blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT end_POSTSUBSCRIPT over^ start_ARG italic_νにゅー end_ARG ( bold_italic_z ) roman_d bold_italic_z , where νにゅー^:d+:^𝜈superscript𝑑superscript\hat{\nu}:\mathbb{R}^{d}\rightarrow\mathbb{R}^{+}over^ start_ARG italic_νにゅー end_ARG : blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPTcan be differentiated and evaluated pointwise but the normalizing constant Z𝑍Zitalic_Z is intractable. We are interested in both estimating Z𝑍Zitalic_Z and obtaining approximate samples from νにゅー𝜈\nuitalic_νにゅー given we can sample from a more tractable density μみゅー𝜇\muitalic_μみゅー. Framework 1 provides us with an objective to tackle the sampling problem as once D(ϕ,θしーた)=0subscript𝐷italic-ϕ𝜃0\mathcal{L}_{D}(\phi,\theta)=0caligraphic_L start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_ϕ , italic_θしーた ) = 0, we can generate samples from νにゅー(𝒛)𝜈𝒛\nu({\bm{z}})italic_νにゅー ( bold_italic_z ) via the variational distribution qϕ(𝒛|𝒙)superscript𝑞italic-ϕconditional𝒛𝒙q^{\phi}({\bm{z}}|{\bm{x}})italic_q start_POSTSUPERSCRIPT italic_ϕ end_POSTSUPERSCRIPT ( bold_italic_z | bold_italic_x ). Through variational inference and optimal transport, we discuss relationships to classical methods as well as shortcomings:

KL-divergence, ELBO and variational inference. Choosing D=DKL𝐷subscript𝐷KLD=D_{\mathrm{KL}}italic_D = italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT in (5), variational inference (VI) and latent variable model based approaches (Dempster et al., 1977; Blei et al., 2017; Kingma & Welling, 2014) can elegantly be placed within Framework 1. Indeed, direct computation (see Appendix B) shows that DKL(ϕ,θしーた)=𝔼𝒙μみゅー(𝒙)[ELBOx(ϕ,θしーた)]+𝔼𝒙μみゅー(𝒙)[lnμみゅー(x)]subscriptsubscript𝐷KLitalic-ϕ𝜃subscript𝔼similar-to𝒙𝜇𝒙delimited-[]subscriptELBO𝑥italic-ϕ𝜃subscript𝔼similar-to𝒙𝜇𝒙delimited-[]ln𝜇𝑥\mathcal{L}_{D_{\mathrm{KL}}}(\phi,\theta)=-\mathbb{E}_{{\bm{x}}\sim\mu({\bm{x% }})}[\mathrm{ELBO}_{x}(\phi,\theta)]+\mathbb{E}_{{\bm{x}}\sim\mu({\bm{x}})}[% \operatorname{ln}\mu(x)]caligraphic_L start_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_ϕ , italic_θしーた ) = - blackboard_E start_POSTSUBSCRIPT bold_italic_x ∼ italic_μみゅー ( bold_italic_x ) end_POSTSUBSCRIPT [ roman_ELBO start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ( italic_ϕ , italic_θしーた ) ] + blackboard_E start_POSTSUBSCRIPT bold_italic_x ∼ italic_μみゅー ( bold_italic_x ) end_POSTSUBSCRIPT [ roman_ln italic_μみゅー ( italic_x ) ], so that minimising DKL(ϕ,θしーた)subscriptsubscript𝐷KLitalic-ϕ𝜃\mathcal{L}_{D_{\mathrm{KL}}}(\phi,\theta)caligraphic_L start_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_ϕ , italic_θしーた ) is equivalent to maximising the expected evidence lower bound (ELBO), also known as the negative free energy (Blei et al., 2017). This derivation is alternative to the standard approach via maximum likelihood and convex duality (or Jensen’s inequality) (Kingma et al., 2021, Section 2.2), and directly accomodates various modifications by replacing the DKLsubscript𝐷KLD_{\mathrm{KL}}italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT-divergence (see Appendix B).

Couplings, (optimal) transport and nonuniqueness. Assuming (4) holds, it is natural to define the joint distribution πぱい(𝒙,𝒛):=qϕ(𝒛|𝒙)μみゅー(𝒙)=pθしーた(𝒙|𝒛)νにゅー(𝒛)assign𝜋𝒙𝒛superscript𝑞italic-ϕconditional𝒛𝒙𝜇𝒙superscript𝑝𝜃conditional𝒙𝒛𝜈𝒛\pi({\bm{x}},{\bm{z}}):=q^{\phi}({\bm{z}}|{\bm{x}})\mu({\bm{x}})=p^{\theta}({% \bm{x}}|{\bm{z}})\nu({\bm{z}})italic_πぱい ( bold_italic_x , bold_italic_z ) := italic_q start_POSTSUPERSCRIPT italic_ϕ end_POSTSUPERSCRIPT ( bold_italic_z | bold_italic_x ) italic_μみゅー ( bold_italic_x ) = italic_p start_POSTSUPERSCRIPT italic_θしーた end_POSTSUPERSCRIPT ( bold_italic_x | bold_italic_z ) italic_νにゅー ( bold_italic_z ), which is a coupling between μみゅー(𝒙)𝜇𝒙\mu({\bm{x}})italic_μみゅー ( bold_italic_x ) and νにゅー(𝒛)𝜈𝒛\nu({\bm{z}})italic_νにゅー ( bold_italic_z ). Viewed from this angle, the set of minimisers of (ϕ,θしーた)italic-ϕ𝜃\mathcal{L}(\phi,\theta)caligraphic_L ( italic_ϕ , italic_θしーた ) stands in one-to-one correspondence with the set of couplings between μみゅー(𝒙)𝜇𝒙\mu({\bm{x}})italic_μみゅー ( bold_italic_x ) and νにゅー(𝒛)𝜈𝒛\nu({\bm{z}})italic_νにゅー ( bold_italic_z ), provided that the parameterisations are chosen flexibly enough. Under the latter assumption, the objective in (5) admits an infinite number of minimisers, rendering algorithmic approaches solely based on Framework 1 potentially unstable and their output hard to interpret. In the language of optimal transport (Villani, 2003), minimising (ϕ,θしーた)italic-ϕ𝜃\mathcal{L}(\phi,\theta)caligraphic_L ( italic_ϕ , italic_θしーた ) enforces the marginal (‘transport’) constraints in (6) without a selection principle based on an appropriate cost function (‘optimal’).

Methods such as VAEs (Kingma & Welling, 2014) parameterise pθしーた(𝒙|𝒛)superscript𝑝𝜃conditional𝒙𝒛p^{\theta}({\bm{x}}|{\bm{z}})italic_p start_POSTSUPERSCRIPT italic_θしーた end_POSTSUPERSCRIPT ( bold_italic_x | bold_italic_z ) and qϕ(𝒛|𝒙)superscript𝑞italic-ϕconditional𝒛𝒙q^{\phi}({\bm{z}}|{\bm{x}})italic_q start_POSTSUPERSCRIPT italic_ϕ end_POSTSUPERSCRIPT ( bold_italic_z | bold_italic_x ) with a restricted family of distributions (such as Gaussians), thus restricting the set of couplings. Expectation maximisation (EM) minimises (ϕ,θしーた)italic-ϕ𝜃\mathcal{L}(\phi,\theta)caligraphic_L ( italic_ϕ , italic_θしーた ) in a component-wise fashion, resolving nonquniqueness in a procedural manner (see Section 3.1). Common diffusion models fix either pθしーた(𝒙|𝒛)superscript𝑝𝜃conditional𝒙𝒛p^{\theta}({\bm{x}}|{\bm{z}})italic_p start_POSTSUPERSCRIPT italic_θしーた end_POSTSUPERSCRIPT ( bold_italic_x | bold_italic_z ) or qϕ(𝒛|𝒙)superscript𝑞italic-ϕconditional𝒛𝒙q^{\phi}({\bm{z}}|{\bm{x}})italic_q start_POSTSUPERSCRIPT italic_ϕ end_POSTSUPERSCRIPT ( bold_italic_z | bold_italic_x ), and thus select a coupling (Section 2.2). In this paper, we argue that the full potential of diffusion models can be unleashed by training the forward and backward processes at the same time, but appropriate modifications that resolve the nonuniqueness inherent in Framework 1 need to be imposed. To develop principled approaches towards this, we proceed as follows:

Outline and contributions. In Section 2 we recall hierarchical VAEs (Rezende et al., 2014) and, following Tzen & Raginsky (2019a), proceed to the infinite-depth limit described by the SDEs in (12). Readers more familiar with VI and discrete time might want to take the development in Section 2.1 as an explanation of (12); readers with background in stochastic analysis might take Framework 1 as their starting point. In Proposition 2.2 we provide a generalised form of the Girsanov theorem for forward-reverse time SDEs, crucially incorporating the choice of a reference process that allows us to reason about sampling and generation in a systematic and principled way. We demonstrate that a range of widely used approaches, such as score-based diffusions and path integral samplers, among others, are special cases of our unifying framework (Section 2.2). Similarly in Section 3.1 we unify optimal transport (OT) and VI under our framework by establishing a correspondence between expectation-maximisation (EM) and iterative proportional fitting (IPF). Going further, we show that this framework allows us to derive new methods:

In Section 3.2, we derive a novel score-based annealed flow technique, the Controlled Monte Carlo Diffusion (CMCD) sampler, and show that it may be viewed as an infinitesimal analogue of the method from Section 3.1. Finally, we connect CMCD to the foundational identities by Crooks and Jarzynki in statistical physics, and show that it empirically outperforms a range of state-of-the-art inference methods in sampling and estimating normalizing constants (Section 4).

2 From hierarchical VAEs to forward-reverse time diffusions

2.1 Hierarchical VAEs (Rezende et al., 2014)

A particularly flexible choice of implicitly parameterising pθしーた(𝒙|𝒛)superscript𝑝𝜃conditional𝒙𝒛p^{\theta}({\bm{x}}|{\bm{z}})italic_p start_POSTSUPERSCRIPT italic_θしーた end_POSTSUPERSCRIPT ( bold_italic_x | bold_italic_z ) and qϕ(𝒛|𝒙)superscript𝑞italic-ϕconditional𝒛𝒙q^{\phi}({\bm{z}}|{\bm{x}})italic_q start_POSTSUPERSCRIPT italic_ϕ end_POSTSUPERSCRIPT ( bold_italic_z | bold_italic_x ) can be achieved via a hierarchical model with intermediate latents: We identify 𝒙=:𝒚0{\bm{x}}=:{\bm{y}}_{0}bold_italic_x = : bold_italic_y start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT and 𝒛=:𝒚L{\bm{z}}=:{\bm{y}}_{L}bold_italic_z = : bold_italic_y start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT with the ‘endpoints’ of the layered augmentation (𝒚0,𝒚1,,𝒚L1,𝒚L)=:𝒚0:L({\bm{y}}_{0},{\bm{y}}_{1},\ldots,{\bm{y}}_{L-1},{\bm{y}}_{L})=:{\bm{y}}_{0:L}( bold_italic_y start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_y start_POSTSUBSCRIPT italic_L - 1 end_POSTSUBSCRIPT , bold_italic_y start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ) = : bold_italic_y start_POSTSUBSCRIPT 0 : italic_L end_POSTSUBSCRIPT, and define

qϕ(𝒚L,,𝒚1|𝒚0):=l=1Lqϕl1(𝒚l|𝒚l1),pθしーた(𝒚0,,𝒚L1|𝒚L):=l=1Lpθしーたl(𝒚l1|𝒚l),formulae-sequenceassignsuperscript𝑞italic-ϕsubscript𝒚𝐿conditionalsubscript𝒚1subscript𝒚0superscriptsubscriptproduct𝑙1𝐿superscript𝑞subscriptitalic-ϕ𝑙1conditionalsubscript𝒚𝑙subscript𝒚𝑙1assignsuperscript𝑝𝜃subscript𝒚0conditionalsubscript𝒚𝐿1subscript𝒚𝐿superscriptsubscriptproduct𝑙1𝐿superscript𝑝subscript𝜃𝑙conditionalsubscript𝒚𝑙1subscript𝒚𝑙\displaystyle q^{\phi}({\bm{y}}_{L},\ldots,{\bm{y}}_{1}|{\bm{y}}_{0}):=\prod_{% l=1}^{L}q^{\phi_{l-1}}({\bm{y}}_{l}|{\bm{y}}_{l-1}),\qquad p^{\theta}({\bm{y}}% _{0},\ldots,{\bm{y}}_{L-1}|{\bm{y}}_{L}):=\prod_{l=1}^{L}p^{\theta_{l}}({\bm{y% }}_{l-1}|{\bm{y}}_{l}),\vspace{-0.2cm}italic_q start_POSTSUPERSCRIPT italic_ϕ end_POSTSUPERSCRIPT ( bold_italic_y start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT , … , bold_italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | bold_italic_y start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) := ∏ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT italic_q start_POSTSUPERSCRIPT italic_ϕ start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ( bold_italic_y start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT | bold_italic_y start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ) , italic_p start_POSTSUPERSCRIPT italic_θしーた end_POSTSUPERSCRIPT ( bold_italic_y start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , … , bold_italic_y start_POSTSUBSCRIPT italic_L - 1 end_POSTSUBSCRIPT | bold_italic_y start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ) := ∏ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT italic_p start_POSTSUPERSCRIPT italic_θしーた start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ( bold_italic_y start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT | bold_italic_y start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) , (7)

so that qϕ(𝒛|𝒙)superscript𝑞italic-ϕconditional𝒛𝒙q^{\phi}({\bm{z}}|{\bm{x}})italic_q start_POSTSUPERSCRIPT italic_ϕ end_POSTSUPERSCRIPT ( bold_italic_z | bold_italic_x ) and pθしーた(𝒙|𝒛)superscript𝑝𝜃conditional𝒙𝒛p^{\theta}({\bm{x}}|{\bm{z}})italic_p start_POSTSUPERSCRIPT italic_θしーた end_POSTSUPERSCRIPT ( bold_italic_x | bold_italic_z ) can be obtained from (7) by marginalising over the auxiliary variables 𝒚1,,𝒚L1subscript𝒚1subscript𝒚𝐿1{\bm{y}}_{1},\ldots,{\bm{y}}_{L-1}bold_italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_y start_POSTSUBSCRIPT italic_L - 1 end_POSTSUBSCRIPT. Here, ϕ=(ϕ0,,ϕL1)italic-ϕsubscriptitalic-ϕ0subscriptitalic-ϕ𝐿1\phi=(\phi_{0},\ldots,\phi_{L-1})italic_ϕ = ( italic_ϕ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , … , italic_ϕ start_POSTSUBSCRIPT italic_L - 1 end_POSTSUBSCRIPT ) and θしーた=(θしーた1,,θしーたL)𝜃subscript𝜃1subscript𝜃𝐿\theta=(\theta_{1},\ldots,\theta_{L})italic_θしーた = ( italic_θしーた start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_θしーた start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ) refer to sets of parameters to be specified in more detail below. Further introducing notation, we write qμみゅー,ϕ(𝒚0:L):=qϕ(𝒚1:L|𝒚0)μみゅー(𝒚0)assignsuperscript𝑞𝜇italic-ϕsubscript𝒚:0𝐿superscript𝑞italic-ϕconditionalsubscript𝒚:1𝐿subscript𝒚0𝜇subscript𝒚0q^{\mu,\phi}({\bm{y}}_{0:L}):=q^{\phi}({\bm{y}}_{1:L}|{\bm{y}}_{0})\mu({\bm{y}% }_{0})italic_q start_POSTSUPERSCRIPT italic_μみゅー , italic_ϕ end_POSTSUPERSCRIPT ( bold_italic_y start_POSTSUBSCRIPT 0 : italic_L end_POSTSUBSCRIPT ) := italic_q start_POSTSUPERSCRIPT italic_ϕ end_POSTSUPERSCRIPT ( bold_italic_y start_POSTSUBSCRIPT 1 : italic_L end_POSTSUBSCRIPT | bold_italic_y start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) italic_μみゅー ( bold_italic_y start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) as well as pνにゅー,θしーた(𝒚0:L):=pθしーた(𝒚0:L1|𝒚L)νにゅー(𝒚L)assignsuperscript𝑝𝜈𝜃subscript𝒚:0𝐿superscript𝑝𝜃conditionalsubscript𝒚:0𝐿1subscript𝒚𝐿𝜈subscript𝒚𝐿p^{\nu,\theta}({\bm{y}}_{0:L}):=p^{\theta}({\bm{y}}_{0:L-1}|{\bm{y}}_{L})\nu({% \bm{y}}_{L})italic_p start_POSTSUPERSCRIPT italic_νにゅー , italic_θしーた end_POSTSUPERSCRIPT ( bold_italic_y start_POSTSUBSCRIPT 0 : italic_L end_POSTSUBSCRIPT ) := italic_p start_POSTSUPERSCRIPT italic_θしーた end_POSTSUPERSCRIPT ( bold_italic_y start_POSTSUBSCRIPT 0 : italic_L - 1 end_POSTSUBSCRIPT | bold_italic_y start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ) italic_νにゅー ( bold_italic_y start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ) and think of those implied joint distributions as emanating from μみゅー(𝒙)=μみゅー(𝒚0)𝜇𝒙𝜇subscript𝒚0\mu({\bm{x}})=\mu({\bm{y}}_{0})italic_μみゅー ( bold_italic_x ) = italic_μみゅー ( bold_italic_y start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) and νにゅー(𝒛)=νにゅー(𝒚L)𝜈𝒛𝜈subscript𝒚𝐿\nu({\bm{z}})=\nu({\bm{y}}_{L})italic_νにゅー ( bold_italic_z ) = italic_νにゅー ( bold_italic_y start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ), respectively, moving ‘forwards’ or ‘backwards’ according to the specific choices for ϕitalic-ϕ\phiitalic_ϕ and θしーた𝜃\thetaitalic_θしーた. In the regime when L𝐿Litalic_L is large, the models in (7) are very expressive, even if the intermediate transition kernels are parameterised in a simple manner. We hence proceed by assuming Gaussian distributions,

qϕl1(𝒚l|𝒚l1)=𝒩(𝒚l|𝒚l1+δでるたal1ϕ(𝒚l1),δでるたσしぐま2I),pθしーたl(𝒚l1|𝒚l)=𝒩(𝒚l1|𝒚l+δでるたblθしーた(𝒚l),δでるたσしぐま2I),formulae-sequencesuperscript𝑞subscriptitalic-ϕ𝑙1conditionalsubscript𝒚𝑙subscript𝒚𝑙1𝒩conditionalsubscript𝒚𝑙subscript𝒚𝑙1𝛿subscriptsuperscript𝑎italic-ϕ𝑙1subscript𝒚𝑙1𝛿superscript𝜎2𝐼superscript𝑝subscript𝜃𝑙conditionalsubscript𝒚𝑙1subscript𝒚𝑙𝒩conditionalsubscript𝒚𝑙1subscript𝒚𝑙𝛿subscriptsuperscript𝑏𝜃𝑙subscript𝒚𝑙𝛿superscript𝜎2𝐼\displaystyle q^{\phi_{l-1}}\!({\bm{y}}_{l}|{\bm{y}}_{l-1}\!)\!=\!{\mathcal{N}% }({\bm{y}}_{l}|{\bm{y}}_{l-1}\!\!+\!\delta a^{\phi}_{l-1}({\bm{y}}_{l-1}),% \delta\sigma^{2}\!I),\;\;p^{\theta_{l}}\!({\bm{y}}_{l-1}|{\bm{y}}_{l}\!)\!=\!{% \mathcal{N}}({\bm{y}}_{l-1}|{\bm{y}}_{l}\!\!+\!\delta b^{\theta}_{l}({\bm{y}}_% {l}),\delta\sigma^{2}\!I),italic_q start_POSTSUPERSCRIPT italic_ϕ start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ( bold_italic_y start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT | bold_italic_y start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ) = caligraphic_N ( bold_italic_y start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT | bold_italic_y start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT + italic_δでるた italic_a start_POSTSUPERSCRIPT italic_ϕ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ( bold_italic_y start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ) , italic_δでるた italic_σしぐま start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_I ) , italic_p start_POSTSUPERSCRIPT italic_θしーた start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ( bold_italic_y start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT | bold_italic_y start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) = caligraphic_N ( bold_italic_y start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT | bold_italic_y start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT + italic_δでるた italic_b start_POSTSUPERSCRIPT italic_θしーた end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( bold_italic_y start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) , italic_δでるた italic_σしぐま start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_I ) , (8)

where σしぐま>0𝜎0\sigma>0italic_σしぐま > 0 controls the standard deviation, and δでるた>0𝛿0\delta>0italic_δでるた > 0 is a small parameter, anticipating the limits L𝐿L\rightarrow\inftyitalic_L → ∞, δでるた0𝛿0\delta\rightarrow 0italic_δでるた → 0 to be taken in Section 2.2 below. The vector fields alϕ(𝒚l)subscriptsuperscript𝑎italic-ϕ𝑙subscript𝒚𝑙a^{\phi}_{l}({\bm{y}}_{l})italic_a start_POSTSUPERSCRIPT italic_ϕ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( bold_italic_y start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) and blθしーた(𝒚l)superscriptsubscript𝑏𝑙𝜃subscript𝒚𝑙b_{l}^{\theta}({\bm{y}}_{l})italic_b start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_θしーた end_POSTSUPERSCRIPT ( bold_italic_y start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) introduced in (8) should be thought of as parameterised by ϕitalic-ϕ\phiitalic_ϕ and θしーた𝜃\thetaitalic_θしーた, but we will henceforth suppress this for brevity.

The models (7)-(8) could equivalently be defined via the Markov chains

𝒚l+1subscript𝒚𝑙1\displaystyle{\bm{y}}_{l+1}bold_italic_y start_POSTSUBSCRIPT italic_l + 1 end_POSTSUBSCRIPT =𝒚l+δでるたal(𝒚l)+δでるたσしぐまξくしーl,𝒚0μみゅー𝒚0:Lqμみゅー,ϕ(𝒚0:L),formulae-sequenceabsentsubscript𝒚𝑙𝛿subscript𝑎𝑙subscript𝒚𝑙𝛿𝜎subscript𝜉𝑙similar-tosubscript𝒚0𝜇subscript𝒚:0𝐿similar-tosuperscript𝑞𝜇italic-ϕsubscript𝒚:0𝐿\displaystyle={\bm{y}}_{l}+\delta a_{l}({\bm{y}}_{l})+\sqrt{\delta}\sigma\xi_{% l},\qquad{\bm{y}}_{0}\sim\mu\implies{\bm{y}}_{0:L}\sim q^{\mu,\phi}({\bm{y}}_{% 0:L}),= bold_italic_y start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT + italic_δでるた italic_a start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( bold_italic_y start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) + square-root start_ARG italic_δでるた end_ARG italic_σしぐま italic_ξくしー start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT , bold_italic_y start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∼ italic_μみゅー ⟹ bold_italic_y start_POSTSUBSCRIPT 0 : italic_L end_POSTSUBSCRIPT ∼ italic_q start_POSTSUPERSCRIPT italic_μみゅー , italic_ϕ end_POSTSUPERSCRIPT ( bold_italic_y start_POSTSUBSCRIPT 0 : italic_L end_POSTSUBSCRIPT ) , (9a)
𝒚l1=𝒚l+δでるたbl(𝒚l)+δでるたσしぐまξくしーl,𝒚Lνにゅー𝒚0:Lpνにゅー,θしーた(𝒚0:L),formulae-sequencesubscript𝒚𝑙1subscript𝒚𝑙𝛿subscript𝑏𝑙subscript𝒚𝑙𝛿𝜎subscript𝜉𝑙similar-tosubscript𝒚𝐿𝜈subscript𝒚:0𝐿similar-tosuperscript𝑝𝜈𝜃subscript𝒚:0𝐿\displaystyle{\bm{y}}_{l-1}={\bm{y}}_{l}+\delta b_{l}({\bm{y}}_{l})+\sqrt{% \delta}\sigma\xi_{l},\qquad{\bm{y}}_{L}\sim\nu\implies{\bm{y}}_{0:L}\sim p^{% \nu,\theta}({\bm{y}}_{0:L}),bold_italic_y start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT = bold_italic_y start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT + italic_δでるた italic_b start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( bold_italic_y start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) + square-root start_ARG italic_δでるた end_ARG italic_σしぐま italic_ξくしー start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT , bold_italic_y start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ∼ italic_νにゅー ⟹ bold_italic_y start_POSTSUBSCRIPT 0 : italic_L end_POSTSUBSCRIPT ∼ italic_p start_POSTSUPERSCRIPT italic_νにゅー , italic_θしーた end_POSTSUPERSCRIPT ( bold_italic_y start_POSTSUBSCRIPT 0 : italic_L end_POSTSUBSCRIPT ) , (9b)

where (ξくしーl)l=1Lsuperscriptsubscriptsubscript𝜉𝑙𝑙1𝐿(\xi_{l})_{l=1}^{L}( italic_ξくしー start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT is an iid sequence of standard Gaussian random variables. As indicated, the forward process in (9a) may serve to define the distribution qμみゅー,ϕ(𝒚0:L)superscript𝑞𝜇italic-ϕsubscript𝒚:0𝐿q^{\mu,\phi}({\bm{y}}_{0:L})italic_q start_POSTSUPERSCRIPT italic_μみゅー , italic_ϕ end_POSTSUPERSCRIPT ( bold_italic_y start_POSTSUBSCRIPT 0 : italic_L end_POSTSUBSCRIPT ), whilst the backward process in (9b) induces pνにゅー,θしーた(𝒚0:L)superscript𝑝𝜈𝜃subscript𝒚:0𝐿p^{\nu,\theta}({\bm{y}}_{0:L})italic_p start_POSTSUPERSCRIPT italic_νにゅー , italic_θしーた end_POSTSUPERSCRIPT ( bold_italic_y start_POSTSUBSCRIPT 0 : italic_L end_POSTSUBSCRIPT ). Note that the transition densities pθしーた(𝒙|𝒛)superscript𝑝𝜃conditional𝒙𝒛p^{\theta}({\bm{x}}|{\bm{z}})italic_p start_POSTSUPERSCRIPT italic_θしーた end_POSTSUPERSCRIPT ( bold_italic_x | bold_italic_z ) and qϕ(𝒛|𝒙)superscript𝑞italic-ϕconditional𝒛𝒙q^{\phi}({\bm{z}}|{\bm{x}})italic_q start_POSTSUPERSCRIPT italic_ϕ end_POSTSUPERSCRIPT ( bold_italic_z | bold_italic_x ) obtained as the marginals of (7) will in general not be available in closed form. However, generalising slightly from Framework 1, we may set out to minimise the extended loss

Dext(ϕ,θしーた)=D(qμみゅー,ϕ(𝒚0:L)||pνにゅー,θしーた(𝒚0:L)),\mathcal{L}^{\mathrm{ext}}_{D}(\phi,\theta)=D(q^{\mu,\phi}({\bm{y}}_{0:L})||p^% {\nu,\theta}({\bm{y}}_{0:L})),caligraphic_L start_POSTSUPERSCRIPT roman_ext end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_ϕ , italic_θしーた ) = italic_D ( italic_q start_POSTSUPERSCRIPT italic_μみゅー , italic_ϕ end_POSTSUPERSCRIPT ( bold_italic_y start_POSTSUBSCRIPT 0 : italic_L end_POSTSUBSCRIPT ) | | italic_p start_POSTSUPERSCRIPT italic_νにゅー , italic_θしーた end_POSTSUPERSCRIPT ( bold_italic_y start_POSTSUBSCRIPT 0 : italic_L end_POSTSUBSCRIPT ) ) , (10)

where D𝐷Ditalic_D refers to a divergence on the ‘discrete path space’ {𝒚0:L}subscript𝒚:0𝐿\{{\bm{y}}_{0:L}\}{ bold_italic_y start_POSTSUBSCRIPT 0 : italic_L end_POSTSUBSCRIPT }. Clearly, Dext(ϕ,θしーた)=0superscriptsubscript𝐷extitalic-ϕ𝜃0\mathcal{L}_{D}^{\mathrm{ext}}(\phi,\theta)=0caligraphic_L start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ext end_POSTSUPERSCRIPT ( italic_ϕ , italic_θしーた ) = 0 still implies (6), but is no longer equivalent. More specifically, in the case when D=DKL𝐷subscript𝐷KLD=D_{\mathrm{KL}}italic_D = italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT, the data processing inequality yields

DKL(qμみゅー,ϕ(𝒚0:L)||pνにゅー,θしーた(𝒚0:L))DKL(qϕ(𝒛|𝒙)μみゅー(𝒙)||pθしーた(𝒙|𝒛)νにゅー(𝒛)),\displaystyle D_{\mathrm{KL}}(q^{\mu,\phi}({\bm{y}}_{0:L})||p^{\nu,\theta}({% \bm{y}}_{0:L}))\geq D_{\mathrm{KL}}\left(q^{\phi}({\bm{z}}|{\bm{x}})\mu({\bm{x% }})\big{|}\big{|}p^{\theta}({\bm{x}}|{\bm{z}})\nu({\bm{z}})\right),italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT ( italic_q start_POSTSUPERSCRIPT italic_μみゅー , italic_ϕ end_POSTSUPERSCRIPT ( bold_italic_y start_POSTSUBSCRIPT 0 : italic_L end_POSTSUBSCRIPT ) | | italic_p start_POSTSUPERSCRIPT italic_νにゅー , italic_θしーた end_POSTSUPERSCRIPT ( bold_italic_y start_POSTSUBSCRIPT 0 : italic_L end_POSTSUBSCRIPT ) ) ≥ italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT ( italic_q start_POSTSUPERSCRIPT italic_ϕ end_POSTSUPERSCRIPT ( bold_italic_z | bold_italic_x ) italic_μみゅー ( bold_italic_x ) | | italic_p start_POSTSUPERSCRIPT italic_θしーた end_POSTSUPERSCRIPT ( bold_italic_x | bold_italic_z ) italic_νにゅー ( bold_italic_z ) ) , (11)

so that DKLext(ϕ,θしーた)subscriptsuperscriptextsubscript𝐷KLitalic-ϕ𝜃\mathcal{L}^{\mathrm{ext}}_{D_{\mathrm{KL}}}(\phi,\theta)caligraphic_L start_POSTSUPERSCRIPT roman_ext end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_ϕ , italic_θしーた ) provides an upper bound for DKL(ϕ,θしーた)subscriptsubscript𝐷KLitalic-ϕ𝜃\mathcal{L}_{D_{\mathrm{KL}}}(\phi,\theta)caligraphic_L start_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_ϕ , italic_θしーた ) as defined in (5).

2.2 Diffusion models – hierarchical VAEs in the infinite depth limit

Here we take inspiration from Section 2.1 and Tzen & Raginsky (2019a); Li et al. (2020); Huang et al. (2021a) to investigate the L𝐿L\rightarrow\inftyitalic_L → ∞ limit, using stochastic differential equations (SDEs). To this end, we think of l=0,,L𝑙0𝐿l=0,\ldots,Litalic_l = 0 , … , italic_L as discrete instances in a fixed time interval [0,T]0𝑇[0,T][ 0 , italic_T ], equidistant with time step δでるた𝛿\deltaitalic_δでるた, that is, we set δでるた=TL1𝛿𝑇superscript𝐿1\delta=TL^{-1}italic_δでるた = italic_T italic_L start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT. The discrete paths 𝒚0:Lsubscript𝒚:0𝐿{\bm{y}}_{0:L}bold_italic_y start_POSTSUBSCRIPT 0 : italic_L end_POSTSUBSCRIPT give rise to continuous paths (𝒀t)0tTC([0,T];d)subscriptsubscript𝒀𝑡0𝑡𝑇𝐶0𝑇superscript𝑑({\bm{Y}}_{t})_{0\leq t\leq T}\in C([0,T];\mathbb{R}^{d})( bold_italic_Y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT 0 ≤ italic_t ≤ italic_T end_POSTSUBSCRIPT ∈ italic_C ( [ 0 , italic_T ] ; blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ) by setting 𝒀δでるたl=𝒚lsubscript𝒀𝛿𝑙subscript𝒚𝑙{\bm{Y}}_{\delta l}={\bm{y}}_{l}bold_italic_Y start_POSTSUBSCRIPT italic_δでるた italic_l end_POSTSUBSCRIPT = bold_italic_y start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT and linearly interpolating 𝒀δでるたlsubscript𝒀𝛿𝑙{\bm{Y}}_{\delta l}bold_italic_Y start_POSTSUBSCRIPT italic_δでるた italic_l end_POSTSUBSCRIPT and 𝒀δでるた(l+1)subscript𝒀𝛿𝑙1{\bm{Y}}_{\delta(l+1)}bold_italic_Y start_POSTSUBSCRIPT italic_δでるた ( italic_l + 1 ) end_POSTSUBSCRIPT. To complete the set-up, we think of aϕ=(a0ϕ,,aL1ϕ)superscript𝑎italic-ϕsubscriptsuperscript𝑎italic-ϕ0subscriptsuperscript𝑎italic-ϕ𝐿1a^{\phi}=(a^{\phi}_{0},\ldots,a^{\phi}_{L-1})italic_a start_POSTSUPERSCRIPT italic_ϕ end_POSTSUPERSCRIPT = ( italic_a start_POSTSUPERSCRIPT italic_ϕ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , … , italic_a start_POSTSUPERSCRIPT italic_ϕ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_L - 1 end_POSTSUBSCRIPT ) and bθしーた=(b1θしーた,,bLθしーた)superscript𝑏𝜃superscriptsubscript𝑏1𝜃subscriptsuperscript𝑏𝜃𝐿b^{\theta}=(b_{1}^{\theta},\ldots,b^{\theta}_{L})italic_b start_POSTSUPERSCRIPT italic_θしーた end_POSTSUPERSCRIPT = ( italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_θしーた end_POSTSUPERSCRIPT , … , italic_b start_POSTSUPERSCRIPT italic_θしーた end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ) in (8) as arising from time-dependent vector fields a,bC([0,T]×d;d)𝑎𝑏superscript𝐶0𝑇superscript𝑑superscript𝑑a,b\in C^{\infty}([0,T]\times\mathbb{R}^{d};\mathbb{R}^{d})italic_a , italic_b ∈ italic_C start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT ( [ 0 , italic_T ] × blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ; blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ) via alϕ(𝒚l)=atδでるた1(𝒀δでるたl)subscriptsuperscript𝑎italic-ϕ𝑙subscript𝒚𝑙subscript𝑎𝑡superscript𝛿1subscript𝒀𝛿𝑙a^{\phi}_{l}({\bm{y}}_{l})=a_{t\delta^{-1}}({\bm{Y}}_{\delta l})italic_a start_POSTSUPERSCRIPT italic_ϕ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( bold_italic_y start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) = italic_a start_POSTSUBSCRIPT italic_t italic_δでるた start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( bold_italic_Y start_POSTSUBSCRIPT italic_δでるた italic_l end_POSTSUBSCRIPT ) and blθしーた(𝒚l)=btδでるた1(𝒀δでるたl)subscriptsuperscript𝑏𝜃𝑙subscript𝒚𝑙subscript𝑏𝑡superscript𝛿1subscript𝒀𝛿𝑙b^{\theta}_{l}({\bm{y}}_{l})=b_{t\delta^{-1}}({\bm{Y}}_{\delta l})italic_b start_POSTSUPERSCRIPT italic_θしーた end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( bold_italic_y start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) = italic_b start_POSTSUBSCRIPT italic_t italic_δでるた start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( bold_italic_Y start_POSTSUBSCRIPT italic_δでるた italic_l end_POSTSUBSCRIPT ).

Taking the limit δでるた0𝛿0\delta\rightarrow 0italic_δでるた → 0, while keeping T>0𝑇0T>0italic_T > 0 fixed, transforms the Markov chains in (9) into continuous-time dynamics described by the SDEs (Tzen & Raginsky, 2019a)

d𝒀tdsubscript𝒀𝑡\displaystyle{\mathrm{d}}{\bm{Y}}_{t}roman_d bold_italic_Y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT =at(𝒀t)dt+σしぐまd𝑾t,𝒀0μみゅー(𝒀t)0tTμみゅー,aμみゅー,a,formulae-sequenceabsentsubscript𝑎𝑡subscript𝒀𝑡d𝑡𝜎dsubscript𝑾𝑡similar-tosubscript𝒀0𝜇subscriptsubscript𝒀𝑡0𝑡𝑇similar-tosuperscript𝜇𝑎superscript𝜇𝑎\displaystyle=a_{t}({\bm{Y}}_{t})\,{\mathrm{d}}t+\sigma\overrightarrow{\mathrm% {d}}{\bm{W}}_{t},\quad{\bm{Y}}_{0}\sim\mu\implies({\bm{Y}}_{t})_{0\leq t\leq T% }\sim\mathbb{Q}^{\mu,a}\equiv\overrightarrow{\mathbb{P}}^{\mu,a},= italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_Y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) roman_d italic_t + italic_σしぐま over→ start_ARG roman_d end_ARG bold_italic_W start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_italic_Y start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∼ italic_μみゅー ⟹ ( bold_italic_Y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT 0 ≤ italic_t ≤ italic_T end_POSTSUBSCRIPT ∼ blackboard_Q start_POSTSUPERSCRIPT italic_μみゅー , italic_a end_POSTSUPERSCRIPT ≡ over→ start_ARG blackboard_P end_ARG start_POSTSUPERSCRIPT italic_μみゅー , italic_a end_POSTSUPERSCRIPT , (12a)
d𝒀tdsubscript𝒀𝑡\displaystyle{\mathrm{d}}{\bm{Y}}_{t}roman_d bold_italic_Y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT =bt(𝒀t)dt+σしぐまd𝑾t,𝒀Tνにゅー(𝒀t)0tTνにゅー,bνにゅー,b,formulae-sequenceabsentsubscript𝑏𝑡subscript𝒀𝑡d𝑡𝜎dsubscript𝑾𝑡similar-tosubscript𝒀𝑇𝜈subscriptsubscript𝒀𝑡0𝑡𝑇similar-tosuperscript𝜈𝑏superscript𝜈𝑏\displaystyle=b_{t}({\bm{Y}}_{t})\,{\mathrm{d}}t+\sigma\overleftarrow{\mathrm{% d}}{\bm{W}}_{t},\quad{\bm{Y}}_{T}\sim\nu\implies({\bm{Y}}_{t})_{0\leq t\leq T}% \sim\mathbb{P}^{\nu,b}\equiv\overleftarrow{\mathbb{P}}^{\nu,b},= italic_b start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_Y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) roman_d italic_t + italic_σしぐま over← start_ARG roman_d end_ARG bold_italic_W start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_italic_Y start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ∼ italic_νにゅー ⟹ ( bold_italic_Y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT 0 ≤ italic_t ≤ italic_T end_POSTSUBSCRIPT ∼ blackboard_P start_POSTSUPERSCRIPT italic_νにゅー , italic_b end_POSTSUPERSCRIPT ≡ over← start_ARG blackboard_P end_ARG start_POSTSUPERSCRIPT italic_νにゅー , italic_b end_POSTSUPERSCRIPT , (12b)

where dd\overrightarrow{\mathrm{d}}over→ start_ARG roman_d end_ARG and dd\overleftarrow{\mathrm{d}}over← start_ARG roman_d end_ARG denote forward and backward Itô integration (see Appendix A for more details and remarks on the notation), and (𝑾t)0tTsubscriptsubscript𝑾𝑡0𝑡𝑇({\bm{W}}_{t})_{0\leq t\leq T}( bold_italic_W start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT 0 ≤ italic_t ≤ italic_T end_POSTSUBSCRIPT is a standard Brownian motion. In complete analogy with (9), the SDEs in (12) induce the distributions μみゅー,asuperscript𝜇𝑎\mathbb{Q}^{\mu,a}blackboard_Q start_POSTSUPERSCRIPT italic_μみゅー , italic_a end_POSTSUPERSCRIPT and νにゅー,bsuperscript𝜈𝑏\mathbb{P}^{\nu,b}blackboard_P start_POSTSUPERSCRIPT italic_νにゅー , italic_b end_POSTSUPERSCRIPT on the path space C([0,T];d)𝐶0𝑇superscript𝑑C([0,T];\mathbb{R}^{d})italic_C ( [ 0 , italic_T ] ; blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ). Relating back to the discussion in the introduction, note that we maintain the relations 𝒀0=𝒙subscript𝒀0𝒙{\bm{Y}}_{0}={\bm{x}}bold_italic_Y start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = bold_italic_x and 𝒀T=𝒛subscript𝒀𝑇𝒛{\bm{Y}}_{T}={\bm{z}}bold_italic_Y start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT = bold_italic_z, and the transitions are parameterised by the vector fields a,b𝑎𝑏a,bitalic_a , italic_b, in the sense that pθしーた(𝒙|𝒛)=0νにゅー,bθしーた(𝒙|𝒀T=𝒛)=0δでるた𝒛,bθしーた(𝒙)superscript𝑝𝜃conditional𝒙𝒛subscriptsuperscript𝜈superscript𝑏𝜃0conditional𝒙subscript𝒀𝑇𝒛subscriptsuperscriptsubscript𝛿𝒛superscript𝑏𝜃0𝒙p^{\theta}({\bm{x}}|{\bm{z}})={\mathbb{P}}^{\nu,b^{\theta}}_{0}({\bm{x}}|{\bm{% Y}}_{T}={\bm{z}})={{\mathbb{P}}}^{\delta_{{\bm{z}}},b^{\theta}}_{0}({\bm{x}})italic_p start_POSTSUPERSCRIPT italic_θしーた end_POSTSUPERSCRIPT ( bold_italic_x | bold_italic_z ) = blackboard_P start_POSTSUPERSCRIPT italic_νにゅー , italic_b start_POSTSUPERSCRIPT italic_θしーた end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_italic_x | bold_italic_Y start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT = bold_italic_z ) = blackboard_P start_POSTSUPERSCRIPT italic_δでるた start_POSTSUBSCRIPT bold_italic_z end_POSTSUBSCRIPT , italic_b start_POSTSUPERSCRIPT italic_θしーた end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_italic_x ) and qϕ(𝒛|𝒙)=Tμみゅー,aϕ(𝒛|𝒀0=𝒙)=Tδでるた𝒙,aϕ(𝒛)superscript𝑞italic-ϕconditional𝒛𝒙subscriptsuperscript𝜇superscript𝑎italic-ϕ𝑇conditional𝒛subscript𝒀0𝒙subscriptsuperscriptsubscript𝛿𝒙superscript𝑎italic-ϕ𝑇𝒛q^{\phi}({\bm{z}}|{\bm{x}})={\mathbb{Q}}^{\mu,a^{\phi}}_{T}({\bm{z}}|{\bm{Y}}_% {0}={\bm{x}})={{\mathbb{Q}}}^{\delta_{{\bm{x}}},a^{\phi}}_{T}({\bm{z}})italic_q start_POSTSUPERSCRIPT italic_ϕ end_POSTSUPERSCRIPT ( bold_italic_z | bold_italic_x ) = blackboard_Q start_POSTSUPERSCRIPT italic_μみゅー , italic_a start_POSTSUPERSCRIPT italic_ϕ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ( bold_italic_z | bold_italic_Y start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = bold_italic_x ) = blackboard_Q start_POSTSUPERSCRIPT italic_δでるた start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT , italic_a start_POSTSUPERSCRIPT italic_ϕ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ( bold_italic_z ).

The following well-known result (Anderson, 1982; Nelson, 1967) allows us to relate forward and backward path measures via a local (score-matching) condition for the reversal relation in (4). 222The global condition μみゅー,a=νにゅー,bsuperscript𝜇𝑎superscript𝜈𝑏\overrightarrow{{\mathbb{P}}}^{\mu,a}\!=\!\overleftarrow{{\mathbb{P}}}^{\nu,b}over→ start_ARG blackboard_P end_ARG start_POSTSUPERSCRIPT italic_μみゅー , italic_a end_POSTSUPERSCRIPT = over← start_ARG blackboard_P end_ARG start_POSTSUPERSCRIPT italic_νにゅー , italic_b end_POSTSUPERSCRIPT is captured by the local condition (13) due to (12)’s Markovian nature.

Proposition 2.1 (Nelson’s relation).

For μみゅー𝜇\muitalic_μみゅー and a𝑎aitalic_a of sufficient regularity, denote the time-marginals of the corresponding path measure by tμみゅー,a=:ρろーtμみゅー,a\overrightarrow{{\mathbb{P}}}^{\mu,a}_{t}=:\rho^{\mu,a}_{t}over→ start_ARG blackboard_P end_ARG start_POSTSUPERSCRIPT italic_μみゅー , italic_a end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = : italic_ρろー start_POSTSUPERSCRIPT italic_μみゅー , italic_a end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. Then μみゅー,a=νにゅー,bsuperscript𝜇𝑎superscript𝜈𝑏\overrightarrow{{\mathbb{P}}}^{\mu,a}=\overleftarrow{{\mathbb{P}}}^{\nu,b}over→ start_ARG blackboard_P end_ARG start_POSTSUPERSCRIPT italic_μみゅー , italic_a end_POSTSUPERSCRIPT = over← start_ARG blackboard_P end_ARG start_POSTSUPERSCRIPT italic_νにゅー , italic_b end_POSTSUPERSCRIPT if and only if

νにゅー=Tμみゅー,aandbt=atσしぐま2lnρろーtμみゅー,a,for all t(0,T].formulae-sequence𝜈subscriptsuperscript𝜇𝑎𝑇andformulae-sequencesubscript𝑏𝑡subscript𝑎𝑡superscript𝜎2lnsubscriptsuperscript𝜌𝜇𝑎𝑡for all 𝑡0𝑇\displaystyle\nu=\overrightarrow{{\mathbb{P}}}^{\mu,a}_{T}\qquad\text{and}% \qquad b_{t}=a_{t}-\sigma^{2}\nabla\operatorname{ln}\rho^{\mu,a}_{t},\qquad% \text{for all }t\in(0,T].italic_νにゅー = over→ start_ARG blackboard_P end_ARG start_POSTSUPERSCRIPT italic_μみゅー , italic_a end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT and italic_b start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_σしぐま start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∇ roman_ln italic_ρろー start_POSTSUPERSCRIPT italic_μみゅー , italic_a end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , for all italic_t ∈ ( 0 , italic_T ] . (13)
Remark 1.

A similarly clean characterisation of equality between forward and backward path measures is not available for the discrete-time setting as presented in (9). In particular, Gaussianity of the intermediate transitions is not preserved under time-reversal.

A recurring theme in this work and related literature is the interplay between the score-matching condition in (13) and the global condition D(μみゅー,a|νにゅー,b)=0𝐷conditionalsuperscript𝜇𝑎superscript𝜈𝑏0D(\overrightarrow{{\mathbb{P}}}^{\mu,a}|\overleftarrow{{\mathbb{P}}}^{\nu,b})=0italic_D ( over→ start_ARG blackboard_P end_ARG start_POSTSUPERSCRIPT italic_μみゅー , italic_a end_POSTSUPERSCRIPT | over← start_ARG blackboard_P end_ARG start_POSTSUPERSCRIPT italic_νにゅー , italic_b end_POSTSUPERSCRIPT ) = 0, invoking Framework 1. To enable calculations involving the latter, we will rely on the following result:

Proposition 2.2 (forward-backward Radon-Nikodym derivatives).

Let Γがんま0,γがんま+=ΓがんまT,γがんまsuperscriptsubscriptΓがんま0superscript𝛾superscriptsubscriptΓがんま𝑇superscript𝛾\overrightarrow{{\mathbb{P}}}^{\Gamma_{0},\gamma^{+}}=\overleftarrow{{\mathbb{% P}}}^{\Gamma_{T},\gamma^{-}}over→ start_ARG blackboard_P end_ARG start_POSTSUPERSCRIPT roman_Γがんま start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_γがんま start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT = over← start_ARG blackboard_P end_ARG start_POSTSUPERSCRIPT roman_Γがんま start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT , italic_γがんま start_POSTSUPERSCRIPT - end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT be a reference path measure (that is, Γがんま0subscriptΓがんま0\Gamma_{0}roman_Γがんま start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, ΓがんまTsubscriptΓがんま𝑇\Gamma_{T}roman_Γがんま start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT and γがんま±superscript𝛾plus-or-minus\gamma^{\pm}italic_γがんま start_POSTSUPERSCRIPT ± end_POSTSUPERSCRIPT define diffusions as in (12) and are related as in Proposition 2.1), absolutely continuous with respect to both μみゅー,asuperscript𝜇𝑎\overrightarrow{{\mathbb{P}}}^{\mu,a}over→ start_ARG blackboard_P end_ARG start_POSTSUPERSCRIPT italic_μみゅー , italic_a end_POSTSUPERSCRIPT and νにゅー,bsuperscript𝜈𝑏\overleftarrow{{\mathbb{P}}}^{\nu,b}over← start_ARG blackboard_P end_ARG start_POSTSUPERSCRIPT italic_νにゅー , italic_b end_POSTSUPERSCRIPT. Then, μみゅー,asuperscript𝜇𝑎\overrightarrow{{\mathbb{P}}}^{\mu,a}over→ start_ARG blackboard_P end_ARG start_POSTSUPERSCRIPT italic_μみゅー , italic_a end_POSTSUPERSCRIPT-almost surely, the corresponding Radon-Nikodym derivative (RND) can be expressed as follows,

ln(dμみゅー,adνにゅー,b)(𝒀)lndsuperscript𝜇𝑎dsuperscript𝜈𝑏𝒀\displaystyle\operatorname{ln}\left(\frac{\mathrm{d}\overrightarrow{\mathbb{P}% }^{\mu,a}}{\mathrm{d}\overleftarrow{\mathbb{P}}^{\nu,b}}\right)({\bm{Y}})roman_ln ( divide start_ARG roman_d over→ start_ARG blackboard_P end_ARG start_POSTSUPERSCRIPT italic_μみゅー , italic_a end_POSTSUPERSCRIPT end_ARG start_ARG roman_d over← start_ARG blackboard_P end_ARG start_POSTSUPERSCRIPT italic_νにゅー , italic_b end_POSTSUPERSCRIPT end_ARG ) ( bold_italic_Y ) =ln(dμみゅーdΓがんま0)(𝒀0)ln(dνにゅーdΓがんまT)(𝒀T)absentlnd𝜇dsubscriptΓがんま0subscript𝒀0lnd𝜈dsubscriptΓがんま𝑇subscript𝒀𝑇\displaystyle=\operatorname{ln}\left(\frac{\mathrm{d}\mu}{\mathrm{d}\Gamma_{0}% }\right)({\bm{Y}}_{0})-\operatorname{ln}\left(\frac{\mathrm{d}\nu}{\mathrm{d}% \Gamma_{T}}\right)({\bm{Y}}_{T})= roman_ln ( divide start_ARG roman_d italic_μみゅー end_ARG start_ARG roman_d roman_Γがんま start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG ) ( bold_italic_Y start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) - roman_ln ( divide start_ARG roman_d italic_νにゅー end_ARG start_ARG roman_d roman_Γがんま start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT end_ARG ) ( bold_italic_Y start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ) (14a)
+1σしぐま20T(atγがんまt+)(𝒀t)(d𝒀t12(at+γがんまt+)(𝒀t)dt)1superscript𝜎2superscriptsubscript0𝑇subscript𝑎𝑡subscriptsuperscript𝛾𝑡subscript𝒀𝑡dsubscript𝒀𝑡12subscript𝑎𝑡subscriptsuperscript𝛾𝑡subscript𝒀𝑡d𝑡\displaystyle+\tfrac{1}{\sigma^{2}}\!\int_{0}^{T}\!\!\left(a_{t}-\gamma^{+}_{t% }\right)({\bm{Y}}_{t})\!\cdot\!\left(\overrightarrow{\mathrm{d}}{\bm{Y}}_{t}-% \tfrac{1}{2}\left(a_{t}+\gamma^{+}_{t}\right)({\bm{Y}}_{t})\,\mathrm{d}t\right)+ divide start_ARG 1 end_ARG start_ARG italic_σしぐま start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_γがんま start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ( bold_italic_Y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ⋅ ( over→ start_ARG roman_d end_ARG bold_italic_Y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + italic_γがんま start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ( bold_italic_Y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) roman_d italic_t ) (14b)
1σしぐま20T(btγがんまt)(𝒀t)(d𝒀t12(bt+γがんまt)(𝒀t)dt).1superscript𝜎2superscriptsubscript0𝑇subscript𝑏𝑡subscriptsuperscript𝛾𝑡subscript𝒀𝑡dsubscript𝒀𝑡12subscript𝑏𝑡subscriptsuperscript𝛾𝑡subscript𝒀𝑡d𝑡\displaystyle-\tfrac{1}{\sigma^{2}}\!\int_{0}^{T}\!\!\left(b_{t}-\gamma^{-}_{t% }\right)({\bm{Y}}_{t})\!\cdot\!\left(\overleftarrow{\mathrm{d}}{\bm{Y}}_{t}-% \tfrac{1}{2}\left(b_{t}+\gamma^{-}_{t}\right)({\bm{Y}}_{t})\,\mathrm{d}t\right).- divide start_ARG 1 end_ARG start_ARG italic_σしぐま start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( italic_b start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_γがんま start_POSTSUPERSCRIPT - end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ( bold_italic_Y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ⋅ ( over← start_ARG roman_d end_ARG bold_italic_Y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( italic_b start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + italic_γがんま start_POSTSUPERSCRIPT - end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ( bold_italic_Y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) roman_d italic_t ) . (14c)
Proof.

The proof relies on Girsanov’s theorem (Üstünel & Zakai, 2013), using the reference to relate the forward and backward processes. For details, see Appendix E. ∎

Remark 2 (Role of the reference process).

According to Proposition 2.2, the Radon-Nikodym derivative between μみゅー,asuperscript𝜇𝑎\overrightarrow{{\mathbb{P}}}^{\mu,a}over→ start_ARG blackboard_P end_ARG start_POSTSUPERSCRIPT italic_μみゅー , italic_a end_POSTSUPERSCRIPT and νにゅー,bsuperscript𝜈𝑏\overleftarrow{{\mathbb{P}}}^{\nu,b}over← start_ARG blackboard_P end_ARG start_POSTSUPERSCRIPT italic_νにゅー , italic_b end_POSTSUPERSCRIPT can be decomposed into boundary terms (14a), as well as forward and backward path integrals (14b) and (14c). Since the left-hand side of (14a) does not depend on the reference Γがんま0,TsubscriptΓがんま0𝑇\Gamma_{0,T}roman_Γがんま start_POSTSUBSCRIPT 0 , italic_T end_POSTSUBSCRIPT, γがんま±superscript𝛾plus-or-minus\gamma^{\pm}italic_γがんま start_POSTSUPERSCRIPT ± end_POSTSUPERSCRIPT, the expressions in (14) are in principle equivalent for all choices of reference. The freedom in Γがんま0,TsubscriptΓがんま0𝑇\Gamma_{0,T}roman_Γがんま start_POSTSUBSCRIPT 0 , italic_T end_POSTSUBSCRIPT and γがんま±superscript𝛾plus-or-minus\gamma^{\pm}italic_γがんま start_POSTSUPERSCRIPT ± end_POSTSUPERSCRIPT allows us to ‘reweight’ between (14a), (14b) and (14c), or even cancel terms. A canonical choice is the Lebesgue measure for Γがんま0subscriptΓがんま0\Gamma_{0}roman_Γがんま start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT and ΓがんまTsubscriptΓがんま𝑇\Gamma_{T}roman_Γがんま start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT, and γがんま±=0superscript𝛾plus-or-minus0\gamma^{\pm}=0italic_γがんま start_POSTSUPERSCRIPT ± end_POSTSUPERSCRIPT = 0, see Appendix C.1.

Remark 3 (Discretisation and conversion formulae).

The distinction between forward and backward integration in (14) is related to the time points at which the integrands (atγがんまt+)(𝒀t)subscript𝑎𝑡subscriptsuperscript𝛾𝑡subscript𝒀𝑡\left(a_{t}-\gamma^{+}_{t}\right)({\bm{Y}}_{t})( italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_γがんま start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ( bold_italic_Y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) and (btγがんまt)(𝒀t)subscript𝑏𝑡subscriptsuperscript𝛾𝑡subscript𝒀𝑡\left(b_{t}-\gamma^{-}_{t}\right)({\bm{Y}}_{t})( italic_b start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_γがんま start_POSTSUPERSCRIPT - end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ( bold_italic_Y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) would be evaluated in discrete-time approximations, e.g.,

0Tat(𝒀t)d𝒀tiati(𝒀ti)(𝒀ti+1𝒀ti),0Tat(𝒀t)d𝒀tiati+1(𝒀ti+1)(𝒀ti+1𝒀ti).formulae-sequencesuperscriptsubscript0𝑇subscript𝑎𝑡subscript𝒀𝑡dsubscript𝒀𝑡subscript𝑖subscript𝑎subscript𝑡𝑖subscript𝒀subscript𝑡𝑖subscript𝒀subscript𝑡𝑖1subscript𝒀subscript𝑡𝑖superscriptsubscript0𝑇subscript𝑎𝑡subscript𝒀𝑡dsubscript𝒀𝑡subscript𝑖subscript𝑎subscript𝑡𝑖1subscript𝒀subscript𝑡𝑖1subscript𝒀subscript𝑡𝑖1subscript𝒀subscript𝑡𝑖\displaystyle\int_{0}^{T}\!\!a_{t}({\bm{Y}}_{t})\!\cdot\!\overrightarrow{% \mathrm{d}}{\bm{Y}}_{t}\approx\sum_{i}a_{t_{i}}({\bm{Y}}_{t_{i}})\!\cdot\!({% \bm{Y}}_{t_{i+1}}-{\bm{Y}}_{t_{i}}),\;\;\int_{0}^{T}\!\!a_{t}({\bm{Y}}_{t})\!% \cdot\!\overleftarrow{\mathrm{d}}{\bm{Y}}_{t}\!\approx\!\sum_{i}a_{t_{i+1}}({% \bm{Y}}_{t_{i+1}})\!\cdot\!({\bm{Y}}_{t_{i+1}}-{\bm{Y}}_{t_{i}}).∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_Y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ⋅ over→ start_ARG roman_d end_ARG bold_italic_Y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≈ ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_Y start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) ⋅ ( bold_italic_Y start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT - bold_italic_Y start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) , ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_Y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ⋅ over← start_ARG roman_d end_ARG bold_italic_Y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≈ ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_Y start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) ⋅ ( bold_italic_Y start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT - bold_italic_Y start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) .

Alternatively, forward and backward integrals can be transformed into each other using the conversion

0Tat(𝒀t)d𝒀t=0Tat(𝒀t)d𝒀tσしぐま20T(at)(𝒀t)dt.superscriptsubscript0𝑇subscript𝑎𝑡subscript𝒀𝑡dsubscript𝒀𝑡superscriptsubscript0𝑇subscript𝑎𝑡subscript𝒀𝑡dsubscript𝒀𝑡superscript𝜎2superscriptsubscript0𝑇subscript𝑎𝑡subscript𝒀𝑡differential-d𝑡\int_{0}^{T}\!\!a_{t}({\bm{Y}}_{t})\cdot\overrightarrow{\mathrm{d}}{\bm{Y}}_{t% }=\int_{0}^{T}\!\!a_{t}({\bm{Y}}_{t})\cdot\overleftarrow{\mathrm{d}}{\bm{Y}}_{% t}-\sigma^{2}\!\!\int_{0}^{T}\!\!(\nabla\cdot a_{t})({\bm{Y}}_{t})\,\mathrm{d}t.∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_Y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ⋅ over→ start_ARG roman_d end_ARG bold_italic_Y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_Y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ⋅ over← start_ARG roman_d end_ARG bold_italic_Y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_σしぐま start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( ∇ ⋅ italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ( bold_italic_Y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) roman_d italic_t . (15)

We refer to Kunita (2019) and Appendix A for further details. In passing, we note that (15) allows us to eliminate the Hutchinson estimator (Hutchinson, 1989)from a variety of common score-matching objectives, potentially reducing the variance of gradient estimators, see Appendix C.1.

Framework 1 can be translated into the setting of (12), noting that (11) continues to hold with appropriate modifications:

Framework 1.

For a divergence D𝐷Ditalic_D on path space, minimise D(μみゅー,a|νにゅー,b)𝐷conditionalsuperscript𝜇𝑎superscript𝜈𝑏D(\overrightarrow{{\mathbb{P}}}^{\mu,a}|\overleftarrow{{\mathbb{P}}}^{\nu,b})italic_D ( over→ start_ARG blackboard_P end_ARG start_POSTSUPERSCRIPT italic_μみゅー , italic_a end_POSTSUPERSCRIPT | over← start_ARG blackboard_P end_ARG start_POSTSUPERSCRIPT italic_νにゅー , italic_b end_POSTSUPERSCRIPT ). If D(μみゅー,a|νにゅー,b)=0𝐷conditionalsuperscript𝜇𝑎superscript𝜈𝑏0D(\overrightarrow{{\mathbb{P}}}^{\mu,a}|\overleftarrow{{\mathbb{P}}}^{\nu,b})=0italic_D ( over→ start_ARG blackboard_P end_ARG start_POSTSUPERSCRIPT italic_μみゅー , italic_a end_POSTSUPERSCRIPT | over← start_ARG blackboard_P end_ARG start_POSTSUPERSCRIPT italic_νにゅー , italic_b end_POSTSUPERSCRIPT ) = 0, then (12a) transports μみゅー𝜇\muitalic_μみゅー to νにゅー𝜈\nuitalic_νにゅー, and (12b) transports νにゅー𝜈\nuitalic_νにゅー to μみゅー𝜇\muitalic_μみゅー. 333Concurrently Richter & Berner (2024) propose an akin framework to ours.

At optimality, D(μみゅー,a|νにゅー,b)=0𝐷conditionalsuperscript𝜇𝑎superscript𝜈𝑏0D(\overrightarrow{{\mathbb{P}}}^{\mu,a}|\overleftarrow{{\mathbb{P}}}^{\nu,b})=0italic_D ( over→ start_ARG blackboard_P end_ARG start_POSTSUPERSCRIPT italic_μみゅー , italic_a end_POSTSUPERSCRIPT | over← start_ARG blackboard_P end_ARG start_POSTSUPERSCRIPT italic_νにゅー , italic_b end_POSTSUPERSCRIPT ) = 0, Proposition 2.1 allows us to obtain the scores associated to the learned diffusion via σしぐま2lnρろーtμみゅー,a=atbtsuperscript𝜎2lnsuperscriptsubscript𝜌𝑡𝜇𝑎subscript𝑎𝑡subscript𝑏𝑡\sigma^{2}\nabla\operatorname{ln}\rho_{t}^{\mu,a}=a_{t}-b_{t}italic_σしぐま start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∇ roman_ln italic_ρろー start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_μみゅー , italic_a end_POSTSUPERSCRIPT = italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_b start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. In this way, Framework 1 is closely connected to (and in some ways extends) score-matching ideas (Song & Ermon, 2019; Song et al., 2021). Indeed, recent approaches towards generative modeling and sampling can be recovered from Framework 1 by making specific choices for the divergence D𝐷Ditalic_D, the parameterisations for a𝑎aitalic_a and b𝑏bitalic_b, as well as for the reference diffusion Γがんま0,γがんま+=ΓがんまT,γがんまsuperscriptsubscriptΓがんま0superscript𝛾superscriptsubscriptΓがんま𝑇superscript𝛾\overrightarrow{{\mathbb{P}}}^{\Gamma_{0},\gamma^{+}}=\overleftarrow{{\mathbb{% P}}}^{\Gamma_{T},\gamma^{-}}over→ start_ARG blackboard_P end_ARG start_POSTSUPERSCRIPT roman_Γがんま start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_γがんま start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT = over← start_ARG blackboard_P end_ARG start_POSTSUPERSCRIPT roman_Γがんま start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT , italic_γがんま start_POSTSUPERSCRIPT - end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT in Proposition 2.2:

Score-based generative modeling: Letting μみゅー𝜇\muitalic_μみゅー be the target and fixing the forward drift atsubscript𝑎𝑡a_{t}italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, and, motivated by Proposition 2.1, parameterising the backward drift as bt=atstsubscript𝑏𝑡subscript𝑎𝑡subscript𝑠𝑡b_{t}=a_{t}-s_{t}italic_b start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, we recover the SGM objectives in Hyvärinen & Dayan (2005); Song & Ermon (2019); Song et al. (2021) from D=DKL𝐷subscript𝐷KLD=D_{\mathrm{KL}}italic_D = italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT; when μみゅー,a=νにゅー,bsuperscript𝜇𝑎superscript𝜈𝑏\overrightarrow{{\mathbb{P}}}^{\mu,a}=\overleftarrow{{\mathbb{P}}}^{\nu,b}over→ start_ARG blackboard_P end_ARG start_POSTSUPERSCRIPT italic_μみゅー , italic_a end_POSTSUPERSCRIPT = over← start_ARG blackboard_P end_ARG start_POSTSUPERSCRIPT italic_νにゅー , italic_b end_POSTSUPERSCRIPT, the variable drift component stsubscript𝑠𝑡s_{t}italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT will represent the score σしぐま2lnρろーtμみゅー,asuperscript𝜎2lnsubscriptsuperscript𝜌𝜇𝑎𝑡\sigma^{2}\nabla\operatorname{ln}\rho^{\mu,a}_{t}italic_σしぐま start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∇ roman_ln italic_ρろー start_POSTSUPERSCRIPT italic_μみゅー , italic_a end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. Modifications can be obtained from the conversion formula (15), see Appendix C.2.

Score-based sampling – ergodic drift: In this setting, νにゅー𝜈\nuitalic_νにゅー becomes the target and we fix btsubscript𝑏𝑡b_{t}italic_b start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT to be the drift of an ergodic (backward) process. Then choosing Γがんま0,T=μみゅーsubscriptΓがんま0𝑇𝜇\Gamma_{0,T}=\muroman_Γがんま start_POSTSUBSCRIPT 0 , italic_T end_POSTSUBSCRIPT = italic_μみゅー, γがんま±=bsuperscript𝛾plus-or-minus𝑏\gamma^{\pm}=bitalic_γがんま start_POSTSUPERSCRIPT ± end_POSTSUPERSCRIPT = italic_b allows us to recover the approaches in Vargas et al. (2023a); Berner et al. (2022). Possible generalisations based on Framework 1 include IWAE-type objectives, see Appendix C.3.

Score-based sampling – Föllmer drift: Finally choosing bt(x)=x/tsubscript𝑏𝑡𝑥𝑥𝑡b_{t}(x)=x/titalic_b start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) = italic_x / italic_t we recover Föllmer sampling (Appendix C.3; Follmer, 1984; Vargas et al., 2023b; Zhang & Chen, 2022; Huang et al., 2021b).

3 Learning forward and backward transitions simultaneously

Recall from the introduction that complete flexibility in a𝑎aitalic_a and b𝑏bitalic_b will render the minima of D(μみゅー,a|νにゅー,b)𝐷conditionalsuperscript𝜇𝑎superscript𝜈𝑏D(\overrightarrow{{\mathbb{P}}}^{\mu,a}|\overleftarrow{{\mathbb{P}}}^{\nu,b})italic_D ( over→ start_ARG blackboard_P end_ARG start_POSTSUPERSCRIPT italic_μみゅー , italic_a end_POSTSUPERSCRIPT | over← start_ARG blackboard_P end_ARG start_POSTSUPERSCRIPT italic_νにゅー , italic_b end_POSTSUPERSCRIPT ) highly nonunique. Furthermore, the approaches surveyed at the end of the previous section circumvent this problem by fixing either μみゅー,asuperscript𝜇𝑎\overrightarrow{{\mathbb{P}}}^{\mu,a}over→ start_ARG blackboard_P end_ARG start_POSTSUPERSCRIPT italic_μみゅー , italic_a end_POSTSUPERSCRIPT or νにゅー,bsuperscript𝜈𝑏\overleftarrow{{\mathbb{P}}}^{\nu,b}over← start_ARG blackboard_P end_ARG start_POSTSUPERSCRIPT italic_νにゅー , italic_b end_POSTSUPERSCRIPT. However, to leverage the full power of diffusion models, both μみゅー,asuperscript𝜇𝑎\overrightarrow{{\mathbb{P}}}^{\mu,a}over→ start_ARG blackboard_P end_ARG start_POSTSUPERSCRIPT italic_μみゅー , italic_a end_POSTSUPERSCRIPT or νにゅー,bsuperscript𝜈𝑏\overleftarrow{{\mathbb{P}}}^{\nu,b}over← start_ARG blackboard_P end_ARG start_POSTSUPERSCRIPT italic_νにゅー , italic_b end_POSTSUPERSCRIPT should be adapted to the problem at hand. In this section, we explore models of this kind, by imposing additional constraints on a𝑎aitalic_a and b𝑏bitalic_b. We end this section by presenting our new CMCD sampler connecting it to prior methodology within VI (Doucet et al., 2022b; Geffner & Domke, 2023; Papamakarios et al., 2017) and OT where we can view CMCD as an instance of entropy regularised OT in the infinite constraint limit (Bernton et al., 2019).

3.1 Connection to Entropic optimal transport

One way of selecting a particular transition between μみゅー𝜇\muitalic_μみゅー and νにゅー𝜈\nuitalic_νにゅー is by imposing an entropic penalty, encouraging the dynamics to stay close to a prescribed, oftentimes physically or biologically motivated, reference process. Using the notation employed in Framework 1, the static Schrödinger problem (Schrödinger, 1931; Léonard, 2014a) is given by

πぱい(𝒙,𝒛)argminπぱい(𝒙,𝒛){DKL(πぱい(𝒙,𝒛)||r(𝒙,𝒛)):πぱい𝒙(𝒙)=μみゅー(𝒙),πぱい𝒛(𝒛)=νにゅー(𝒛)},\displaystyle\pi^{*}({\bm{x}},{\bm{z}})\in\operatorname*{arg\,min}_{\pi({\bm{x% }},{\bm{z}})}\Big{\{}D_{\mathrm{KL}}(\pi({\bm{x}},{\bm{z}})||r({\bm{x}},{\bm{z% }})):\pi_{\bm{x}}({\bm{x}})=\mu({\bm{x}}),\pi_{\bm{z}}({\bm{z}})=\nu({\bm{z}})% \Big{\}},italic_πぱい start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( bold_italic_x , bold_italic_z ) ∈ start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_πぱい ( bold_italic_x , bold_italic_z ) end_POSTSUBSCRIPT { italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT ( italic_πぱい ( bold_italic_x , bold_italic_z ) | | italic_r ( bold_italic_x , bold_italic_z ) ) : italic_πぱい start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT ( bold_italic_x ) = italic_μみゅー ( bold_italic_x ) , italic_πぱい start_POSTSUBSCRIPT bold_italic_z end_POSTSUBSCRIPT ( bold_italic_z ) = italic_νにゅー ( bold_italic_z ) } , (16)

where r(𝒙,𝒛)𝑟𝒙𝒛r({\bm{x}},{\bm{z}})italic_r ( bold_italic_x , bold_italic_z ) is the Schrödinger prior encoding additional domain-specific information. In an analogous way, we can introduce a regulariser to the path-space approach of Framework 1’ to obtain the dynamic Schrödinger problem

argminTμみゅー,a=νにゅー𝔼𝒀μみゅー,a[12σしぐま20Tatft2(𝒀t)dt],superscriptsubscriptargminsubscriptsuperscript𝜇𝑎𝑇𝜈subscript𝔼similar-to𝒀superscript𝜇𝑎delimited-[]12superscript𝜎2superscriptsubscript0𝑇superscriptnormsubscript𝑎𝑡subscript𝑓𝑡2subscript𝒀𝑡differential-d𝑡{\mathbb{P}}^{*}\!\!\in\!\operatorname*{arg\,min}_{\overrightarrow{{\mathbb{P}% }}^{\mu,a}_{T}\;=\;\nu}\mathbb{E}_{{\bm{Y}}\sim\overrightarrow{{\mathbb{P}}}^{% \mu,a}}\!\!\left[\tfrac{1}{2\sigma^{2}}\!\int_{0}^{T}\!\!\|a_{t}-f_{t}\|^{2}({% \bm{Y}}_{t})\,\mathrm{d}t\right]\!,blackboard_P start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ∈ start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT over→ start_ARG blackboard_P end_ARG start_POSTSUPERSCRIPT italic_μみゅー , italic_a end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT = italic_νにゅー end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT bold_italic_Y ∼ over→ start_ARG blackboard_P end_ARG start_POSTSUPERSCRIPT italic_μみゅー , italic_a end_POSTSUPERSCRIPT end_POSTSUBSCRIPT [ divide start_ARG 1 end_ARG start_ARG 2 italic_σしぐま start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ∥ italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( bold_italic_Y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) roman_d italic_t ] , (17)

that is, the driving vector field atsubscript𝑎𝑡a_{t}italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT determining superscript{\mathbb{P}}^{*}blackboard_P start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT should be chosen in such a way that (i), the corresponding diffusion transitions from μみゅー𝜇\muitalic_μみゅー to νにゅー𝜈\nuitalic_νにゅー, and (ii), among such diffusions, the vector field atsubscript𝑎𝑡a_{t}italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT remains close to the prescribed vector field ftsubscript𝑓𝑡f_{t}italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, in mean square sense. Under mild conditions, the solutions to (16) and (17) exist and are unique. Further, the static and dynamic viewpoints are related through a mixture-of-bridges construction (assuming that the conditionals r(𝒛|𝒙)𝑟conditional𝒛𝒙r({\bm{z}}|{\bm{x}})italic_r ( bold_italic_z | bold_italic_x ) correspond to the transitions induced by ftsubscript𝑓𝑡f_{t}italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT), see (Léonard, 2014a, Section 2).

Iterative proportional fitting (IPF) and the EM algorithm. It is well known that approximate solutions for πぱい(𝒙,𝒛)superscript𝜋𝒙𝒛\pi^{*}({\bm{x}},{\bm{z}})italic_πぱい start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( bold_italic_x , bold_italic_z ) and superscript{\mathbb{P}}^{*}blackboard_P start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT can be obtained using alternating DKLsubscript𝐷KLD_{\mathrm{KL}}italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT-projections, keeping one of the marginals fixed in each iteration: Under mild conditions, the sequence defined by

πぱい2n+1(𝒙,𝒛)superscript𝜋2𝑛1𝒙𝒛\displaystyle\pi^{2n+1}({\bm{x}},{\bm{z}})italic_πぱい start_POSTSUPERSCRIPT 2 italic_n + 1 end_POSTSUPERSCRIPT ( bold_italic_x , bold_italic_z ) =argminπぱい(𝒙,𝒛){DKL(πぱい(𝒙,𝒛)||πぱい2n(𝒙,𝒛)):πぱい𝒙(𝒙)=μみゅー(𝒙)},\displaystyle=\operatorname*{arg\,min}_{\pi({\bm{x}},{\bm{z}})}\left\{D_{% \mathrm{KL}}(\pi({\bm{x}},{\bm{z}})||\pi^{2n}({\bm{x}},{\bm{z}})):\,\,\pi_{\bm% {x}}({\bm{x}})=\mu({\bm{x}})\right\},= start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_πぱい ( bold_italic_x , bold_italic_z ) end_POSTSUBSCRIPT { italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT ( italic_πぱい ( bold_italic_x , bold_italic_z ) | | italic_πぱい start_POSTSUPERSCRIPT 2 italic_n end_POSTSUPERSCRIPT ( bold_italic_x , bold_italic_z ) ) : italic_πぱい start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT ( bold_italic_x ) = italic_μみゅー ( bold_italic_x ) } , (18a)
πぱい2n+2(𝒙,𝒛)superscript𝜋2𝑛2𝒙𝒛\displaystyle\pi^{2n+2}({\bm{x}},{\bm{z}})italic_πぱい start_POSTSUPERSCRIPT 2 italic_n + 2 end_POSTSUPERSCRIPT ( bold_italic_x , bold_italic_z ) =argminπぱい(𝒙,𝒛){DKL(πぱい(𝒙,𝒛)||πぱい2n+1(𝒙,𝒛)):πぱい𝒛(𝒛)=νにゅー(𝒛)},n0,\displaystyle=\operatorname*{arg\,min}_{\pi({\bm{x}},{\bm{z}})}\left\{D_{% \mathrm{KL}}(\pi({\bm{x}},{\bm{z}})||\pi^{2n+1}({\bm{x}},{\bm{z}})):\,\,\pi_{% \bm{z}}({\bm{z}})=\nu({\bm{z}})\right\},\qquad n\geq 0,= start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_πぱい ( bold_italic_x , bold_italic_z ) end_POSTSUBSCRIPT { italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT ( italic_πぱい ( bold_italic_x , bold_italic_z ) | | italic_πぱい start_POSTSUPERSCRIPT 2 italic_n + 1 end_POSTSUPERSCRIPT ( bold_italic_x , bold_italic_z ) ) : italic_πぱい start_POSTSUBSCRIPT bold_italic_z end_POSTSUBSCRIPT ( bold_italic_z ) = italic_νにゅー ( bold_italic_z ) } , italic_n ≥ 0 , (18b)

with initialisation πぱい0(𝒙,𝒛)=r(𝒙,𝒛)superscript𝜋0𝒙𝒛𝑟𝒙𝒛\pi^{0}({\bm{x}},{\bm{z}})=r({\bm{x}},{\bm{z}})italic_πぱい start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT ( bold_italic_x , bold_italic_z ) = italic_r ( bold_italic_x , bold_italic_z ), converges to πぱい(𝒙,𝒛)superscript𝜋𝒙𝒛\pi^{*}({\bm{x}},{\bm{z}})italic_πぱい start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( bold_italic_x , bold_italic_z ) as n𝑛n\rightarrow\inftyitalic_n → ∞ (De Bortoli et al., 2021), and this procedure is commonly referred to as iterative proportional fitting (IPF) (Fortet, 1940; Kullback, 1968; Ruschendorf, 1995) or Sinkhorn updates (Cuturi, 2013). IPF can straightforwardly be modified to the path space setting of (17), and the resulting updates coincide with the Föllmer drift updates discussed in Section C.3, see (Vargas et al., 2021a) and Appendix E.4.

To further demonstrate the coverage of our framework, we establish a connection between IPF and expectation-maximisation (EM) (Dempster et al., 1977), originally devised for finding maximum likelihood estimates in models with latent (or hidden) variables. According to Neal & Hinton (1998), the EM-algorithm can be described in the setting from the introduction, and written in the form

θしーたn+1=argminθしーたDKL(ϕn,θしーた),ϕn+1=argminϕDKL(ϕ,θしーたn+1),formulae-sequencesubscript𝜃𝑛1subscriptargmin𝜃subscriptsubscript𝐷KLsubscriptitalic-ϕ𝑛𝜃subscriptitalic-ϕ𝑛1subscriptargminitalic-ϕsubscriptsubscript𝐷KLitalic-ϕsubscript𝜃𝑛1\displaystyle\theta_{n+1}=\operatorname*{arg\,min}_{\theta}\mathcal{L}_{D_{% \mathrm{KL}}}(\phi_{n},\theta),\qquad\phi_{n+1}=\operatorname*{arg\,min}_{\phi% }\mathcal{L}_{D_{\mathrm{KL}}}(\phi,\theta_{n+1}),italic_θしーた start_POSTSUBSCRIPT italic_n + 1 end_POSTSUBSCRIPT = start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_θしーた end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_ϕ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_θしーた ) , italic_ϕ start_POSTSUBSCRIPT italic_n + 1 end_POSTSUBSCRIPT = start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_ϕ , italic_θしーた start_POSTSUBSCRIPT italic_n + 1 end_POSTSUBSCRIPT ) , (19)

with DKLsubscriptsubscript𝐷KL\mathcal{L}_{D_{\mathrm{KL}}}caligraphic_L start_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT end_POSTSUBSCRIPT defined as in (5). If the initialisations are matched appropriately, the following result establishes an exact correspondence between the IPF updates in (18) and the EM updates in (19):

Proposition 3.1 (EM iff\iff IPF).

Assume that the transition densities pθしーた(𝐱|𝐳)superscript𝑝𝜃conditional𝐱𝐳p^{\theta}({\bm{x}}|{\bm{z}})italic_p start_POSTSUPERSCRIPT italic_θしーた end_POSTSUPERSCRIPT ( bold_italic_x | bold_italic_z ) and qϕ(𝐳|𝐱)superscript𝑞italic-ϕconditional𝐳𝐱q^{\phi}({\bm{z}}|{\bm{x}})italic_q start_POSTSUPERSCRIPT italic_ϕ end_POSTSUPERSCRIPT ( bold_italic_z | bold_italic_x ) are parameterised with perfect flexibility,444In precise terms, we assume that for any transition densities p(𝐱|𝐳)𝑝conditional𝐱𝐳p({\bm{x}}|{\bm{z}})italic_p ( bold_italic_x | bold_italic_z ) and q(𝐳|𝐱)𝑞conditional𝐳𝐱q({\bm{z}}|{\bm{x}})italic_q ( bold_italic_z | bold_italic_x ), there exist θしーたsubscript𝜃\theta_{*}italic_θしーた start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT and ϕsubscriptitalic-ϕ\phi_{*}italic_ϕ start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT such that p(𝐱|𝐳)=pθしーた(𝐱|𝐳)𝑝conditional𝐱𝐳superscript𝑝subscript𝜃conditional𝐱𝐳p({\bm{x}}|{\bm{z}})=p^{\theta_{*}}({\bm{x}}|{\bm{z}})italic_p ( bold_italic_x | bold_italic_z ) = italic_p start_POSTSUPERSCRIPT italic_θしーた start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ( bold_italic_x | bold_italic_z ) and q(𝐱|𝐳)=qϕ(𝐱|𝐳)𝑞conditional𝐱𝐳superscript𝑞subscriptitalic-ϕconditional𝐱𝐳q({\bm{x}}|{\bm{z}})=q^{\phi_{*}}({\bm{x}}|{\bm{z}})italic_q ( bold_italic_x | bold_italic_z ) = italic_q start_POSTSUPERSCRIPT italic_ϕ start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ( bold_italic_x | bold_italic_z ). and furthermore that the EM-scheme (19) is initialised at ϕ0subscriptitalic-ϕ0\phi_{0}italic_ϕ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT in such a way that qϕ0(𝐳|𝐱)=r(𝐳|𝐱)superscript𝑞subscriptitalic-ϕ0conditional𝐳𝐱𝑟conditional𝐳𝐱q^{\phi_{0}}({\bm{z}}|{\bm{x}})=r({\bm{z}}|{\bm{x}})italic_q start_POSTSUPERSCRIPT italic_ϕ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ( bold_italic_z | bold_italic_x ) = italic_r ( bold_italic_z | bold_italic_x ). Then the IPF iterations in (18) agree with the EM iterations in (19) for all n1𝑛1n\geq 1italic_n ≥ 1, in the sense that

πぱいn(𝒙,𝒛)=qϕ(n1)/2(𝒛|𝒙)μみゅー(𝒙),fornodd,πぱいn(𝒙,𝒛)=pθしーたn/2(𝒙|𝒛)νにゅー(𝒛),forneven.formulae-sequencesuperscript𝜋𝑛𝒙𝒛superscript𝑞subscriptitalic-ϕ𝑛12conditional𝒛𝒙𝜇𝒙for𝑛oddsuperscript𝜋𝑛𝒙𝒛superscript𝑝subscript𝜃𝑛2conditional𝒙𝒛𝜈𝒛for𝑛even\displaystyle\pi^{n}({\bm{x}},{\bm{z}})=q^{\phi_{(n-1)/2}}({\bm{z}}|{\bm{x}})% \mu({\bm{x}}),\quad\text{for}\,\,n\,\text{odd},\quad\pi^{n}({\bm{x}},{\bm{z}})% =p^{\theta_{n/2}}({\bm{x}}|{\bm{z}})\nu({\bm{z}}),\quad\text{for}\,\,n\,\text{% even}.italic_πぱい start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ( bold_italic_x , bold_italic_z ) = italic_q start_POSTSUPERSCRIPT italic_ϕ start_POSTSUBSCRIPT ( italic_n - 1 ) / 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ( bold_italic_z | bold_italic_x ) italic_μみゅー ( bold_italic_x ) , for italic_n odd , italic_πぱい start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ( bold_italic_x , bold_italic_z ) = italic_p start_POSTSUPERSCRIPT italic_θしーた start_POSTSUBSCRIPT italic_n / 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ( bold_italic_x | bold_italic_z ) italic_νにゅー ( bold_italic_z ) , for italic_n even . (20)

From the proof (Appenix E), it is clear that flexibility of parameterisations is crucial, and thus EMIPFiffEMIPF\text{EM}\iff\text{IPF}EM ⇔ IPF fails for classical VAEs, but holds up to a negligle error for the SDE-parameterisations from Section 2.2, see also Liu et al. (b). Under this assumption, the key observation is that replacing forward-DKLsubscript𝐷KLD_{\mathrm{KL}}italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT by reverse-DKLsubscript𝐷KLD_{\mathrm{KL}}italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT in one or both of (18a) and (18b) does not – in theory – change the sequence of minimisers.

In practice favoring the EM objectives over IPF can offer an advantage as optimizing with respect to forward-DKLsubscript𝐷KLD_{\mathrm{KL}}italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT and backward-DKLsubscript𝐷KLD_{\mathrm{KL}}italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT encourages moment-matching and mode-seeking behavior, respectively, and so an alternating scheme as defined in (19) might present a suitable compromise over optimizing a single direction of DKLsubscript𝐷KLD_{\mathrm{KL}}italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT’s, empirical exploration is left for future work.

Whilst EM and IPF might seem appealing for learning a sampler they both require sequentially solving a series of minimization problems, which we can only solve approximately; this is not only slow but also causes a sequential accumulation of errors arising from each iterate (Vargas et al., 2021a; Fernandes et al., 2021). In order to address both issues we will present a novel approach (CMCD) that similarly to IPF learns both the forward and backward processes whilst preserving the desired uniqueness property. However, in contrast to IPF it does so in an end-to-end fashion and performs updates simultaneously. As an alternative in Appendix E.5 we also discuss a regularised IPF objective and leave further empirical exploration for future work.

3.2 Score-based annealing: the Controlled Monte Carlo Diffusion sampler

In this section, we fix a prescribed curve of distributions (πぱいt)t[0,T]subscriptsubscript𝜋𝑡𝑡0𝑇(\pi_{t})_{t\in[0,T]}( italic_πぱい start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_t ∈ [ 0 , italic_T ] end_POSTSUBSCRIPT, whose scores lnπぱいtlnsubscript𝜋𝑡\nabla\operatorname{ln}\pi_{t}∇ roman_ln italic_πぱい start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT (and unnormalised densities πぱい^tsubscript^𝜋𝑡\hat{\pi}_{t}over^ start_ARG italic_πぱい end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT) are assumed to be available in tractable form; this is the scenario typically encountered in annealed importance sampling (IS) and related approaches towards computing posterior expectations (Neal, 2001; Reich, 2011; Heng et al., 2021; 2020; Arbel et al., 2021; Doucet et al., 2022a). The Controlled Monte Carlo Diffusion sampler (CMCD) learns the vector field ϕtsubscriptitalic-ϕ𝑡\nabla\phi_{t}∇ italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT in

d𝒀t=(σしぐま2lnπぱいt(𝒀t)+ϕt(𝒀t))dt+σしぐま2d𝑾t,𝒀0πぱい0,formulae-sequencedsubscript𝒀𝑡superscript𝜎2lnsubscript𝜋𝑡subscript𝒀𝑡subscriptitalic-ϕ𝑡subscript𝒀𝑡d𝑡𝜎2dsubscript𝑾𝑡similar-tosubscript𝒀0subscript𝜋0\displaystyle\mathrm{d}{\bm{Y}}_{t}\!=\!\left(\sigma^{2}\nabla\operatorname{ln% }\pi_{t}({\bm{Y}}_{t})\!+\!\nabla\phi_{t}({\bm{Y}}_{t})\right)\mathrm{d}t+% \sigma\!\sqrt{2}\,\overrightarrow{\mathrm{d}}{\bm{W}}_{t},\qquad{\bm{Y}}_{0}\!% \sim\!\pi_{0},roman_d bold_italic_Y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = ( italic_σしぐま start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∇ roman_ln italic_πぱい start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_Y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) + ∇ italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_Y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ) roman_d italic_t + italic_σしぐま square-root start_ARG 2 end_ARG over→ start_ARG roman_d end_ARG bold_italic_W start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_italic_Y start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∼ italic_πぱい start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , (21)

so that (21) produces the interpolation from the prior πぱい0subscript𝜋0\pi_{0}italic_πぱい start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT to the posterior πぱいTsubscript𝜋𝑇\pi_{T}italic_πぱい start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT, i.e., tπぱい0,σしぐま2lnπぱい+ϕ=πぱいtsubscriptsuperscriptsubscript𝜋0superscript𝜎2ln𝜋italic-ϕ𝑡subscript𝜋𝑡\overrightarrow{{\mathbb{P}}}^{\pi_{0},\sigma^{2}\nabla\operatorname{ln}\pi+% \nabla\phi}_{t}=\pi_{t}over→ start_ARG blackboard_P end_ARG start_POSTSUPERSCRIPT italic_πぱい start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_σしぐま start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∇ roman_ln italic_πぱい + ∇ italic_ϕ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_πぱい start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, for all t[0,T]𝑡0𝑇t\in[0,T]italic_t ∈ [ 0 , italic_T ]. Note that if πぱいtsubscript𝜋𝑡\pi_{t}italic_πぱい start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT were constant in time (πぱいt=πぱい0subscript𝜋𝑡subscript𝜋0\pi_{t}=\pi_{0}italic_πぱい start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_πぱい start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT), then ϕ=0italic-ϕ0\phi=0italic_ϕ = 0 would reduce (21) to equilibrium overdamped Langevin dynamics, preserving πぱい0subscript𝜋0\pi_{0}italic_πぱい start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. With πぱいtsubscript𝜋𝑡\pi_{t}italic_πぱい start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT varying in time, ϕtsubscriptitalic-ϕ𝑡\nabla\phi_{t}∇ italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT can be thought of as a control enabling transitions between neighbouring densities πぱいtsubscript𝜋𝑡\pi_{t}italic_πぱい start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT and πぱいt+δでるたtsubscript𝜋𝑡𝛿𝑡\pi_{t+\delta t}italic_πぱい start_POSTSUBSCRIPT italic_t + italic_δでるた italic_t end_POSTSUBSCRIPT.

To obtain ϕtsubscriptitalic-ϕ𝑡\nabla\phi_{t}∇ italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT we invoke Framework 1, but restrict πぱいT,bsuperscriptsubscript𝜋𝑇𝑏\overleftarrow{{\mathbb{P}}}^{\pi_{T},b}over← start_ARG blackboard_P end_ARG start_POSTSUPERSCRIPT italic_πぱい start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT , italic_b end_POSTSUPERSCRIPT to retain uniqueness. Proposition 2.1 motivates the choice bt=(σしぐま2lnπぱいt+ϕt)2σしぐま2lnπぱいt=σしぐま2lnπぱいt+ϕtsubscript𝑏𝑡superscript𝜎2lnsubscript𝜋𝑡subscriptitalic-ϕ𝑡2superscript𝜎2lnsubscript𝜋𝑡superscript𝜎2lnsubscript𝜋𝑡subscriptitalic-ϕ𝑡b_{t}=(\sigma^{2}\nabla\operatorname{ln}\pi_{t}+\nabla\phi_{t})-2\sigma^{2}% \nabla\operatorname{ln}\pi_{t}=-\sigma^{2}\nabla\operatorname{ln}\pi_{t}+% \nabla\phi_{t}italic_b start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = ( italic_σしぐま start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∇ roman_ln italic_πぱい start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + ∇ italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) - 2 italic_σしぐま start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∇ roman_ln italic_πぱい start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = - italic_σしぐま start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∇ roman_ln italic_πぱい start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + ∇ italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT,555Note the additional factor of 2222 in Nelson’s relation due to the noise scaling σしぐま2d𝑾t𝜎2dsubscript𝑾𝑡\sigma\sqrt{2}\overrightarrow{\mathrm{d}}{\bm{W}}_{t}italic_σしぐま square-root start_ARG 2 end_ARG over→ start_ARG roman_d end_ARG bold_italic_W start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT in (21). leading to

DCMCD(ϕ):=D(πぱい0,σしぐま2lnπぱい+ϕ,πぱいT,σしぐま2lnπぱい+ϕ),assignsubscriptsuperscriptCMCD𝐷italic-ϕ𝐷superscriptsubscript𝜋0superscript𝜎2ln𝜋italic-ϕsuperscriptsubscript𝜋𝑇superscript𝜎2ln𝜋italic-ϕ\displaystyle\!\!\!\!\mathcal{L}^{\mathrm{CMCD}}_{D}(\phi):=D\left({% \overrightarrow{{\mathbb{P}}}^{\pi_{0},\sigma^{2}\nabla\operatorname{ln}\pi+% \nabla\phi}},{\overleftarrow{{\mathbb{P}}}^{\pi_{T},-\sigma^{2}\nabla% \operatorname{ln}\pi+\nabla\phi}}\right),caligraphic_L start_POSTSUPERSCRIPT roman_CMCD end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( italic_ϕ ) := italic_D ( over→ start_ARG blackboard_P end_ARG start_POSTSUPERSCRIPT italic_πぱい start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_σしぐま start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∇ roman_ln italic_πぱい + ∇ italic_ϕ end_POSTSUPERSCRIPT , over← start_ARG blackboard_P end_ARG start_POSTSUPERSCRIPT italic_πぱい start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT , - italic_σしぐま start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∇ roman_ln italic_πぱい + ∇ italic_ϕ end_POSTSUPERSCRIPT ) , (22)

which is valid for any choice of divergence D𝐷Ditalic_D. The additional score constraint bt=at2σしぐま2lnπぱいtsubscript𝑏𝑡subscript𝑎𝑡2superscript𝜎2lnsubscript𝜋𝑡b_{t}=a_{t}-2\sigma^{2}\nabla\operatorname{ln}\pi_{t}italic_b start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - 2 italic_σしぐま start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∇ roman_ln italic_πぱい start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT restores uniqueness in Framework 1 (see Appendix D for a proof):

Algorithm 1 Controlled Monte Carlo Diffusions - Sampling and normalizing constant estimation
πぱい0subscript𝜋0\pi_{0}italic_πぱい start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, πぱいTsubscript𝜋𝑇\pi_{T}italic_πぱい start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT, πぱいtsubscript𝜋𝑡\pi_{t}italic_πぱい start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, σしぐま𝜎\sigmaitalic_σしぐま, K𝐾Kitalic_K step-sizes ΔでるたtkΔでるたsubscript𝑡𝑘\Delta t_{k}roman_Δでるた italic_t start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT, network fϕsuperscript𝑓italic-ϕf^{\phi}italic_f start_POSTSUPERSCRIPT italic_ϕ end_POSTSUPERSCRIPT trained via minimising Eq 24
𝒀0πぱい0similar-tosubscript𝒀0subscript𝜋0{\bm{Y}}_{0}\sim\pi_{0}bold_italic_Y start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∼ italic_πぱい start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT
ln𝑾=lnπぱい0(𝒀0)𝑾lnsubscript𝜋0subscript𝒀0\ln{\bm{W}}=-\operatorname{ln}\pi_{0}({\bm{Y}}_{0})roman_ln bold_italic_W = - roman_ln italic_πぱい start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_italic_Y start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT )
for k=0𝑘0k=0italic_k = 0 to K1𝐾1K-1italic_K - 1 do
     𝒀tk+1𝒩(𝒀tk+1|𝒀tk+(σしぐま2lnπぱいtk+ftkϕ)(𝒀tk)Δでるたtk,2σしぐま2Δでるたtk)similar-tosubscript𝒀subscript𝑡𝑘1𝒩conditionalsubscript𝒀subscript𝑡𝑘1subscript𝒀subscript𝑡𝑘superscript𝜎2subscript𝜋subscript𝑡𝑘subscriptsuperscript𝑓italic-ϕsubscript𝑡𝑘subscript𝒀subscript𝑡𝑘Δでるたsubscript𝑡𝑘2superscript𝜎2Δでるたsubscript𝑡𝑘{\bm{Y}}_{t_{k+1}}\sim{\mathcal{N}}\Big{(}{\bm{Y}}_{t_{k+1}}\big{|}{\bm{Y}}_{t% _{k}}+(\sigma^{2}\nabla\ln\pi_{t_{k}}+f^{\phi}_{t_{k}})({\bm{Y}}_{t_{k}})% \Delta t_{k},2\sigma^{2}\Delta t_{k}\Big{)}bold_italic_Y start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∼ caligraphic_N ( bold_italic_Y start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT | bold_italic_Y start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT + ( italic_σしぐま start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∇ roman_ln italic_πぱい start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT + italic_f start_POSTSUPERSCRIPT italic_ϕ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) ( bold_italic_Y start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) roman_Δでるた italic_t start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , 2 italic_σしぐま start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_Δでるた italic_t start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT )
     ln𝑾=ln𝑾+ln𝒩(𝒀tk|𝒀tk+1+(σしぐま2lnπぱいtk+1ftk+1ϕ)(𝒀tk+1)Δでるたtk,2σしぐま2Δでるたtk)𝒩(𝒀tk+1|𝒀tk+(σしぐま2lnπぱいtk+ftkϕ)(𝒀tk)Δでるたtk,2σしぐま2Δでるたtk)𝑾𝑾𝒩conditionalsubscript𝒀subscript𝑡𝑘subscript𝒀subscript𝑡𝑘1superscript𝜎2subscript𝜋subscript𝑡𝑘1subscriptsuperscript𝑓italic-ϕsubscript𝑡𝑘1