Skip to content

Commit 81fa2d6

Browse files
authored
Avoid nested fix-copies (#1332)
* Avoid nested `# Copied from` statements during `make fix-copies` * style
1 parent 195e437 commit 81fa2d6

File tree

2 files changed

+4
-11
lines changed

2 files changed

+4
-11
lines changed

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
8181
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
8282
"""
8383

84-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.AltDiffusionPipeline.__init__
8584
def __init__(
8685
self,
8786
vae: AutoencoderKL,
@@ -148,7 +147,6 @@ def __init__(
148147
feature_extractor=feature_extractor,
149148
)
150149

151-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.AltDiffusionPipeline.enable_attention_slicing
152150
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
153151
r"""
154152
Enable sliced attention computation.
@@ -168,7 +166,6 @@ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto
168166
slice_size = self.unet.config.attention_head_dim // 2
169167
self.unet.set_attention_slice(slice_size)
170168

171-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.AltDiffusionPipeline.disable_attention_slicing
172169
def disable_attention_slicing(self):
173170
r"""
174171
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
@@ -177,7 +174,6 @@ def disable_attention_slicing(self):
177174
# set slice_size = `None` to disable `attention slicing`
178175
self.enable_attention_slicing(None)
179176

180-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.AltDiffusionPipeline.enable_sequential_cpu_offload
181177
def enable_sequential_cpu_offload(self, gpu_id=0):
182178
r"""
183179
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
@@ -196,7 +192,6 @@ def enable_sequential_cpu_offload(self, gpu_id=0):
196192
cpu_offload(cpu_offloaded_model, device)
197193

198194
@property
199-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.AltDiffusionPipeline._execution_device
200195
def _execution_device(self):
201196
r"""
202197
Returns the device on which the pipeline's models will be executed. After calling
@@ -214,7 +209,6 @@ def _execution_device(self):
214209
return torch.device(module._hf_hook.execution_device)
215210
return self.device
216211

217-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.AltDiffusionPipeline.enable_xformers_memory_efficient_attention
218212
def enable_xformers_memory_efficient_attention(self):
219213
r"""
220214
Enable memory efficient attention as implemented in xformers.
@@ -227,14 +221,12 @@ def enable_xformers_memory_efficient_attention(self):
227221
"""
228222
self.unet.set_use_memory_efficient_attention_xformers(True)
229223

230-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.AltDiffusionPipeline.disable_xformers_memory_efficient_attention
231224
def disable_xformers_memory_efficient_attention(self):
232225
r"""
233226
Disable memory efficient attention as implemented in xformers.
234227
"""
235228
self.unet.set_use_memory_efficient_attention_xformers(False)
236229

237-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.AltDiffusionPipeline._encode_prompt
238230
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
239231
r"""
240232
Encodes the prompt into text encoder hidden states.
@@ -340,7 +332,6 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr
340332

341333
return text_embeddings
342334

343-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.AltDiffusionPipeline.run_safety_checker
344335
def run_safety_checker(self, image, device, dtype):
345336
if self.safety_checker is not None:
346337
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
@@ -351,7 +342,6 @@ def run_safety_checker(self, image, device, dtype):
351342
has_nsfw_concept = None
352343
return image, has_nsfw_concept
353344

354-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.AltDiffusionPipeline.decode_latents
355345
def decode_latents(self, latents):
356346
latents = 1 / 0.18215 * latents
357347
image = self.vae.decode(latents).sample
@@ -360,7 +350,6 @@ def decode_latents(self, latents):
360350
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
361351
return image
362352

363-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.AltDiffusionPipeline.prepare_extra_step_kwargs
364353
def prepare_extra_step_kwargs(self, generator, eta):
365354
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
366355
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.

utils/check_copies.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,10 @@ def is_copy_consistent(filename, overwrite=False):
153153
observed_code_lines = lines[start_index:line_index]
154154
observed_code = "".join(observed_code_lines)
155155

156+
# Remove any nested `Copied from` comments to avoid circular copies
157+
theoretical_code = [line for line in theoretical_code.split("\n") if _re_copy_warning.search(line) is None]
158+
theoretical_code = "\n".join(theoretical_code)
159+
156160
# Before comparing, use the `replace_pattern` on the original code.
157161
if len(replace_pattern) > 0:
158162
patterns = replace_pattern.replace("with", "").split(",")

0 commit comments

Comments
 (0)