Skip to content

Commit df262d3

Browse files
committed
Reland #113487 and #112527 (sdpa shim & fp8 AOTInductor support)
This is a backout of #113747 which reverted the above two commits. Pull Request resolved: #114974 ghstack-source-id: fecfab4
1 parent ec124b9 commit df262d3

File tree

7 files changed

+268
-21
lines changed

7 files changed

+268
-21
lines changed

test/inductor/test_aot_inductor.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,50 @@ def forward(self, x, y):
569569
example_inputs = (a, b)
570570
self.check_model(Model(), example_inputs, constraints=constraints)
571571

572+
@unittest.skipIf(
573+
not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0),
574+
"FP8 is only supported on H100+",
575+
)
576+
def test_fp8(self):
577+
class Model(torch.nn.Module):
578+
def __init__(self, dtype):
579+
super().__init__()
580+
self.out_dtype = dtype
581+
582+
def forward(self, x, weight, bias, scale_a, scale_b):
583+
weight = weight.to(torch.float8_e4m3fn)
584+
output, updated_amax = torch._scaled_mm(
585+
x,
586+
weight,
587+
bias=input_bias,
588+
out_dtype=self.out_dtype,
589+
scale_a=scale_a,
590+
scale_b=scale_b,
591+
)
592+
return output
593+
594+
dtype = torch.float16
595+
596+
a_scale = torch.Tensor([1.0]).to(device="cuda")
597+
b_scale = torch.Tensor([1.0]).to(device="cuda")
598+
input_bias = torch.rand(32, device="cuda", dtype=dtype)
599+
weight_shape = (32, 16)
600+
weight = torch.rand(*weight_shape, device="cuda", dtype=dtype).T
601+
a_inverse_scale = 1 / a_scale
602+
b_inverse_scale = 1 / b_scale
603+
604+
x_shape = (16, 16)
605+
x = torch.rand(*x_shape, device="cuda", dtype=dtype).to(torch.float8_e4m3fn)
606+
constraints = [
607+
torch._export.dynamic_dim(x, 0) >= 1,
608+
torch._export.dynamic_dim(x, 0) <= 2048,
609+
]
610+
self.check_model(
611+
Model(dtype),
612+
(x, weight, input_bias, a_inverse_scale, b_inverse_scale),
613+
constraints=constraints,
614+
)
615+
572616
def test_poi_multiple_dynamic(self):
573617
class Model(torch.nn.Module):
574618
def __init__(self):
@@ -1420,8 +1464,6 @@ class AOTInductorTestABICompatibleCpu(TestCase):
14201464
"test_poi_multiple_dynamic": TestFailure(("abi_compatible_cpu",)),
14211465
# There is a double-free issue which will be fixed in another PR
14221466
"test_repeat_output": TestFailure(("abi_compatible_cpu",), is_skip=True),
1423-
"test_sdpa": TestFailure(("abi_compatible_cpu",)),
1424-
"test_sdpa_2": TestFailure(("abi_compatible_cpu",)),
14251467
"test_simple_dynamic": TestFailure(("abi_compatible_cpu",)),
14261468
# error: could not find s0
14271469
"test_shifted_constraint_ranges": TestFailure(

torch/_inductor/codegen/cpp.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@
7777
torch.bool: "at::kBool",
7878
torch.bfloat16: "at::kBFloat16",
7979
torch.complex64: "at::kComplexFloat",
80+
torch.float8_e4m3fn: "at::kFloat8_e4m3fn",
81+
torch.float8_e5m2: "at::kFloat8_e5m2",
8082
}
8183

8284
DEVICE_TO_ATEN = {

torch/_inductor/codegen/wrapper.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,6 +1041,9 @@ def writeline(self, line):
10411041
def enter_context(self, ctx):
10421042
self.lines.append(LineContext(ctx))
10431043

1044+
def val_to_cpp_arg_str(self, type_, val, is_legacy_abi) -> str:
1045+
raise NotImplementedError()
1046+
10441047
def val_to_arg_str(self, s):
10451048
if isinstance(s, SymTypes):
10461049
return pexpr(sympy.expand(repr(s)))
@@ -1755,8 +1758,11 @@ def generate_c_shim_extern_kernel_alloc_call(self, extern_kernel, args):
17551758
else:
17561759
raise NotImplementedError("unsupported type of {output=}")
17571760
args = args + output_args
1761+
assert (
1762+
extern_kernel.abi_compatible_kernel is not None
1763+
), f"abi_compatible_kernel is None for {extern_kernel.kernel=}"
17581764
self.generate_c_shim_extern_kernel_call(
1759-
extern_kernel.codegen_kernel_name(), args
1765+
extern_kernel.abi_compatible_kernel, args
17601766
)
17611767
for raii_handle in output_raii_handles:
17621768
self.writeline(raii_handle)
@@ -2363,13 +2369,29 @@ def extract_output_name(out):
23632369

23642370
self.extern_call_ops.add(cpp_kernel_key)
23652371

2366-
def val_to_arg_str(self, val):
2372+
def val_to_cpp_arg_str(self, type_, val, is_legacy_abi) -> str:
2373+
if (
2374+
config.aot_inductor.abi_compatible
2375+
and not is_legacy_abi
2376+
and isinstance(type_, torch.OptionalType)
2377+
):
2378+
if val is None:
2379+
return "0" # nullptr is not available in C
2380+
if isinstance(val, (bool, int, str, float)):
2381+
var_name = f"var_{next(self.arg_var_id)}"
2382+
self.writeline(f"auto {var_name} = {self.val_to_arg_str(val)};")
2383+
return f"&{var_name}"
2384+
if not isinstance(type_.getElementType(), torch.TensorType):
2385+
return f"&{self.val_to_arg_str(val)}"
2386+
2387+
return self.val_to_arg_str(val)
2388+
2389+
def val_to_arg_str(self, val) -> str:
23672390
if val is None:
23682391
# When None is passed as an argument, it represents an optional that does not contain a value.
23692392
if config.aot_inductor.abi_compatible:
2370-
return "nullptr"
2371-
else:
2372-
return "c10::nullopt"
2393+
return "0" # nullptr is not available in C
2394+
return "c10::nullopt"
23732395
elif isinstance(val, bool):
23742396
if config.aot_inductor.abi_compatible:
23752397
return "1" if val else "0"
@@ -2391,7 +2413,8 @@ def val_to_arg_str(self, val):
23912413
else:
23922414
return "-std::numeric_limits<float>::infinity()"
23932415
elif isinstance(val, (list, tuple)):
2394-
result = f"{{{', '.join(list(map(self.val_to_arg_str, val)))}}}"
2416+
# FIXME handle embedded optional types?
2417+
result = f"{{{', '.join(self.val_to_arg_str(x) for x in val)}}}"
23952418
if config.aot_inductor.abi_compatible:
23962419
# Need to pass the array length because we can't use std::vector
23972420
return f"{self.codegen_int_array_var(result)}, {len(val)}"

torch/_inductor/graph.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ def supported_dtype_of_cpp_wrapper(dtype, cuda):
7979
}
8080
if cuda:
8181
supported_dtype.add(torch.float16)
82+
supported_dtype.add(torch.float8_e4m3fn)
83+
supported_dtype.add(torch.float8_e5m2)
8284

8385
return dtype in supported_dtype
8486

torch/_inductor/ir.py

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3628,9 +3628,10 @@ def get_kwargs_value(self, arg_name):
36283628
f"arg {arg_name} not found in self.kwargs or self.kwargs_default_value"
36293629
)
36303630

3631+
def is_legacy_abi_kernel(self):
3632+
return False
3633+
36313634
def codegen_kwargs(self):
3632-
if not self.kwargs:
3633-
return []
36343635
if V.graph.cpp_wrapper:
36353636
# FIXME: we should unconditionally fill self.kwargs with missing default values
36363637
# instead of carrying an extra self.ordered_kwargs_for_cpp_kernel
@@ -3642,7 +3643,16 @@ def codegen_kwargs(self):
36423643
if isinstance(v, sympy.Expr):
36433644
kwargs.append(v)
36443645
else:
3645-
kwargs.append(V.graph.wrapper_code.val_to_arg_str(v))
3646+
# FIXME We should let ExternKernel have access to the cpp schema where possible.
3647+
if hasattr(self, "kwargs_default_value"):
3648+
type_ = self.kwargs_default_value.get(arg_name).get("type")
3649+
else:
3650+
type_ = None
3651+
kwargs.append(
3652+
V.graph.wrapper_code.val_to_cpp_arg_str(
3653+
type_, v, self.is_legacy_abi_kernel()
3654+
)
3655+
)
36463656
else:
36473657
kwargs = [
36483658
f"{k}={V.graph.wrapper_code.val_to_arg_str(v)}"
@@ -3777,12 +3787,38 @@ def __init__(self, count: int, device: torch.device):
37773787

37783788

37793789
class ExternKernelAlloc(ExternKernel):
3790+
# Generate abi-compatible kernel names for shim kernels.
3791+
# Each individual shim kernel may have its own versioning rule.
3792+
# However, we don't expect we would end up with too many of such rules.
3793+
def _get_abi_compatible_kernel(self):
3794+
if not V.graph.cpp_wrapper:
3795+
return self.kernel
3796+
3797+
def sdpa_ver_fn():
3798+
# For sdpa, we need the v2 version only if any optional
3799+
# kwarg is missing.
3800+
if any(
3801+
self.get_kwargs_value(arg_name) is None
3802+
for arg_name in self.ordered_kwargs_for_cpp_kernel
3803+
):
3804+
return f"{self.cpp_kernel}_v2"
3805+
else:
3806+
return self.cpp_kernel
3807+
3808+
kernel_to_ver = {"at::_scaled_dot_product_flash_attention": sdpa_ver_fn}
3809+
if (ver_fn := kernel_to_ver.get(self.cpp_kernel, None)) is not None:
3810+
return ver_fn()
3811+
return self.cpp_kernel
3812+
37803813
def codegen_kernel_name(self):
37813814
return self.cpp_kernel if V.graph.cpp_wrapper else self.kernel
37823815

37833816
def codegen(self, wrapper):
37843817
self.codegen_comment(wrapper)
37853818
args = [*self.codegen_args(), *self.codegen_kwargs()]
3819+
# Now we setup abi_compatible_kernel after self.kernel
3820+
# and kwargs are adjusted appropriately.
3821+
self.abi_compatible_kernel = self._get_abi_compatible_kernel()
37863822
V.graph.wrapper_code.generate_extern_kernel_alloc(self, args)
37873823
if isinstance(self.layout, Layout):
37883824
self.codegen_size_asserts(wrapper)
@@ -3803,6 +3839,7 @@ def __init__(
38033839
self.name = V.graph.register_buffer(self)
38043840
self.kernel = kernel
38053841
self.cpp_kernel = cpp_kernel
3842+
self.abi_compatible_kernel = None
38063843
self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel
38073844

38083845
def should_allocate(self):
@@ -4302,6 +4339,9 @@ def is_not_write(arg):
43024339
x.name for x in kernel._schema.arguments if x.kwarg_only
43034340
]
43044341

4342+
def is_legacy_abi_kernel(self):
4343+
return "_scaled_dot_product_flash_attention" in str(self.kernel)
4344+
43054345
def get_arg_default_value(self, pos):
43064346
assert hasattr(
43074347
self, "args_default_value"
@@ -4321,7 +4361,17 @@ def __repr__(self):
43214361

43224362
tensor_args = [Shim(x.codegen_reference()) for x in self.inputs]
43234363
args, kwargs = self.unflatten_args(tensor_args, self.constant_args)
4324-
args = [V.graph.wrapper_code.val_to_arg_str(x) for x in args]
4364+
4365+
if V.graph.cpp_wrapper and isinstance(self.op_overload, torch._ops.OpOverload):
4366+
args = [
4367+
V.graph.wrapper_code.val_to_cpp_arg_str(
4368+
param.real_type, x, self.is_legacy_abi_kernel()
4369+
)
4370+
for param, x in zip(self.op_overload._schema.arguments, args)
4371+
]
4372+
else:
4373+
args = [V.graph.wrapper_code.val_to_arg_str(x) for x in args]
4374+
43254375
# Previously, we want to maintain forward-compatibility by skipping
43264376
# default args in the serialized artifacts in fbcode. However,
43274377
# some of our shim interfaces require default values being set.

torch/csrc/inductor/aoti_torch/c/shim.h

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ using AOTITorchError = int32_t;
8383
AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_cpu();
8484
AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_cuda();
8585

86+
AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e5m2();
87+
AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e4m3fn();
8688
AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_bfloat16();
8789
AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float16();
8890
AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float32();
@@ -175,6 +177,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob(
175177
AtenTensorHandle* ret // returns new reference
176178
);
177179

180+
// This version is deprecated. We will remove it later
178181
AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_dot_product_flash_attention(
179182
AtenTensorHandle query,
180183
AtenTensorHandle key,
@@ -194,6 +197,38 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_dot_product_flash_attention(
194197
AtenTensorHandle* ret8 // returns new reference
195198
);
196199

200+
AOTI_TORCH_EXPORT AOTITorchError
201+
aoti_torch__scaled_dot_product_flash_attention_v2(
202+
AtenTensorHandle query,
203+
AtenTensorHandle key,
204+
AtenTensorHandle value,
205+
double dropout_p,
206+
int is_causal,
207+
int return_debug_mask,
208+
double* scale,
209+
AtenTensorHandle* ret0, // returns new reference
210+
AtenTensorHandle* ret1, // returns new reference
211+
AtenTensorHandle* ret2, // returns new reference
212+
AtenTensorHandle* ret3, // returns new reference
213+
int64_t* ret4,
214+
int64_t* ret5,
215+
AtenTensorHandle* ret6, // returns new reference
216+
AtenTensorHandle* ret7, // returns new reference
217+
AtenTensorHandle* ret8 // returns new reference
218+
);
219+
220+
AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_mm(
221+
AtenTensorHandle self,
222+
AtenTensorHandle mat2,
223+
AtenTensorHandle bias,
224+
int32_t* out_dtype,
225+
AtenTensorHandle scale_a,
226+
AtenTensorHandle scale_b,
227+
AtenTensorHandle scale_result,
228+
int8_t use_fast_accum,
229+
AtenTensorHandle* ret0,
230+
AtenTensorHandle* ret1);
231+
197232
// This function will create a new uninitialized tensor object
198233
// and its pointer is returned through *ret.
199234
AOTI_TORCH_EXPORT AOTITorchError
@@ -238,7 +273,7 @@ aoti_torch_nonzero(AtenTensorHandle self, AtenTensorHandle* out);
238273

239274
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_repeat_interleave_Tensor(
240275
AtenTensorHandle repeats,
241-
int64_t output_size,
276+
int64_t* output_size,
242277
AtenTensorHandle* out);
243278

244279
AOTI_TORCH_EXPORT AOTITorchError

0 commit comments

Comments
 (0)