In this work, we propose a prompt adaptation framework for automatic prompt engineering via reinforcement learning. Specifically, we first perform supervised fine-tuning with a pretrained language model (e.g., GPT) on a small collection of manually engineered prompts. The finetuned model is used to initialize the prompt policy network for reinforcement learning. Next, the model is trained by exploring optimized prompts of user inputs, where diverse beam search [VCS+16] is used to ensure generation quality and diversity. The training objective is to maximize the reward, which is defined as a combination of relevance scores and aesthetic scores of generated images. The relevance score reflects how much the original user intentions are retained after prompt adaptation. The aesthetic score indicates what degree the generated images are aesthetically pleasing.

We conduct experiments with the publicly available Stable Diffusion models [RBL+22]. We evaluate our method using both the automatic reward metric and human preference ratings. Moreover, we find that reinforcement learning is more favorable than supervised fine-tuning, especially on out-of-domain user inputs. Overall, we show that language models can serve as a prompt interface that optimizes user input into model-preferred prompts.

Methods

Untitled

MoreThe goal of our prompt adaptation framework is to automatically perform prompt engineering. Given user input of the text-to-image generator, our model learns to generate model-preferred prompts that obtain better output images while preserving their original intentions. Figure 1 presents the overview of our method. The prompt optimization model is named PROMPTIST, which is built upon a pretrained language model, such as GPT [BMR+20].over, we find that reinforcement learning is more favorable than supervised fine-tuning, especially on out-of-domain user inputs. Overall, we show that language models can serve as a prompt interface that optimizes user input into model-preferred prompts.

Supervised Fine-tuning

Initialized with a pretrained generative language model, the policy model is first finetuned on a set of prompt pairs before reinforcement learning. A parallel prompt corpus D = {(x, y)} contains prompt pairs of original user inputs x and manually engineered examples y. The training objective is to maximize the log-likelihood with teacher forcing:

Untitled

Reward Definition

We measure the quality of optimized prompts from two aspects, namely relevance and aesthetics. The goal motivates us to define the reward function R(·) from the above two perspectives.

First, we measure whether the generated images are relevant to the original input prompt after prompt adaptation. To be specific, we first sample images by the text-to-image model conditioned on the the optimized prompt, respectively. Then, we compute CLIP [RKH+21] similarity scores to measure how relevant the generated images and the original input prompts are. The resulting relevance score is defined as:

Untitled

where iy ∼ G(y) means sampling images iy from the text-to-image model G with y as input prompt, and gCLIP(·, ·) stands for the CLIP similarity function.

we employ the aesthetic predictor3 to quantify aesthetic preferences. The predictor builds a linear estimator on top of a frozen CLIP model, which is trained by human ratings in the Aesthetic Visual Analysis [MMP12] dataset. The aesthetic score is defined as:

Untitled

where gaes(·) denotes the aesthetic predictor, and iy, ix are the images generated by the prompts y and x, respectively. Notice that both gCLIP(·) and gaes(·) require the CLIP model, so we can share the CLIP forward pass during reward computation.

Finally, we define the overall reward by combining the above scores with an additional KL penalty, which is between the policy model πθ and the supervised finetuned model πSFT with coefficient η:

Untitled

Reinforcement Learning

Starting from the supervised fine-tuning, we further finetune our model with reinforcement learning. We employ proximal policy optimization (PPO) [SWD+17], which is empirically data-efficient and of reliable performance. As a text generation problem, prompt optimization can be viewed as a Markov decision process (MDP) hS, A, r, fst, γi with a finite state space S, action space A, reward function r, state-transition probability function fst, and a discount term γ. In an episode of prompt adaptation, the initial state x ∈ S is the input prompt with n tokens x = (x1, . . . , xn) where each token x is from a finite vocabulary V. At t-th time step, the agent selects an action yt ∈ V according to the current policy model yt ∼ π(y|x, y<t). With a deterministic state transition, the next state is (x, y<t+1) = (x1, . . . , xn, y1, . . . , yt). The episode ends when the agent selects an end-of-sentence action. The goal of the agent is to maximize the accumulated expected reward: