Untitled

This paper treats the design of fast samplers for diffusion models as a differentiable optimization problem, and proposes Differentiable Diffusion Sampler Search (DDSS). Our key observation is that one can unroll the sampling chain of a diffusion model and use reparametrization trick (Kingma & Welling, 2013) and gradient rematerialization (Kumar et al., 2019a) to optimize over a class of parametric few-step samplers with respect to a global objective function.

An important challenge for fast DDPM sampling is the mismatch between the training objective (e.g., ELBO or weighted ELBO) and sample quality. Prior work (Watson et al., 2021; Song et al., 2021a) finds that samplers that are optimal with respect to ELBO often lead to worse sample quality and Frechet Inception Distance (FID) scores (Heusel et al., 2017), especially with few inference ´ steps. We propose the use of a perceptual loss within the DDSS framework to find high-fidelity diffusion samplers, motivated by prior work showing that their optimization leads to solutions that correlate better with human perception of quality. We empirically find that using DDSS with the Kernel Inception Distance (KID) (Binkowski et al., 2018) as the perceptual loss indeed leads to ´ fast samplers with significantly better image quality than prior work (see Figure 1). Moreover, our method is robust to different choices of kernels for KID.

Our main contributions are as follows:

  1. We propose Differentiable Diffusion Sampler Search (DDSS), which uses the reparametrization trick and gradient rematerialization to optimize over a parametric family of fast samplers for diffusion models.
  2. We identify a parametric family of Generalized Gaussian Diffusion Model (GGDM) that admits high-fidelity fast samplers for diffusion models.
  3. We show that using DDSS to optimize samplers by minimizing the Kernel Inception Distance leads to fast diffusion model samplers with state-of-the-art sample quality scores.

1 DIFFERENTIABLE DIFFUSION SAMPLER SEARCH (DDSS)

Untitled

We now describe DDSS, our approach to search for fast high-fidelity samplers with a limited budget of K < T steps. Our key observation is that one can backpropagate through the sampling process of a diffusion model via the reparamterization trick (Kingma & Welling, 2013). Equipped with this, we can now use stochastic gradient descent to learn fast samplers by optimizing any given differentiable loss function over a minibatch of model samples.

As we will show in the experiments, despite the fact that our pre-trained DDPMs are trained with discrete timesteps, learning the timesteps is still helpful.

DIFFERENTIABLE SAMPLE QUALITY SCORES

we instead design a perceptual loss which simply compares mean statistics between model samples and real samples in the neural network feature space. These types of objectives have been shown in prior work to better correlate with human perception of sample quality (Johnson et al., 2016; Heusel et al., 2017), which we also confirm in our experiments.

We rely on the representations of the penultimate layer of a pre-trained InceptionV3 classifier (Szegedy et al., 2016) and optimize the Kernel Inception Distance (KID) (Binkowski et al., 2018). ´ Let φ(x) denote the inception features of an image x and pψ represent a diffusion sampler with trainable parameters ψ. For a linear kernel, which works best in our experiments, the objective is:

Untitled

More generally, KID can be expressed as:

Untitled

where f ∗ (x) = Ex0 p∼pψ kφ(x, x 0 p ) − Ex0 q∼q(x0)kφ(x, x 0 q ) is the witness function for any differentiable, positive definite kernel k, and kφ(x, y) = k(φ(x), φ(y)). Note that f ∗ attains the supremum of the MMD. To enable stochastic gradient descent, we use an unbiased estimator of KID using a minibatch of n model samples x (1) p . . . x (n) p ∼ pψ and n real samples x (1) q . . . x (n) q ∼ q: