Skip to content

Commit 1d70b0e

Browse files
add AudioDiffusionPipeline and LatentAudioDiffusionPipeline huggingface#1334 (huggingface#1426)
* add AudioDiffusionPipeline and LatentAudioDiffusionPipeline * add docs to toc * fix tests * fix tests * fix tests * fix tests * fix tests * Update pr_tests.yml Fix tests * parent 499ff34 author teticio <[email protected]> 1668765652 +0000 committer teticio <[email protected]> 1669041721 +0000 parent 499ff34 author teticio <[email protected]> 1668765652 +0000 committer teticio <[email protected]> 1669041704 +0000 add colab notebook [Flax] Fix loading scheduler from subfolder (huggingface#1319) [FLAX] Fix loading scheduler from subfolder Fix/Enable all schedulers for in-painting (huggingface#1331) * inpaint fix k lms * onnox as well * up Correct path to schedlure (huggingface#1322) * [Examples] Correct path * uP Avoid nested fix-copies (huggingface#1332) * Avoid nested `# Copied from` statements during `make fix-copies` * style Fix img2img speed with LMS-Discrete Scheduler (huggingface#896) Casting `self.sigmas` into a different dtype (the one of original_samples) is not advisable. In my img2img pipeline this leads to a long running time in the `integrate.quad` call later on- by long I mean more than 10x slower. Co-authored-by: Anton Lozhkov <[email protected]> Fix the order of casts for onnx inpainting (huggingface#1338) Legacy Inpainting Pipeline for Onnx Models (huggingface#1237) * Add legacy inpainting pipeline compatibility for onnx * remove commented out line * Add onnx legacy inpainting test * Fix slow decorators * pep8 styling * isort styling * dummy object * ordering consistency * style * docstring styles * Refactor common prompt encoding pattern * Update tests to permanent repository home * support all available schedulers until ONNX IO binding is available Co-authored-by: Anton Lozhkov <[email protected]> * updated styling from PR suggested feedback Co-authored-by: Anton Lozhkov <[email protected]> Jax infer support negative prompt (huggingface#1337) * support negative prompts in sd jax pipeline * pass batched neg_prompt * only encode when negative prompt is None Co-authored-by: Juan Acevedo <[email protected]> Update README.md: Minor change to Imagic code snippet, missing dir error (huggingface#1347) Minor change to Imagic Readme Missing dir causes an error when running the example code. make style change the sample model (huggingface#1352) * Update alt_diffusion.mdx * Update alt_diffusion.mdx Add bit diffusion [WIP] (huggingface#971) * Create bit_diffusion.py Bit diffusion based on the paper, arXiv:2208.04202, Chen2022AnalogBG * adding bit diffusion to new branch ran tests * tests * tests * tests * tests * removed test folders + added to README * Update README.md Co-authored-by: Patrick von Platen <[email protected]> * move Mel to module in pipeline construction, make librosa optional * fix imports * fix copy & paste error in comment * fix style * add missing register_to_config * fix class docstrings * fix class docstrings * tweak docstrings * tweak docstrings * update slow test * put trailing commas back * respect alphabetical order * remove LatentAudioDiffusion, make vqvae optional * move Mel from models back to pipelines :-) * allow loading of pretrained audiodiffusion models * fix tests * fix dummies * remove reference to latent_audio_diffusion in docs * unused import * inherit from SchedulerMixin to make loadable * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]>
1 parent ddf18c6 commit 1d70b0e

10 files changed

+500
-2
lines changed

__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,14 @@
3030
)
3131
from .pipeline_utils import DiffusionPipeline
3232
from .pipelines import (
33+
AudioDiffusionPipeline,
3334
DanceDiffusionPipeline,
3435
DDIMPipeline,
3536
DDPMPipeline,
3637
KarrasVePipeline,
3738
LDMPipeline,
3839
LDMSuperResolutionPipeline,
40+
Mel,
3941
PNDMPipeline,
4042
RePaintPipeline,
4143
ScoreSdeVePipeline,

dependency_versions_table.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"isort": "isort>=5.5.4",
1616
"jax": "jax>=0.2.8,!=0.3.2",
1717
"jaxlib": "jaxlib>=0.1.65",
18+
"librosa": "librosa",
1819
"modelcards": "modelcards>=0.1.4",
1920
"numpy": "numpy",
2021
"parameterized": "parameterized",

pipelines/__init__.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1-
from ..utils import is_flax_available, is_onnx_available, is_torch_available, is_transformers_available
1+
from ..utils import (
2+
is_flax_available,
3+
is_librosa_available,
4+
is_onnx_available,
5+
is_torch_available,
6+
is_transformers_available,
7+
)
28

39

410
if is_torch_available():
@@ -14,6 +20,11 @@
1420
else:
1521
from ..utils.dummy_pt_objects import * # noqa F403
1622

23+
if is_torch_available() and is_librosa_available():
24+
from .audio_diffusion import AudioDiffusionPipeline, Mel
25+
else:
26+
from ..utils.dummy_torch_and_librosa_objects import AudioDiffusionPipeline, Mel # noqa F403
27+
1728
if is_torch_available() and is_transformers_available():
1829
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
1930
from .latent_diffusion import LDMTextToImagePipeline

pipelines/audio_diffusion/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# flake8: noqa
2+
from .mel import Mel
3+
from .pipeline_audio_diffusion import AudioDiffusionPipeline

pipelines/audio_diffusion/mel.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# Copyright 2022 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import warnings
17+
18+
from ...configuration_utils import ConfigMixin, register_to_config
19+
from ...schedulers.scheduling_utils import SchedulerMixin
20+
21+
22+
warnings.filterwarnings("ignore")
23+
24+
import numpy as np # noqa: E402
25+
26+
import librosa # noqa: E402
27+
from PIL import Image # noqa: E402
28+
29+
30+
class Mel(ConfigMixin, SchedulerMixin):
31+
"""
32+
Parameters:
33+
x_res (`int`): x resolution of spectrogram (time)
34+
y_res (`int`): y resolution of spectrogram (frequency bins)
35+
sample_rate (`int`): sample rate of audio
36+
n_fft (`int`): number of Fast Fourier Transforms
37+
hop_length (`int`): hop length (a higher number is recommended for lower than 256 y_res)
38+
top_db (`int`): loudest in decibels
39+
n_iter (`int`): number of iterations for Griffin Linn mel inversion
40+
"""
41+
42+
config_name = "mel_config.json"
43+
44+
@register_to_config
45+
def __init__(
46+
self,
47+
x_res: int = 256,
48+
y_res: int = 256,
49+
sample_rate: int = 22050,
50+
n_fft: int = 2048,
51+
hop_length: int = 512,
52+
top_db: int = 80,
53+
n_iter: int = 32,
54+
):
55+
self.hop_length = hop_length
56+
self.sr = sample_rate
57+
self.n_fft = n_fft
58+
self.top_db = top_db
59+
self.n_iter = n_iter
60+
self.set_resolution(x_res, y_res)
61+
self.audio = None
62+
63+
def set_resolution(self, x_res: int, y_res: int):
64+
"""Set resolution.
65+
66+
Args:
67+
x_res (`int`): x resolution of spectrogram (time)
68+
y_res (`int`): y resolution of spectrogram (frequency bins)
69+
"""
70+
self.x_res = x_res
71+
self.y_res = y_res
72+
self.n_mels = self.y_res
73+
self.slice_size = self.x_res * self.hop_length - 1
74+
75+
def load_audio(self, audio_file: str = None, raw_audio: np.ndarray = None):
76+
"""Load audio.
77+
78+
Args:
79+
audio_file (`str`): must be a file on disk due to Librosa limitation or
80+
raw_audio (`np.ndarray`): audio as numpy array
81+
"""
82+
if audio_file is not None:
83+
self.audio, _ = librosa.load(audio_file, mono=True, sr=self.sr)
84+
else:
85+
self.audio = raw_audio
86+
87+
# Pad with silence if necessary.
88+
if len(self.audio) < self.x_res * self.hop_length:
89+
self.audio = np.concatenate([self.audio, np.zeros((self.x_res * self.hop_length - len(self.audio),))])
90+
91+
def get_number_of_slices(self) -> int:
92+
"""Get number of slices in audio.
93+
94+
Returns:
95+
`int`: number of spectograms audio can be sliced into
96+
"""
97+
return len(self.audio) // self.slice_size
98+
99+
def get_audio_slice(self, slice: int = 0) -> np.ndarray:
100+
"""Get slice of audio.
101+
102+
Args:
103+
slice (`int`): slice number of audio (out of get_number_of_slices())
104+
105+
Returns:
106+
`np.ndarray`: audio as numpy array
107+
"""
108+
return self.audio[self.slice_size * slice : self.slice_size * (slice + 1)]
109+
110+
def get_sample_rate(self) -> int:
111+
"""Get sample rate:
112+
113+
Returns:
114+
`int`: sample rate of audio
115+
"""
116+
return self.sr
117+
118+
def audio_slice_to_image(self, slice: int) -> Image.Image:
119+
"""Convert slice of audio to spectrogram.
120+
121+
Args:
122+
slice (`int`): slice number of audio to convert (out of get_number_of_slices())
123+
124+
Returns:
125+
`PIL Image`: grayscale image of x_res x y_res
126+
"""
127+
S = librosa.feature.melspectrogram(
128+
y=self.get_audio_slice(slice), sr=self.sr, n_fft=self.n_fft, hop_length=self.hop_length, n_mels=self.n_mels
129+
)
130+
log_S = librosa.power_to_db(S, ref=np.max, top_db=self.top_db)
131+
bytedata = (((log_S + self.top_db) * 255 / self.top_db).clip(0, 255) + 0.5).astype(np.uint8)
132+
image = Image.fromarray(bytedata)
133+
return image
134+
135+
def image_to_audio(self, image: Image.Image) -> np.ndarray:
136+
"""Converts spectrogram to audio.
137+
138+
Args:
139+
image (`PIL Image`): x_res x y_res grayscale image
140+
141+
Returns:
142+
audio (`np.ndarray`): raw audio
143+
"""
144+
bytedata = np.frombuffer(image.tobytes(), dtype="uint8").reshape((image.height, image.width))
145+
log_S = bytedata.astype("float") * self.top_db / 255 - self.top_db
146+
S = librosa.db_to_power(log_S)
147+
audio = librosa.feature.inverse.mel_to_audio(
148+
S, sr=self.sr, n_fft=self.n_fft, hop_length=self.hop_length, n_iter=self.n_iter
149+
)
150+
return audio

0 commit comments

Comments
 (0)