diff --git a/.gitignore b/.gitignore index cf8183463613..f018a111ea33 100644 --- a/.gitignore +++ b/.gitignore @@ -163,4 +163,6 @@ tags *.lock # DS_Store (MacOS) -.DS_Store \ No newline at end of file +.DS_Store +# RL pipelines may produce mp4 outputs +*.mp4 \ No newline at end of file diff --git a/docs/source/api/models.mdx b/docs/source/api/models.mdx index 2e1e8798a7ef..7c1faa847440 100644 --- a/docs/source/api/models.mdx +++ b/docs/source/api/models.mdx @@ -22,12 +22,15 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module ## UNet2DOutput [[autodoc]] models.unet_2d.UNet2DOutput -## UNet1DModel -[[autodoc]] UNet1DModel - ## UNet2DModel [[autodoc]] UNet2DModel +## UNet1DOutput +[[autodoc]] models.unet_1d.UNet1DOutput + +## UNet1DModel +[[autodoc]] UNet1DModel + ## UNet2DConditionOutput [[autodoc]] models.unet_2d_condition.UNet2DConditionOutput diff --git a/examples/README.md b/examples/README.md index 29872a7a164e..06ce06b9e3bc 100644 --- a/examples/README.md +++ b/examples/README.md @@ -42,7 +42,7 @@ Training examples show how to pretrain or fine-tune diffusion models for a varie | [**Text-to-Image fine-tuning**](./text_to_image) | โœ… | โœ… | | [**Textual Inversion**](./textual_inversion) | โœ… | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb) | [**Dreambooth**](./dreambooth) | โœ… | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_dreambooth_training.ipynb) - +| [**Reinforcement Learning for Control**](https://github.com/huggingface/diffusers/blob/main/examples/rl/run_diffusers_locomotion.py) | - | - | coming soon. ## Community diff --git a/examples/rl/README.md b/examples/rl/README.md new file mode 100644 index 000000000000..d68f2bf780e3 --- /dev/null +++ b/examples/rl/README.md @@ -0,0 +1,19 @@ +# Overview + +These examples show how to run (Diffuser)[https://arxiv.org/abs/2205.09991] in Diffusers. +There are four scripts, +1. `run_diffuser_locomotion.py` to sample actions and run them in the environment, +2. and `run_diffuser_gen_trajectories.py` to just sample actions from the pre-trained diffusion model. + +You will need some RL specific requirements to run the examples: + +``` +pip install -f https://download.pytorch.org/whl/torch_stable.html \ + free-mujoco-py \ + einops \ + gym==0.24.1 \ + protobuf==3.20.1 \ + git+https://github.com/rail-berkeley/d4rl.git \ + mediapy \ + Pillow==9.0.0 +``` diff --git a/examples/rl/run_diffuser_gen_trajectories.py b/examples/rl/run_diffuser_gen_trajectories.py new file mode 100644 index 000000000000..5bb068cc9fc7 --- /dev/null +++ b/examples/rl/run_diffuser_gen_trajectories.py @@ -0,0 +1,57 @@ +import d4rl # noqa +import gym +import tqdm +from diffusers.experimental import ValueGuidedRLPipeline + + +config = dict( + n_samples=64, + horizon=32, + num_inference_steps=20, + n_guide_steps=0, + scale_grad_by_std=True, + scale=0.1, + eta=0.0, + t_grad_cutoff=2, + device="cpu", +) + + +if __name__ == "__main__": + env_name = "hopper-medium-v2" + env = gym.make(env_name) + + pipeline = ValueGuidedRLPipeline.from_pretrained( + "bglick13/hopper-medium-v2-value-function-hor32", + env=env, + ) + + env.seed(0) + obs = env.reset() + total_reward = 0 + total_score = 0 + T = 1000 + rollout = [obs.copy()] + try: + for t in tqdm.tqdm(range(T)): + # Call the policy + denorm_actions = pipeline(obs, planning_horizon=32) + + # execute action in environment + next_observation, reward, terminal, _ = env.step(denorm_actions) + score = env.get_normalized_score(total_reward) + # update return + total_reward += reward + total_score += score + print( + f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}, Score: {score}, Total Score:" + f" {total_score}" + ) + # save observations for rendering + rollout.append(next_observation.copy()) + + obs = next_observation + except KeyboardInterrupt: + pass + + print(f"Total reward: {total_reward}") diff --git a/examples/rl/run_diffuser_locomotion.py b/examples/rl/run_diffuser_locomotion.py new file mode 100644 index 000000000000..e89181610b33 --- /dev/null +++ b/examples/rl/run_diffuser_locomotion.py @@ -0,0 +1,57 @@ +import d4rl # noqa +import gym +import tqdm +from diffusers.experimental import ValueGuidedRLPipeline + + +config = dict( + n_samples=64, + horizon=32, + num_inference_steps=20, + n_guide_steps=2, + scale_grad_by_std=True, + scale=0.1, + eta=0.0, + t_grad_cutoff=2, + device="cpu", +) + + +if __name__ == "__main__": + env_name = "hopper-medium-v2" + env = gym.make(env_name) + + pipeline = ValueGuidedRLPipeline.from_pretrained( + "bglick13/hopper-medium-v2-value-function-hor32", + env=env, + ) + + env.seed(0) + obs = env.reset() + total_reward = 0 + total_score = 0 + T = 1000 + rollout = [obs.copy()] + try: + for t in tqdm.tqdm(range(T)): + # call the policy + denorm_actions = pipeline(obs, planning_horizon=32) + + # execute action in environment + next_observation, reward, terminal, _ = env.step(denorm_actions) + score = env.get_normalized_score(total_reward) + # update return + total_reward += reward + total_score += score + print( + f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}, Score: {score}, Total Score:" + f" {total_score}" + ) + # save observations for rendering + rollout.append(next_observation.copy()) + + obs = next_observation + except KeyboardInterrupt: + pass + + print(f"Total reward: {total_reward}") diff --git a/scripts/convert_models_diffuser_to_diffusers.py b/scripts/convert_models_diffuser_to_diffusers.py new file mode 100644 index 000000000000..9475f7da93fb --- /dev/null +++ b/scripts/convert_models_diffuser_to_diffusers.py @@ -0,0 +1,100 @@ +import json +import os + +import torch + +from diffusers import UNet1DModel + + +os.makedirs("hub/hopper-medium-v2/unet/hor32", exist_ok=True) +os.makedirs("hub/hopper-medium-v2/unet/hor128", exist_ok=True) + +os.makedirs("hub/hopper-medium-v2/value_function", exist_ok=True) + + +def unet(hor): + if hor == 128: + down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D") + block_out_channels = (32, 128, 256) + up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D") + + elif hor == 32: + down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D") + block_out_channels = (32, 64, 128, 256) + up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D") + model = torch.load(f"/Users/bglickenhaus/Documents/diffuser/temporal_unet-hopper-mediumv2-hor{hor}.torch") + state_dict = model.state_dict() + config = dict( + down_block_types=down_block_types, + block_out_channels=block_out_channels, + up_block_types=up_block_types, + layers_per_block=1, + use_timestep_embedding=True, + out_block_type="OutConv1DBlock", + norm_num_groups=8, + downsample_each_block=False, + in_channels=14, + out_channels=14, + extra_in_channels=0, + time_embedding_type="positional", + flip_sin_to_cos=False, + freq_shift=1, + sample_size=65536, + mid_block_type="MidResTemporalBlock1D", + act_fn="mish", + ) + hf_value_function = UNet1DModel(**config) + print(f"length of state dict: {len(state_dict.keys())}") + print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}") + mapping = dict((k, hfk) for k, hfk in zip(model.state_dict().keys(), hf_value_function.state_dict().keys())) + for k, v in mapping.items(): + state_dict[v] = state_dict.pop(k) + hf_value_function.load_state_dict(state_dict) + + torch.save(hf_value_function.state_dict(), f"hub/hopper-medium-v2/unet/hor{hor}/diffusion_pytorch_model.bin") + with open(f"hub/hopper-medium-v2/unet/hor{hor}/config.json", "w") as f: + json.dump(config, f) + + +def value_function(): + config = dict( + in_channels=14, + down_block_types=("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"), + up_block_types=(), + out_block_type="ValueFunction", + mid_block_type="ValueFunctionMidBlock1D", + block_out_channels=(32, 64, 128, 256), + layers_per_block=1, + downsample_each_block=True, + sample_size=65536, + out_channels=14, + extra_in_channels=0, + time_embedding_type="positional", + use_timestep_embedding=True, + flip_sin_to_cos=False, + freq_shift=1, + norm_num_groups=8, + act_fn="mish", + ) + + model = torch.load("/Users/bglickenhaus/Documents/diffuser/value_function-hopper-mediumv2-hor32.torch") + state_dict = model + hf_value_function = UNet1DModel(**config) + print(f"length of state dict: {len(state_dict.keys())}") + print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}") + + mapping = dict((k, hfk) for k, hfk in zip(state_dict.keys(), hf_value_function.state_dict().keys())) + for k, v in mapping.items(): + state_dict[v] = state_dict.pop(k) + + hf_value_function.load_state_dict(state_dict) + + torch.save(hf_value_function.state_dict(), "hub/hopper-medium-v2/value_function/diffusion_pytorch_model.bin") + with open("hub/hopper-medium-v2/value_function/config.json", "w") as f: + json.dump(config, f) + + +if __name__ == "__main__": + unet(32) + # unet(128) + value_function() diff --git a/src/diffusers/experimental/README.md b/src/diffusers/experimental/README.md new file mode 100644 index 000000000000..81a9de81c737 --- /dev/null +++ b/src/diffusers/experimental/README.md @@ -0,0 +1,5 @@ +# ๐Ÿงจ Diffusers Experimental + +We are adding experimental code to support novel applications and usages of the Diffusers library. +Currently, the following experiments are supported: +* Reinforcement learning via an implementation of the [Diffuser](https://arxiv.org/abs/2205.09991) model. \ No newline at end of file diff --git a/src/diffusers/experimental/__init__.py b/src/diffusers/experimental/__init__.py new file mode 100644 index 000000000000..ebc815540301 --- /dev/null +++ b/src/diffusers/experimental/__init__.py @@ -0,0 +1 @@ +from .rl import ValueGuidedRLPipeline diff --git a/src/diffusers/experimental/rl/__init__.py b/src/diffusers/experimental/rl/__init__.py new file mode 100644 index 000000000000..7b338d3173e1 --- /dev/null +++ b/src/diffusers/experimental/rl/__init__.py @@ -0,0 +1 @@ +from .value_guided_sampling import ValueGuidedRLPipeline diff --git a/src/diffusers/experimental/rl/value_guided_sampling.py b/src/diffusers/experimental/rl/value_guided_sampling.py new file mode 100644 index 000000000000..8d5062e3d4c5 --- /dev/null +++ b/src/diffusers/experimental/rl/value_guided_sampling.py @@ -0,0 +1,129 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch + +import tqdm + +from ...models.unet_1d import UNet1DModel +from ...pipeline_utils import DiffusionPipeline +from ...utils.dummy_pt_objects import DDPMScheduler + + +class ValueGuidedRLPipeline(DiffusionPipeline): + def __init__( + self, + value_function: UNet1DModel, + unet: UNet1DModel, + scheduler: DDPMScheduler, + env, + ): + super().__init__() + self.value_function = value_function + self.unet = unet + self.scheduler = scheduler + self.env = env + self.data = env.get_dataset() + self.means = dict() + for key in self.data.keys(): + try: + self.means[key] = self.data[key].mean() + except: + pass + self.stds = dict() + for key in self.data.keys(): + try: + self.stds[key] = self.data[key].std() + except: + pass + self.state_dim = env.observation_space.shape[0] + self.action_dim = env.action_space.shape[0] + + def normalize(self, x_in, key): + return (x_in - self.means[key]) / self.stds[key] + + def de_normalize(self, x_in, key): + return x_in * self.stds[key] + self.means[key] + + def to_torch(self, x_in): + if type(x_in) is dict: + return {k: self.to_torch(v) for k, v in x_in.items()} + elif torch.is_tensor(x_in): + return x_in.to(self.unet.device) + return torch.tensor(x_in, device=self.unet.device) + + def reset_x0(self, x_in, cond, act_dim): + for key, val in cond.items(): + x_in[:, key, act_dim:] = val.clone() + return x_in + + def run_diffusion(self, x, conditions, n_guide_steps, scale): + batch_size = x.shape[0] + y = None + for i in tqdm.tqdm(self.scheduler.timesteps): + # create batch of timesteps to pass into model + timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long) + for _ in range(n_guide_steps): + with torch.enable_grad(): + x.requires_grad_() + y = self.value_function(x.permute(0, 2, 1), timesteps).sample + grad = torch.autograd.grad([y.sum()], [x])[0] + + posterior_variance = self.scheduler._get_variance(i) + model_std = torch.exp(0.5 * posterior_variance) + grad = model_std * grad + grad[timesteps < 2] = 0 + x = x.detach() + x = x + scale * grad + x = self.reset_x0(x, conditions, self.action_dim) + prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1) + x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"] + + # apply conditions to the trajectory + x = self.reset_x0(x, conditions, self.action_dim) + x = self.to_torch(x) + return x, y + + def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1): + # normalize the observations and create batch dimension + obs = self.normalize(obs, "observations") + obs = obs[None].repeat(batch_size, axis=0) + + conditions = {0: self.to_torch(obs)} + shape = (batch_size, planning_horizon, self.state_dim + self.action_dim) + + # generate initial noise and apply our conditions (to make the trajectories start at current state) + x1 = torch.randn(shape, device=self.unet.device) + x = self.reset_x0(x1, conditions, self.action_dim) + x = self.to_torch(x) + + # run the diffusion process + x, y = self.run_diffusion(x, conditions, n_guide_steps, scale) + + # sort output trajectories by value + sorted_idx = y.argsort(0, descending=True).squeeze() + sorted_values = x[sorted_idx] + actions = sorted_values[:, :, : self.action_dim] + actions = actions.detach().cpu().numpy() + denorm_actions = self.de_normalize(actions, key="actions") + + # select the action with the highest value + if y is not None: + selected_index = 0 + else: + # if we didn't run value guiding, select a random action + selected_index = np.random.randint(0, batch_size) + denorm_actions = denorm_actions[selected_index, 0] + return denorm_actions diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index b09d43fc2edf..0221d891f171 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -62,14 +62,21 @@ def get_timestep_embedding( class TimestepEmbedding(nn.Module): - def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu"): + def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None): super().__init__() - self.linear_1 = nn.Linear(channel, time_embed_dim) + self.linear_1 = nn.Linear(in_channels, time_embed_dim) self.act = None if act_fn == "silu": self.act = nn.SiLU() - self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim) + elif act_fn == "mish": + self.act = nn.Mish() + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) def forward(self, sample): sample = self.linear_1(sample) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 7bb5416adf24..52d056ae96fb 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -5,6 +5,75 @@ import torch.nn.functional as F +class Upsample1D(nn.Module): + """ + An upsampling layer with an optional convolution. + + Parameters: + channels: channels in the inputs and outputs. + use_conv: a bool determining if a convolution is applied. + use_conv_transpose: + out_channels: + """ + + def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + self.conv = None + if use_conv_transpose: + self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1) + elif use_conv: + self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.use_conv_transpose: + return self.conv(x) + + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + + if self.use_conv: + x = self.conv(x) + + return x + + +class Downsample1D(nn.Module): + """ + A downsampling layer with an optional convolution. + + Parameters: + channels: channels in the inputs and outputs. + use_conv: a bool determining if a convolution is applied. + out_channels: + padding: + """ + + def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + + if use_conv: + self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding) + else: + assert self.channels == self.out_channels + self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.conv(x) + + class Upsample2D(nn.Module): """ An upsampling layer with an optional convolution. @@ -12,7 +81,8 @@ class Upsample2D(nn.Module): Parameters: channels: channels in the inputs and outputs. use_conv: a bool determining if a convolution is applied. - dims: determines if the signal is 1D, 2D, or 3D. If 3D, then upsampling occurs in the inner-two dimensions. + use_conv_transpose: + out_channels: """ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): @@ -80,7 +150,8 @@ class Downsample2D(nn.Module): Parameters: channels: channels in the inputs and outputs. use_conv: a bool determining if a convolution is applied. - dims: determines if the signal is 1D, 2D, or 3D. If 3D, then downsampling occurs in the inner-two dimensions. + out_channels: + padding: """ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): @@ -415,6 +486,69 @@ def forward(self, hidden_states): return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) +# unet_rl.py +def rearrange_dims(tensor): + if len(tensor.shape) == 2: + return tensor[:, :, None] + if len(tensor.shape) == 3: + return tensor[:, :, None, :] + elif len(tensor.shape) == 4: + return tensor[:, :, 0, :] + else: + raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.") + + +class Conv1dBlock(nn.Module): + """ + Conv1d --> GroupNorm --> Mish + """ + + def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): + super().__init__() + + self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2) + self.group_norm = nn.GroupNorm(n_groups, out_channels) + self.mish = nn.Mish() + + def forward(self, x): + x = self.conv1d(x) + x = rearrange_dims(x) + x = self.group_norm(x) + x = rearrange_dims(x) + x = self.mish(x) + return x + + +# unet_rl.py +class ResidualTemporalBlock1D(nn.Module): + def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5): + super().__init__() + self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size) + self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size) + + self.time_emb_act = nn.Mish() + self.time_emb = nn.Linear(embed_dim, out_channels) + + self.residual_conv = ( + nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity() + ) + + def forward(self, x, t): + """ + Args: + x : [ batch_size x inp_channels x horizon ] + t : [ batch_size x embed_dim ] + + returns: + out : [ batch_size x out_channels x horizon ] + """ + t = self.time_emb_act(t) + t = self.time_emb(t) + out = self.conv_in(x) + rearrange_dims(t) + out = self.conv_out(out) + return out + self.residual_conv(x) + + def upsample_2d(hidden_states, kernel=None, factor=2, gain=1): r"""Upsample2D a batch of 2D images with the given filter. Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given diff --git a/src/diffusers/models/unet_1d.py b/src/diffusers/models/unet_1d.py index cc0685deb933..29d1d707f55a 100644 --- a/src/diffusers/models/unet_1d.py +++ b/src/diffusers/models/unet_1d.py @@ -1,3 +1,17 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from dataclasses import dataclass from typing import Optional, Tuple, Union @@ -8,7 +22,7 @@ from ..modeling_utils import ModelMixin from ..utils import BaseOutput from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps -from .unet_1d_blocks import get_down_block, get_mid_block, get_up_block +from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block @dataclass @@ -30,11 +44,11 @@ class UNet1DModel(ModelMixin, ConfigMixin): implements for all the model (such as downloading or saving, etc.) Parameters: - sample_size (`int`, *optionl*): Default length of sample. Should be adaptable at runtime. + sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime. in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample. out_channels (`int`, *optional*, defaults to 2): Number of channels in the output. time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use. - freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding. + freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for fourier time embedding. flip_sin_to_cos (`bool`, *optional*, defaults to : obj:`False`): Whether to flip sin to cos for fourier time embedding. down_block_types (`Tuple[str]`, *optional*, defaults to : @@ -43,6 +57,13 @@ class UNet1DModel(ModelMixin, ConfigMixin): obj:`("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`): Tuple of upsample block types. block_out_channels (`Tuple[int]`, *optional*, defaults to : obj:`(32, 32, 64)`): Tuple of block output channels. + mid_block_type (`str`, *optional*, defaults to "UNetMidBlock1D"): block type for middle of UNet. + out_block_type (`str`, *optional*, defaults to `None`): optional output processing of UNet. + act_fn (`str`, *optional*, defaults to None): optional activitation function in UNet blocks. + norm_num_groups (`int`, *optional*, defaults to 8): group norm member count in UNet blocks. + layers_per_block (`int`, *optional*, defaults to 1): added number of layers in a UNet block. + downsample_each_block (`int`, *optional*, defaults to False: + experimental feature for using a UNet without upsampling. """ @register_to_config @@ -54,16 +75,20 @@ def __init__( out_channels: int = 2, extra_in_channels: int = 0, time_embedding_type: str = "fourier", - freq_shift: int = 0, flip_sin_to_cos: bool = True, use_timestep_embedding: bool = False, + freq_shift: float = 0.0, down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"), - mid_block_type: str = "UNetMidBlock1D", up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"), + mid_block_type: Tuple[str] = "UNetMidBlock1D", + out_block_type: str = None, block_out_channels: Tuple[int] = (32, 32, 64), + act_fn: str = None, + norm_num_groups: int = 8, + layers_per_block: int = 1, + downsample_each_block: bool = False, ): super().__init__() - self.sample_size = sample_size # time @@ -73,12 +98,19 @@ def __init__( ) timestep_input_dim = 2 * block_out_channels[0] elif time_embedding_type == "positional": - self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + self.time_proj = Timesteps( + block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift + ) timestep_input_dim = block_out_channels[0] if use_timestep_embedding: time_embed_dim = block_out_channels[0] * 4 - self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + self.time_mlp = TimestepEmbedding( + in_channels=timestep_input_dim, + time_embed_dim=time_embed_dim, + act_fn=act_fn, + out_dim=block_out_channels[0], + ) self.down_blocks = nn.ModuleList([]) self.mid_block = None @@ -94,38 +126,66 @@ def __init__( if i == 0: input_channel += extra_in_channels + is_final_block = i == len(block_out_channels) - 1 + down_block = get_down_block( down_block_type, + num_layers=layers_per_block, in_channels=input_channel, out_channels=output_channel, + temb_channels=block_out_channels[0], + add_downsample=not is_final_block or downsample_each_block, ) self.down_blocks.append(down_block) # mid self.mid_block = get_mid_block( - mid_block_type=mid_block_type, - mid_channels=block_out_channels[-1], + mid_block_type, in_channels=block_out_channels[-1], - out_channels=None, + mid_channels=block_out_channels[-1], + out_channels=block_out_channels[-1], + embed_dim=block_out_channels[0], + num_layers=layers_per_block, + add_downsample=downsample_each_block, ) # up reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] + if out_block_type is None: + final_upsample_channels = out_channels + else: + final_upsample_channels = block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else out_channels + output_channel = ( + reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else final_upsample_channels + ) + + is_final_block = i == len(block_out_channels) - 1 up_block = get_up_block( up_block_type, + num_layers=layers_per_block, in_channels=prev_output_channel, out_channels=output_channel, + temb_channels=block_out_channels[0], + add_upsample=not is_final_block, ) self.up_blocks.append(up_block) prev_output_channel = output_channel - # TODO(PVP, Nathan) placeholder for RL application to be merged shortly - # Totally fine to add another layer with a if statement - no need for nn.Identity here + # out + num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32) + self.out_block = get_out_block( + out_block_type=out_block_type, + num_groups_out=num_groups_out, + embed_dim=block_out_channels[0], + out_channels=out_channels, + act_fn=act_fn, + fc_dim=block_out_channels[-1] // 4, + ) def forward( self, @@ -144,12 +204,20 @@ def forward( [`~models.unet_1d.UNet1DOutput`] or `tuple`: [`~models.unet_1d.UNet1DOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ - # 1. time - if len(timestep.shape) == 0: - timestep = timestep[None] - timestep_embed = self.time_proj(timestep)[..., None] - timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype) + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + timestep_embed = self.time_proj(timesteps) + if self.config.use_timestep_embedding: + timestep_embed = self.time_mlp(timestep_embed) + else: + timestep_embed = timestep_embed[..., None] + timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype) # 2. down down_block_res_samples = () @@ -158,13 +226,18 @@ def forward( down_block_res_samples += res_samples # 3. mid - sample = self.mid_block(sample) + if self.mid_block: + sample = self.mid_block(sample, timestep_embed) # 4. up for i, upsample_block in enumerate(self.up_blocks): res_samples = down_block_res_samples[-1:] down_block_res_samples = down_block_res_samples[:-1] - sample = upsample_block(sample, res_samples) + sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=timestep_embed) + + # 5. post-process + if self.out_block: + sample = self.out_block(sample, timestep_embed) if not return_dict: return (sample,) diff --git a/src/diffusers/models/unet_1d_blocks.py b/src/diffusers/models/unet_1d_blocks.py index 9009071d1e78..fc758ebbb044 100644 --- a/src/diffusers/models/unet_1d_blocks.py +++ b/src/diffusers/models/unet_1d_blocks.py @@ -17,6 +17,256 @@ import torch.nn.functional as F from torch import nn +from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims + + +class DownResnetBlock1D(nn.Module): + def __init__( + self, + in_channels, + out_channels=None, + num_layers=1, + conv_shortcut=False, + temb_channels=32, + groups=32, + groups_out=None, + non_linearity=None, + time_embedding_norm="default", + output_scale_factor=1.0, + add_downsample=True, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.time_embedding_norm = time_embedding_norm + self.add_downsample = add_downsample + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + # there will always be at least one resnet + resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels)] + + for _ in range(num_layers): + resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels)) + + self.resnets = nn.ModuleList(resnets) + + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = nn.Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + else: + self.nonlinearity = None + + self.downsample = None + if add_downsample: + self.downsample = Downsample1D(out_channels, use_conv=True, padding=1) + + def forward(self, hidden_states, temb=None): + output_states = () + + hidden_states = self.resnets[0](hidden_states, temb) + for resnet in self.resnets[1:]: + hidden_states = resnet(hidden_states, temb) + + output_states += (hidden_states,) + + if self.nonlinearity is not None: + hidden_states = self.nonlinearity(hidden_states) + + if self.downsample is not None: + hidden_states = self.downsample(hidden_states) + + return hidden_states, output_states + + +class UpResnetBlock1D(nn.Module): + def __init__( + self, + in_channels, + out_channels=None, + num_layers=1, + temb_channels=32, + groups=32, + groups_out=None, + non_linearity=None, + time_embedding_norm="default", + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.time_embedding_norm = time_embedding_norm + self.add_upsample = add_upsample + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + # there will always be at least one resnet + resnets = [ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels)] + + for _ in range(num_layers): + resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels)) + + self.resnets = nn.ModuleList(resnets) + + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = nn.Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + else: + self.nonlinearity = None + + self.upsample = None + if add_upsample: + self.upsample = Upsample1D(out_channels, use_conv_transpose=True) + + def forward(self, hidden_states, res_hidden_states_tuple=None, temb=None): + if res_hidden_states_tuple is not None: + res_hidden_states = res_hidden_states_tuple[-1] + hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1) + + hidden_states = self.resnets[0](hidden_states, temb) + for resnet in self.resnets[1:]: + hidden_states = resnet(hidden_states, temb) + + if self.nonlinearity is not None: + hidden_states = self.nonlinearity(hidden_states) + + if self.upsample is not None: + hidden_states = self.upsample(hidden_states) + + return hidden_states + + +class ValueFunctionMidBlock1D(nn.Module): + def __init__(self, in_channels, out_channels, embed_dim): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.embed_dim = embed_dim + + self.res1 = ResidualTemporalBlock1D(in_channels, in_channels // 2, embed_dim=embed_dim) + self.down1 = Downsample1D(out_channels // 2, use_conv=True) + self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim) + self.down2 = Downsample1D(out_channels // 4, use_conv=True) + + def forward(self, x, temb=None): + x = self.res1(x, temb) + x = self.down1(x) + x = self.res2(x, temb) + x = self.down2(x) + return x + + +class MidResTemporalBlock1D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + embed_dim, + num_layers: int = 1, + add_downsample: bool = False, + add_upsample: bool = False, + non_linearity=None, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.add_downsample = add_downsample + + # there will always be at least one resnet + resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim)] + + for _ in range(num_layers): + resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim)) + + self.resnets = nn.ModuleList(resnets) + + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = nn.Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + else: + self.nonlinearity = None + + self.upsample = None + if add_upsample: + self.upsample = Downsample1D(out_channels, use_conv=True) + + self.downsample = None + if add_downsample: + self.downsample = Downsample1D(out_channels, use_conv=True) + + if self.upsample and self.downsample: + raise ValueError("Block cannot downsample and upsample") + + def forward(self, hidden_states, temb): + hidden_states = self.resnets[0](hidden_states, temb) + for resnet in self.resnets[1:]: + hidden_states = resnet(hidden_states, temb) + + if self.upsample: + hidden_states = self.upsample(hidden_states) + if self.downsample: + self.downsample = self.downsample(hidden_states) + + return hidden_states + + +class OutConv1DBlock(nn.Module): + def __init__(self, num_groups_out, out_channels, embed_dim, act_fn): + super().__init__() + self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2) + self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim) + if act_fn == "silu": + self.final_conv1d_act = nn.SiLU() + if act_fn == "mish": + self.final_conv1d_act = nn.Mish() + self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1) + + def forward(self, hidden_states, temb=None): + hidden_states = self.final_conv1d_1(hidden_states) + hidden_states = rearrange_dims(hidden_states) + hidden_states = self.final_conv1d_gn(hidden_states) + hidden_states = rearrange_dims(hidden_states) + hidden_states = self.final_conv1d_act(hidden_states) + hidden_states = self.final_conv1d_2(hidden_states) + return hidden_states + + +class OutValueFunctionBlock(nn.Module): + def __init__(self, fc_dim, embed_dim): + super().__init__() + self.final_block = nn.ModuleList( + [ + nn.Linear(fc_dim + embed_dim, fc_dim // 2), + nn.Mish(), + nn.Linear(fc_dim // 2, 1), + ] + ) + + def forward(self, hidden_states, temb): + hidden_states = hidden_states.view(hidden_states.shape[0], -1) + hidden_states = torch.cat((hidden_states, temb), dim=-1) + for layer in self.final_block: + hidden_states = layer(hidden_states) + + return hidden_states + _kernels = { "linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8], @@ -62,7 +312,7 @@ def __init__(self, kernel="linear", pad_mode="reflect"): self.pad = kernel_1d.shape[0] // 2 - 1 self.register_buffer("kernel", kernel_1d) - def forward(self, hidden_states): + def forward(self, hidden_states, temb=None): hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode) weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) @@ -162,32 +412,6 @@ def forward(self, hidden_states): return output -def get_down_block(down_block_type, out_channels, in_channels): - if down_block_type == "DownBlock1D": - return DownBlock1D(out_channels=out_channels, in_channels=in_channels) - elif down_block_type == "AttnDownBlock1D": - return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels) - elif down_block_type == "DownBlock1DNoSkip": - return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels) - raise ValueError(f"{down_block_type} does not exist.") - - -def get_up_block(up_block_type, in_channels, out_channels): - if up_block_type == "UpBlock1D": - return UpBlock1D(in_channels=in_channels, out_channels=out_channels) - elif up_block_type == "AttnUpBlock1D": - return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels) - elif up_block_type == "UpBlock1DNoSkip": - return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels) - raise ValueError(f"{up_block_type} does not exist.") - - -def get_mid_block(mid_block_type, in_channels, mid_channels, out_channels): - if mid_block_type == "UNetMidBlock1D": - return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels) - raise ValueError(f"{mid_block_type} does not exist.") - - class UNetMidBlock1D(nn.Module): def __init__(self, mid_channels, in_channels, out_channels=None): super().__init__() @@ -217,7 +441,7 @@ def __init__(self, mid_channels, in_channels, out_channels=None): self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - def forward(self, hidden_states): + def forward(self, hidden_states, temb=None): hidden_states = self.down(hidden_states) for attn, resnet in zip(self.attentions, self.resnets): hidden_states = resnet(hidden_states) @@ -322,7 +546,7 @@ def __init__(self, in_channels, out_channels, mid_channels=None): self.resnets = nn.ModuleList(resnets) self.up = Upsample1d(kernel="cubic") - def forward(self, hidden_states, res_hidden_states_tuple): + def forward(self, hidden_states, res_hidden_states_tuple, temb=None): res_hidden_states = res_hidden_states_tuple[-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) @@ -349,7 +573,7 @@ def __init__(self, in_channels, out_channels, mid_channels=None): self.resnets = nn.ModuleList(resnets) self.up = Upsample1d(kernel="cubic") - def forward(self, hidden_states, res_hidden_states_tuple): + def forward(self, hidden_states, res_hidden_states_tuple, temb=None): res_hidden_states = res_hidden_states_tuple[-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) @@ -374,7 +598,7 @@ def __init__(self, in_channels, out_channels, mid_channels=None): self.resnets = nn.ModuleList(resnets) - def forward(self, hidden_states, res_hidden_states_tuple): + def forward(self, hidden_states, res_hidden_states_tuple, temb=None): res_hidden_states = res_hidden_states_tuple[-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) @@ -382,3 +606,63 @@ def forward(self, hidden_states, res_hidden_states_tuple): hidden_states = resnet(hidden_states) return hidden_states + + +def get_down_block(down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample): + if down_block_type == "DownResnetBlock1D": + return DownResnetBlock1D( + in_channels=in_channels, + num_layers=num_layers, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + ) + elif down_block_type == "DownBlock1D": + return DownBlock1D(out_channels=out_channels, in_channels=in_channels) + elif down_block_type == "AttnDownBlock1D": + return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels) + elif down_block_type == "DownBlock1DNoSkip": + return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block(up_block_type, num_layers, in_channels, out_channels, temb_channels, add_upsample): + if up_block_type == "UpResnetBlock1D": + return UpResnetBlock1D( + in_channels=in_channels, + num_layers=num_layers, + out_channels=out_channels, + temb_channels=temb_channels, + add_upsample=add_upsample, + ) + elif up_block_type == "UpBlock1D": + return UpBlock1D(in_channels=in_channels, out_channels=out_channels) + elif up_block_type == "AttnUpBlock1D": + return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels) + elif up_block_type == "UpBlock1DNoSkip": + return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels) + raise ValueError(f"{up_block_type} does not exist.") + + +def get_mid_block(mid_block_type, num_layers, in_channels, mid_channels, out_channels, embed_dim, add_downsample): + if mid_block_type == "MidResTemporalBlock1D": + return MidResTemporalBlock1D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + embed_dim=embed_dim, + add_downsample=add_downsample, + ) + elif mid_block_type == "ValueFunctionMidBlock1D": + return ValueFunctionMidBlock1D(in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim) + elif mid_block_type == "UNetMidBlock1D": + return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels) + raise ValueError(f"{mid_block_type} does not exist.") + + +def get_out_block(*, out_block_type, num_groups_out, embed_dim, out_channels, act_fn, fc_dim): + if out_block_type == "OutConv1DBlock": + return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn) + elif out_block_type == "ValueFunction": + return OutValueFunctionBlock(fc_dim, embed_dim) + return None diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index a19d91879cc3..c3e373d2bdca 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -204,6 +204,7 @@ def _get_variance(self, t, predicted_variance=None, variance_type=None): # for rl-diffuser https://arxiv.org/abs/2205.09991 elif variance_type == "fixed_small_log": variance = torch.log(torch.clamp(variance, min=1e-20)) + variance = torch.exp(0.5 * variance) elif variance_type == "fixed_large": variance = self.betas[t] elif variance_type == "fixed_large_log": @@ -301,7 +302,10 @@ def step( variance_noise = torch.randn( model_output.shape, generator=generator, device=device, dtype=model_output.dtype ) - variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise + if self.variance_type == "fixed_small_log": + variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise + else: + variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise pred_prev_sample = pred_prev_sample + variance diff --git a/tests/models/test_models_unet_1d.py b/tests/models/test_models_unet_1d.py index c274ce419211..41c4fdecfa0a 100644 --- a/tests/models/test_models_unet_1d.py +++ b/tests/models/test_models_unet_1d.py @@ -18,13 +18,120 @@ import torch from diffusers import UNet1DModel -from diffusers.utils import slow, torch_device +from diffusers.utils import floats_tensor, slow, torch_device + +from ..test_modeling_common import ModelTesterMixin torch.backends.cuda.matmul.allow_tf32 = False -class UnetModel1DTests(unittest.TestCase): +class UNet1DModelTests(ModelTesterMixin, unittest.TestCase): + model_class = UNet1DModel + + @property + def dummy_input(self): + batch_size = 4 + num_features = 14 + seq_len = 16 + + noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device) + time_step = torch.tensor([10] * batch_size).to(torch_device) + + return {"sample": noise, "timestep": time_step} + + @property + def input_shape(self): + return (4, 14, 16) + + @property + def output_shape(self): + return (4, 14, 16) + + def test_ema_training(self): + pass + + def test_training(self): + pass + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_determinism(self): + super().test_determinism() + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_outputs_equivalence(self): + super().test_outputs_equivalence() + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_from_pretrained_save_pretrained(self): + super().test_from_pretrained_save_pretrained() + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_model_from_config(self): + super().test_model_from_config() + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_output(self): + super().test_output() + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "block_out_channels": (32, 64, 128, 256), + "in_channels": 14, + "out_channels": 14, + "time_embedding_type": "positional", + "use_timestep_embedding": True, + "flip_sin_to_cos": False, + "freq_shift": 1.0, + "out_block_type": "OutConv1DBlock", + "mid_block_type": "MidResTemporalBlock1D", + "down_block_types": ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"), + "up_block_types": ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D"), + "act_fn": "mish", + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_from_pretrained_hub(self): + model, loading_info = UNet1DModel.from_pretrained( + "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="unet" + ) + self.assertIsNotNone(model) + self.assertEqual(len(loading_info["missing_keys"]), 0) + + model.to(torch_device) + image = model(**self.dummy_input) + + assert image is not None, "Make sure output is not None" + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_output_pretrained(self): + model = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-value-function-hor32", subfolder="unet") + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + num_features = model.in_channels + seq_len = 16 + noise = torch.randn((1, seq_len, num_features)).permute( + 0, 2, 1 + ) # match original, we can update values and remove + time_step = torch.full((num_features,), 0) + + with torch.no_grad(): + output = model(noise, time_step).sample.permute(0, 2, 1) + + output_slice = output[0, -3:, -3:].flatten() + # fmt: off + expected_output_slice = torch.tensor([-2.137172, 1.1426016, 0.3688687, -0.766922, 0.7303146, 0.11038864, -0.4760633, 0.13270172, 0.02591348]) + # fmt: on + self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3)) + + def test_forward_with_norm_groups(self): + # Not implemented yet for this UNet + pass + @slow def test_unet_1d_maestro(self): model_id = "harmonai/maestro-150k" @@ -43,3 +150,127 @@ def test_unet_1d_maestro(self): assert (output_sum - 224.0896).abs() < 4e-2 assert (output_max - 0.0607).abs() < 4e-4 + + +class UNetRLModelTests(ModelTesterMixin, unittest.TestCase): + model_class = UNet1DModel + + @property + def dummy_input(self): + batch_size = 4 + num_features = 14 + seq_len = 16 + + noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device) + time_step = torch.tensor([10] * batch_size).to(torch_device) + + return {"sample": noise, "timestep": time_step} + + @property + def input_shape(self): + return (4, 14, 16) + + @property + def output_shape(self): + return (4, 14, 1) + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_determinism(self): + super().test_determinism() + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_outputs_equivalence(self): + super().test_outputs_equivalence() + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_from_pretrained_save_pretrained(self): + super().test_from_pretrained_save_pretrained() + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_model_from_config(self): + super().test_model_from_config() + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_output(self): + # UNetRL is a value-function is different output shape + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output.sample + + self.assertIsNotNone(output) + expected_shape = torch.Size((inputs_dict["sample"].shape[0], 1)) + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + + def test_ema_training(self): + pass + + def test_training(self): + pass + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "in_channels": 14, + "out_channels": 14, + "down_block_types": ["DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"], + "up_block_types": [], + "out_block_type": "ValueFunction", + "mid_block_type": "ValueFunctionMidBlock1D", + "block_out_channels": [32, 64, 128, 256], + "layers_per_block": 1, + "downsample_each_block": True, + "use_timestep_embedding": True, + "freq_shift": 1.0, + "flip_sin_to_cos": False, + "time_embedding_type": "positional", + "act_fn": "mish", + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_from_pretrained_hub(self): + value_function, vf_loading_info = UNet1DModel.from_pretrained( + "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function" + ) + self.assertIsNotNone(value_function) + self.assertEqual(len(vf_loading_info["missing_keys"]), 0) + + value_function.to(torch_device) + image = value_function(**self.dummy_input) + + assert image is not None, "Make sure output is not None" + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_output_pretrained(self): + value_function, vf_loading_info = UNet1DModel.from_pretrained( + "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function" + ) + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + num_features = value_function.in_channels + seq_len = 14 + noise = torch.randn((1, seq_len, num_features)).permute( + 0, 2, 1 + ) # match original, we can update values and remove + time_step = torch.full((num_features,), 0) + + with torch.no_grad(): + output = value_function(noise, time_step).sample + + # fmt: off + expected_output_slice = torch.tensor([165.25] * seq_len) + # fmt: on + self.assertTrue(torch.allclose(output, expected_output_slice, rtol=1e-3)) + + def test_forward_with_norm_groups(self): + # Not implemented yet for this UNet + pass diff --git a/tests/pipelines/dance_diffusion/test_dance_diffusion.py b/tests/pipelines/dance_diffusion/test_dance_diffusion.py index 72e67e4479d2..a63ef84c63f5 100644 --- a/tests/pipelines/dance_diffusion/test_dance_diffusion.py +++ b/tests/pipelines/dance_diffusion/test_dance_diffusion.py @@ -44,6 +44,10 @@ def dummy_unet(self): sample_rate=16_000, in_channels=2, out_channels=2, + flip_sin_to_cos=True, + use_timestep_embedding=False, + time_embedding_type="fourier", + mid_block_type="UNetMidBlock1D", down_block_types=["DownBlock1DNoSkip"] + ["DownBlock1D"] + ["AttnDownBlock1D"], up_block_types=["AttnUpBlock1D"] + ["UpBlock1D"] + ["UpBlock1DNoSkip"], )