From 5412f4b942ca0df9a8fe5fbfba062ceec60c999c Mon Sep 17 00:00:00 2001 From: entrpn Date: Thu, 17 Nov 2022 08:15:44 -0800 Subject: [PATCH 1/3] support negative prompts in sd jax pipeline --- .../pipeline_flax_stable_diffusion.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index 02943997d928..01e1c363a87b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -165,6 +165,7 @@ def _generate( guidance_scale: float = 7.5, latents: Optional[jnp.array] = None, debug: bool = False, + neg_prompt_ids: jnp.array = [""] ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -178,7 +179,7 @@ def _generate( max_length = prompt_ids.shape[-1] uncond_input = self.tokenizer( - [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" + neg_prompt_ids * batch_size, padding="max_length", max_length=max_length, return_tensors="np" ) uncond_embeddings = self.text_encoder(uncond_input.input_ids, params=params["text_encoder"])[0] context = jnp.concatenate([uncond_embeddings, text_embeddings]) @@ -251,6 +252,7 @@ def __call__( return_dict: bool = True, jit: bool = False, debug: bool = False, + neg_prompt_ids: jnp.array = [""], **kwargs, ): r""" @@ -298,11 +300,11 @@ def __call__( """ if jit: images = _p_generate( - self, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug + self, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug, neg_prompt_ids ) else: images = self._generate( - prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug + prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug, neg_prompt_ids ) if self.safety_checker is not None: @@ -333,10 +335,10 @@ def __call__( # TODO: maybe use a config dict instead of so many static argnums @partial(jax.pmap, static_broadcasted_argnums=(0, 4, 5, 6, 7, 9)) def _p_generate( - pipe, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug + pipe, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug, neg_prompt_ids ): return pipe._generate( - prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug + prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug, neg_prompt_ids ) From e69c68b55bef68e318af4685fe217543a383087c Mon Sep 17 00:00:00 2001 From: entrpn Date: Thu, 17 Nov 2022 11:05:51 -0800 Subject: [PATCH 2/3] pass batched neg_prompt --- .../stable_diffusion/pipeline_flax_stable_diffusion.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index 01e1c363a87b..b9425c283d7c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -165,7 +165,7 @@ def _generate( guidance_scale: float = 7.5, latents: Optional[jnp.array] = None, debug: bool = False, - neg_prompt_ids: jnp.array = [""] + neg_prompt_ids: jnp.array = None ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -178,6 +178,9 @@ def _generate( batch_size = prompt_ids.shape[0] max_length = prompt_ids.shape[-1] + + if neg_prompt_ids is None: + neg_prompt_ids = [""] * batch_size uncond_input = self.tokenizer( neg_prompt_ids * batch_size, padding="max_length", max_length=max_length, return_tensors="np" ) @@ -252,7 +255,7 @@ def __call__( return_dict: bool = True, jit: bool = False, debug: bool = False, - neg_prompt_ids: jnp.array = [""], + neg_prompt_ids: jnp.array = None, **kwargs, ): r""" From 20b740c4b619d46d074c99075dfca3884dec0cee Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 17 Nov 2022 23:11:37 +0000 Subject: [PATCH 3/3] only encode when negative prompt is None --- .../pipeline_flax_stable_diffusion.py | 59 +++++++++++++++---- 1 file changed, 49 insertions(+), 10 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index b9425c283d7c..a2f0f73dbf1f 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -165,7 +165,7 @@ def _generate( guidance_scale: float = 7.5, latents: Optional[jnp.array] = None, debug: bool = False, - neg_prompt_ids: jnp.array = None + neg_prompt_ids: jnp.array = None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -180,11 +180,12 @@ def _generate( max_length = prompt_ids.shape[-1] if neg_prompt_ids is None: - neg_prompt_ids = [""] * batch_size - uncond_input = self.tokenizer( - neg_prompt_ids * batch_size, padding="max_length", max_length=max_length, return_tensors="np" - ) - uncond_embeddings = self.text_encoder(uncond_input.input_ids, params=params["text_encoder"])[0] + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" + ).input_ids + else: + uncond_input = neg_prompt_ids + uncond_embeddings = self.text_encoder(uncond_input, params=params["text_encoder"])[0] context = jnp.concatenate([uncond_embeddings, text_embeddings]) latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) @@ -303,11 +304,30 @@ def __call__( """ if jit: images = _p_generate( - self, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug, neg_prompt_ids + self, + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + debug, + neg_prompt_ids, ) else: images = self._generate( - prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug, neg_prompt_ids + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + debug, + neg_prompt_ids, ) if self.safety_checker is not None: @@ -338,10 +358,29 @@ def __call__( # TODO: maybe use a config dict instead of so many static argnums @partial(jax.pmap, static_broadcasted_argnums=(0, 4, 5, 6, 7, 9)) def _p_generate( - pipe, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug, neg_prompt_ids + pipe, + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + debug, + neg_prompt_ids, ): return pipe._generate( - prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug, neg_prompt_ids + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + debug, + neg_prompt_ids, )