we propose a two-stage distillation approach to improving the sampling efficiency of classifier-free guided models. In the first stage, we introduce a single student model to match the combined output of the two diffusion models of the teacher. In the second stage, we progressively distill the model learned from the first stage to a fewer-step model using the approach introduced in [33]. Using our approach, a single distilled model is able to handle a wide range of different guidance strengths, allowing for the trade-off between sample quality and diversity efficiently.
Progressive distillation
Our approach is inspired by progressive distillation [33], an effective method for improving the sampling speed of (unguided) diffusion models by repeated distillation.
Latent diffusion models (LDMs)
increase the training and inference efficiency of diffusion models (directly learned on the pixel-space) by modeling images in the latent space of a pre-trained regularized autoencoder, where the latent representations are usually of lower dimensionality than the pixel-space. Latent diffusion models can be considered as an alternative to cascaded diffusion approaches [5], which rely on one or more super-resolution diffusion models to scale up a low-dimensional image to the desired target resolution.
we discuss our approach for distilling a classifier-free guided diffusion model [6] into a student model that requires fewer steps to sample from. Using a single distilled model conditioned on the guidance strength, our model can capture a wide range of classifier-free guidance levels, allowing for the trade-off between sample quality and diversity efficiently.
our approach can be decomposed into two stages.
In the first stage, we introduce a student model xˆη1 (zt, w), with learnable parameter η1
A key functionality of classifier-free guidance [6] is its ability to easily trade-off between sample quality and diversity, which is controlled by a “guidance strength” parameter. Thus, we would also want our distilled model to maintain this property. Given a range of guidance strengths [wmin, wmax] we are interested in, we optimize the student model using the following objective
Note that here, our distilled model xˆη1 (zt, w) is also conditioned on the context c (e.g., text prompt), but we drop the notation c in the paper for simplicity. We provide the detailed training algorithm in Algorithm 1.
To incorporate the guidance weight w, we introduce a w-conditioned model, where w is fed as an input to the student model. To better capture the feature, we apply Fourier embedding to w, which is then incorporated into the diffusion model backbone in a way similar to how the time-step was incorporated in [10, 33]. As initialization plays a key role in the performance [33], we initialize the student model with the same parameters as the conditional model of the teacher, except for the newly introduced parameters related to w-conditioning.
supplement
we also make the model take w as input. Specifically, we apply Fourier embedding to w before combining with the model backbone. The way we incorporate w is the same as how time-step is incorporated to the model as used in [10, 33]. We parameterize the model to predict v as discussed in [33]. We train the distilled model using Algorithm 2. We train the model using SNR loss [10, 33].