This paper treats the design of fast samplers for diffusion models as a differentiable optimization problem, and proposes Differentiable Diffusion Sampler Search (DDSS). Our key observation is that one can unroll the sampling chain of a diffusion model and use reparametrization trick (Kingma & Welling, 2013) and gradient rematerialization (Kumar et al., 2019a) to optimize over a class of parametric few-step samplers with respect to a global objective function.
An important challenge for fast DDPM sampling is the mismatch between the training objective (e.g., ELBO or weighted ELBO) and sample quality. Prior work (Watson et al., 2021; Song et al., 2021a) finds that samplers that are optimal with respect to ELBO often lead to worse sample quality and Frechet Inception Distance (FID) scores (Heusel et al., 2017), especially with few inference ´ steps. We propose the use of a perceptual loss within the DDSS framework to find high-fidelity diffusion samplers, motivated by prior work showing that their optimization leads to solutions that correlate better with human perception of quality. We empirically find that using DDSS with the Kernel Inception Distance (KID) (Binkowski et al., 2018) as the perceptual loss indeed leads to ´ fast samplers with significantly better image quality than prior work (see Figure 1). Moreover, our method is robust to different choices of kernels for KID.
Our main contributions are as follows:
We now describe DDSS, our approach to search for fast high-fidelity samplers with a limited budget of K < T steps. Our key observation is that one can backpropagate through the sampling process of a diffusion model via the reparamterization trick (Kingma & Welling, 2013). Equipped with this, we can now use stochastic gradient descent to learn fast samplers by optimizing any given differentiable loss function over a minibatch of model samples.
As we will show in the experiments, despite the fact that our pre-trained DDPMs are trained with discrete timesteps, learning the timesteps is still helpful.
we instead design a perceptual loss which simply compares mean statistics between model samples and real samples in the neural network feature space. These types of objectives have been shown in prior work to better correlate with human perception of sample quality (Johnson et al., 2016; Heusel et al., 2017), which we also confirm in our experiments.
We rely on the representations of the penultimate layer of a pre-trained InceptionV3 classifier (Szegedy et al., 2016) and optimize the Kernel Inception Distance (KID) (Binkowski et al., 2018). ´ Let φ(x) denote the inception features of an image x and pψ represent a diffusion sampler with trainable parameters ψ. For a linear kernel, which works best in our experiments, the objective is:
More generally, KID can be expressed as:
where f ∗ (x) = Ex0 p∼pψ kφ(x, x 0 p ) − Ex0 q∼q(x0)kφ(x, x 0 q ) is the witness function for any differentiable, positive definite kernel k, and kφ(x, y) = k(φ(x), φ(y)). Note that f ∗ attains the supremum of the MMD. To enable stochastic gradient descent, we use an unbiased estimator of KID using a minibatch of n model samples x (1) p . . . x (n) p ∼ pψ and n real samples x (1) q . . . x (n) q ∼ q: