-
Notifications
You must be signed in to change notification settings - Fork 6k
Add UNet 1d for RL model for planning + colab #105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
8d1a17c
84e94d7
f67b036
e42d1c0
2dd514e
b4c6188
c53bba9
effcbdb
7865231
35b0a43
9b1379d
e97a610
8642560
f58c915
ad8376d
3b08bea
aae2a9a
dd872af
9b67bb7
db012eb
4db6e0b
634a526
aebf547
305ecd8
42855b9
95d3a1c
6cbb73b
ffb7355
a6314f6
48a7414
3acddb5
713e8f2
268ebdf
daa05fb
ea5f231
4f7a3a4
d90b8b1
ad8b6cf
e06a4a4
99b2c81
de4b6e4
ef6ca1f
6e3485c
e6f1a83
c35a925
949b93a
2f6462b
a2dd559
39dff73
d5eedff
faeacd5
be25030
72b7ee8
cf76a2d
2290356
a061f7e
0c58758
691ddee
4948ca7
ac88677
ba204db
915c41e
c901889
becc803
9b8e5ee
3684a8c
ebdef16
a259aae
1f7702c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,70 @@ | |
import torch.nn.functional as F | ||
|
||
|
||
class Upsample1D(nn.Module): | ||
""" | ||
An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param | ||
use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. | ||
If 3D, then | ||
upsampling occurs in the inner-two dimensions. | ||
""" | ||
|
||
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): | ||
super().__init__() | ||
self.channels = channels | ||
self.out_channels = out_channels or channels | ||
self.use_conv = use_conv | ||
self.use_conv_transpose = use_conv_transpose | ||
self.name = name | ||
|
||
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed | ||
self.conv = None | ||
if use_conv_transpose: | ||
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1) | ||
elif use_conv: | ||
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1) | ||
|
||
def forward(self, x): | ||
assert x.shape[1] == self.channels | ||
if self.use_conv_transpose: | ||
return self.conv(x) | ||
|
||
x = F.interpolate(x, scale_factor=2.0, mode="nearest") | ||
|
||
if self.use_conv: | ||
x = self.conv(x) | ||
|
||
return x | ||
|
||
|
||
class Downsample1D(nn.Module): | ||
""" | ||
A downsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param | ||
use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. | ||
If 3D, then | ||
downsampling occurs in the inner-two dimensions. | ||
""" | ||
|
||
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): | ||
super().__init__() | ||
self.channels = channels | ||
self.out_channels = out_channels or channels | ||
self.use_conv = use_conv | ||
self.padding = padding | ||
stride = 2 | ||
self.name = name | ||
|
||
if use_conv: | ||
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding) | ||
else: | ||
assert self.channels == self.out_channels | ||
self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride) | ||
|
||
def forward(self, x): | ||
assert x.shape[1] == self.channels | ||
return self.conv(x) | ||
|
||
|
||
class Upsample2D(nn.Module): | ||
""" | ||
An upsampling layer with an optional convolution. | ||
|
@@ -374,6 +438,76 @@ def forward(self, x): | |
return x * torch.tanh(torch.nn.functional.softplus(x)) | ||
|
||
|
||
class Conv1dBlock(nn.Module): | ||
""" | ||
Conv1d --> GroupNorm --> Mish | ||
""" | ||
|
||
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): | ||
super().__init__() | ||
|
||
self.block = nn.Sequential( | ||
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), | ||
RearrangeDim(), | ||
# Rearrange("batch channels horizon -> batch channels 1 horizon"), | ||
nn.GroupNorm(n_groups, out_channels), | ||
RearrangeDim(), | ||
# Rearrange("batch channels 1 horizon -> batch channels horizon"), | ||
nn.Mish(), | ||
) | ||
|
||
def forward(self, x): | ||
return self.block(x) | ||
|
||
|
||
class RearrangeDim(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, tensor): | ||
if len(tensor.shape) == 2: | ||
return tensor[:, :, None] | ||
if len(tensor.shape) == 3: | ||
return tensor[:, :, None, :] | ||
elif len(tensor.shape) == 4: | ||
return tensor[:, :, 0, :] | ||
else: | ||
raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.") | ||
|
||
|
||
# unet_rl.py | ||
class ResidualTemporalBlock(nn.Module): | ||
def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5): | ||
super().__init__() | ||
|
||
self.blocks = nn.ModuleList( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not a big fan of using blocks here as it makes it very hard to adapt the class to future models. Sorry this is a bit annoying to do when converting the checkpoint, but could we maybe instead call the layers |
||
[ | ||
Conv1dBlock(inp_channels, out_channels, kernel_size), | ||
Conv1dBlock(out_channels, out_channels, kernel_size), | ||
] | ||
) | ||
|
||
self.time_mlp = nn.Sequential( | ||
natolambert marked this conversation as resolved.
Show resolved
Hide resolved
|
||
nn.Mish(), | ||
nn.Linear(embed_dim, out_channels), | ||
RearrangeDim(), | ||
# Rearrange("batch t -> batch t 1"), | ||
) | ||
|
||
self.residual_conv = ( | ||
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity() | ||
) | ||
|
||
def forward(self, x, t): | ||
""" | ||
x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x | ||
out_channels x horizon ] | ||
""" | ||
out = self.blocks[0](x) + self.time_mlp(t) | ||
out = self.blocks[1](out) | ||
return out + self.residual_conv(x) | ||
|
||
|
||
def upsample_2d(x, kernel=None, factor=2, gain=1): | ||
r"""Upsample2D a batch of 2D images with the given filter. | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,193 @@ | ||
# model adapted from diffuser https://github.com/jannerm/diffuser/blob/main/diffuser/models/temporal.py | ||
from dataclasses import dataclass | ||
from typing import Tuple, Union | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from diffusers.models.resnet import Downsample1D, ResidualTemporalBlock, Upsample1D | ||
|
||
from ..configuration_utils import ConfigMixin, register_to_config | ||
from ..modeling_utils import ModelMixin | ||
from ..utils import BaseOutput | ||
from .embeddings import get_timestep_embedding | ||
|
||
|
||
@dataclass | ||
class TemporalUNetOutput(BaseOutput): | ||
""" | ||
Args: | ||
sample (`torch.FloatTensor` of shape `(batch, horizon, obs_dimension)`): | ||
Hidden states output. Output of last layer of model. | ||
""" | ||
|
||
sample: torch.FloatTensor | ||
|
||
|
||
class SinusoidalPosEmb(nn.Module): | ||
natolambert marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def __init__(self, dim): | ||
super().__init__() | ||
self.dim = dim | ||
|
||
def forward(self, x): | ||
return get_timestep_embedding(x, self.dim) | ||
|
||
|
||
class RearrangeDim(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, tensor): | ||
if len(tensor.shape) == 2: | ||
return tensor[:, :, None] | ||
if len(tensor.shape) == 3: | ||
return tensor[:, :, None, :] | ||
elif len(tensor.shape) == 4: | ||
return tensor[:, :, 0, :] | ||
else: | ||
raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.") | ||
|
||
|
||
class Conv1dBlock(nn.Module): | ||
natolambert marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Conv1d --> GroupNorm --> Mish | ||
""" | ||
|
||
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): | ||
super().__init__() | ||
|
||
self.block = nn.Sequential( | ||
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), | ||
RearrangeDim(), | ||
nn.GroupNorm(n_groups, out_channels), | ||
RearrangeDim(), | ||
nn.Mish(), | ||
) | ||
|
||
def forward(self, x): | ||
return self.block(x) | ||
|
||
|
||
class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): | ||
@register_to_config | ||
def __init__( | ||
self, | ||
training_horizon=128, | ||
transition_dim=14, | ||
cond_dim=3, | ||
predict_epsilon=False, | ||
clip_denoised=True, | ||
dim=32, | ||
dim_mults=(1, 4, 8), | ||
): | ||
super().__init__() | ||
|
||
self.transition_dim = transition_dim | ||
self.cond_dim = cond_dim | ||
self.predict_epsilon = predict_epsilon | ||
self.clip_denoised = clip_denoised | ||
|
||
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)] | ||
in_out = list(zip(dims[:-1], dims[1:])) | ||
|
||
time_dim = dim | ||
self.time_mlp = nn.Sequential( | ||
SinusoidalPosEmb(dim), | ||
nn.Linear(dim, dim * 4), | ||
nn.Mish(), | ||
nn.Linear(dim * 4, dim), | ||
) | ||
|
||
self.downs = nn.ModuleList([]) | ||
self.ups = nn.ModuleList([]) | ||
num_resolutions = len(in_out) | ||
|
||
for ind, (dim_in, dim_out) in enumerate(in_out): | ||
is_last = ind >= (num_resolutions - 1) | ||
|
||
self.downs.append( | ||
nn.ModuleList( | ||
[ | ||
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=training_horizon), | ||
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=training_horizon), | ||
Downsample1D(dim_out, use_conv=True) if not is_last else nn.Identity(), | ||
] | ||
) | ||
) | ||
|
||
if not is_last: | ||
training_horizon = training_horizon // 2 | ||
|
||
mid_dim = dims[-1] | ||
self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=training_horizon) | ||
self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=training_horizon) | ||
|
||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): | ||
is_last = ind >= (num_resolutions - 1) | ||
|
||
self.ups.append( | ||
nn.ModuleList( | ||
[ | ||
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=training_horizon), | ||
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=training_horizon), | ||
Upsample1D(dim_in, use_conv_transpose=True) if not is_last else nn.Identity(), | ||
] | ||
) | ||
) | ||
|
||
if not is_last: | ||
training_horizon = training_horizon * 2 | ||
|
||
self.final_conv = nn.Sequential( | ||
Conv1dBlock(dim, dim, kernel_size=5), | ||
nn.Conv1d(dim, transition_dim, 1), | ||
) | ||
|
||
# def forward(self, sample, timestep): | ||
# """ | ||
# x : [ batch x horizon x transition ] #""" | ||
def forward( | ||
self, | ||
sample: torch.FloatTensor, | ||
timestep: Union[torch.Tensor, float, int], | ||
return_dict: bool = True, | ||
) -> Union[TemporalUNetOutput, Tuple]: | ||
"""r | ||
Args: | ||
sample (`torch.FloatTensor`): (batch, horizon, obs_dimension) noisy inputs tensor | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the shape should be |
||
timestep (`torch.FloatTensor` or `float` or `int): batch (batch) timesteps | ||
return_dict (`bool`, *optional*, defaults to `True`): | ||
Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple. | ||
|
||
Returns: | ||
[`~models.unet_2d.UNet2DOutput`] or `tuple`: [`~models.unet_2d.UNet2DOutput`] if `return_dict` is True, | ||
otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. | ||
""" | ||
sample = sample.permute(0, 2, 1) | ||
|
||
t = self.time_mlp(timestep) | ||
h = [] | ||
|
||
for resnet, resnet2, downsample in self.downs: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's really try to mirror the design in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll try and make it work with the |
||
sample = resnet(sample, t) | ||
sample = resnet2(sample, t) | ||
h.append(sample) | ||
sample = downsample(sample) | ||
|
||
sample = self.mid_block1(sample, t) | ||
sample = self.mid_block2(sample, t) | ||
|
||
for resnet, resnet2, upsample in self.ups: | ||
sample = torch.cat((sample, h.pop()), dim=1) | ||
sample = resnet(sample, t) | ||
sample = resnet2(sample, t) | ||
sample = upsample(sample) | ||
|
||
sample = self.final_conv(sample) | ||
|
||
sample = sample.permute(0, 2, 1) | ||
|
||
if not return_dict: | ||
return (sample,) | ||
|
||
return TemporalUNetOutput(sample=sample) |
Uh oh!
There was an error while loading. Please reload this page.