Skip to content

Commit 733be4e

Browse files
authored
[mlir][spirv] Add GpuToLLVM cconv suited to Vulkan, migrate last tests (llvm#123384)
This commit is a follow-up to 99a562b, which migrated some of the mlir-vulkan-runner tests to mlir-cpu-runner using a new pipeline and set of wrappers. That commit could not migrate all the tests, because the existing calling conventions/ABIs for kernel arguments generated by GPUToLLVMConversionPass were not a good fit for the Vulkan runtime. This commit fixes this and migrates the remaining tests. With this commit, mlir-vulkan-runner and many related components are now unused, and they will be removed in a later commit (see llvm#73457). The old calling conventions require both the caller (host LLVM code) and callee (device code) to have compile-time knowledge of the precise argument types. This works for CUDA, ROCm and SYCL, where there is a C-like calling convention agreed between the host and device code, and the runtime passes through arguments as raw data without comprehension. For Vulkan, however, the interface declared by the shader/kernel is in a more abstract form, so the device code has indirect access to the argument data, and the runtime must process the arguments to set up and bind appropriately-sized buffer descriptors. This commit introduces a new calling convention option to meet the Vulkan runtime's needs. It lowers memref arguments to {void*, size_t} pairs, which can be trivially interpreted by the runtime without it needing to know the original argument types. Unlike the stopgap measure in the previous commit, this system can support memrefs of various ranks and element types, which unblocked migrating the remaining tests.
1 parent 13c6abf commit 733be4e

File tree

10 files changed

+113
-68
lines changed

10 files changed

+113
-68
lines changed

mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,10 @@ struct FunctionCallBuilder {
6262

6363
/// Collect a set of patterns to convert from the GPU dialect to LLVM and
6464
/// populate converter for gpu types.
65-
void populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
66-
RewritePatternSet &patterns,
67-
bool kernelBarePtrCallConv = false,
68-
bool typeCheckKernelArgs = false);
65+
void populateGpuToLLVMConversionPatterns(
66+
LLVMTypeConverter &converter, RewritePatternSet &patterns,
67+
bool kernelBarePtrCallConv = false,
68+
bool kernelIntersperseSizeCallConv = false);
6969

7070
/// A function that maps a MemorySpace enum to a target-specific integer value.
7171
using MemorySpaceMapping = std::function<unsigned(gpu::AddressSpace)>;

mlir/include/mlir/Conversion/Passes.td

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -518,11 +518,13 @@ def GpuToLLVMConversionPass : Pass<"gpu-to-llvm", "ModuleOp"> {
518518
"Use bare pointers to pass memref arguments to kernels. "
519519
"The kernel must use the same setting for this option."
520520
>,
521-
Option<"typeCheckKernelArgs", "type-check-kernel-args", "bool",
521+
Option<"kernelIntersperseSizeCallConv", "intersperse-sizes-for-kernels", "bool",
522522
/*default=*/"false",
523-
"Require all kernel arguments to be memrefs of rank 1 and with a "
524-
"32-bit element size. This is a temporary option that will be "
525-
"removed; TODO(https://github.com/llvm/llvm-project/issues/73457)."
523+
"Inserts a size_t argument following each memref argument, "
524+
"containing the static size in bytes of the buffer. Incompatible "
525+
"arguments are rejected. This is intended for use by the Vulkan "
526+
"runtime with the kernel bare pointer calling convention, to enable "
527+
"dynamic binding of buffers as arguments without static type info."
526528
>
527529
];
528530

mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp

Lines changed: 57 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -428,18 +428,18 @@ class LegalizeLaunchFuncOpPattern
428428
public:
429429
LegalizeLaunchFuncOpPattern(const LLVMTypeConverter &typeConverter,
430430
bool kernelBarePtrCallConv,
431-
bool typeCheckKernelArgs)
431+
bool kernelIntersperseSizeCallConv)
432432
: ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
433433
kernelBarePtrCallConv(kernelBarePtrCallConv),
434-
typeCheckKernelArgs(typeCheckKernelArgs) {}
434+
kernelIntersperseSizeCallConv(kernelIntersperseSizeCallConv) {}
435435

436436
private:
437437
LogicalResult
438438
matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
439439
ConversionPatternRewriter &rewriter) const override;
440440

441441
bool kernelBarePtrCallConv;
442-
bool typeCheckKernelArgs;
442+
bool kernelIntersperseSizeCallConv;
443443
};
444444

445445
/// A rewrite pattern to convert gpu.memcpy operations into a GPU runtime
@@ -566,8 +566,9 @@ void GpuToLLVMConversionPass::runOnOperation() {
566566
populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns);
567567
populateAsyncStructuralTypeConversionsAndLegality(converter, patterns,
568568
target);
569-
populateGpuToLLVMConversionPatterns(
570-
converter, patterns, kernelBarePtrCallConv, typeCheckKernelArgs);
569+
populateGpuToLLVMConversionPatterns(converter, patterns,
570+
kernelBarePtrCallConv,
571+
kernelIntersperseSizeCallConv);
571572

572573
if (failed(
573574
applyPartialConversion(getOperation(), target, std::move(patterns))))
@@ -970,33 +971,55 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
970971
else if (launchOp.getAsyncToken())
971972
stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
972973

973-
if (typeCheckKernelArgs) {
974-
// The current non-bare-pointer ABI is a bad fit for `mgpuLaunchKernel`,
975-
// which takes an untyped list of arguments. The type check here prevents
976-
// accidentally violating the assumption made in vulkan-runtime-wrappers.cpp
977-
// and creating a unchecked runtime ABI mismatch.
978-
// TODO(https://github.com/llvm/llvm-project/issues/73457): Change the ABI
979-
// here to remove the need for this type check.
980-
for (Value arg : launchOp.getKernelOperands()) {
981-
if (auto memrefTy = dyn_cast<MemRefType>(arg.getType())) {
982-
if (memrefTy.getRank() != 1 ||
983-
memrefTy.getElementTypeBitWidth() != 32) {
984-
return rewriter.notifyMatchFailure(
985-
launchOp, "Operand to launch op is not a rank-1 memref with "
986-
"32-bit element type.");
987-
}
988-
} else {
974+
// Lower the kernel operands to match kernel parameters.
975+
// Note: If `useBarePtrCallConv` is set in the type converter's options,
976+
// the value of `kernelBarePtrCallConv` will be ignored.
977+
OperandRange origArguments = launchOp.getKernelOperands();
978+
SmallVector<Value, 8> llvmArguments = getTypeConverter()->promoteOperands(
979+
loc, origArguments, adaptor.getKernelOperands(), rewriter,
980+
/*useBarePtrCallConv=*/kernelBarePtrCallConv);
981+
SmallVector<Value, 8> llvmArgumentsWithSizes;
982+
983+
// Intersperse size information if requested.
984+
if (kernelIntersperseSizeCallConv) {
985+
if (origArguments.size() != llvmArguments.size()) {
986+
// This shouldn't happen if the bare-pointer calling convention is used.
987+
return rewriter.notifyMatchFailure(
988+
launchOp,
989+
"Cannot add sizes to arguments with one-to-many LLVM IR expansion.");
990+
}
991+
992+
llvmArgumentsWithSizes.reserve(llvmArguments.size() * 2);
993+
for (auto [llvmArg, origArg] : zip_equal(llvmArguments, origArguments)) {
994+
auto memrefTy = dyn_cast<MemRefType>(origArg.getType());
995+
if (!memrefTy) {
989996
return rewriter.notifyMatchFailure(
990997
launchOp, "Operand to launch op is not a memref.");
991998
}
999+
1000+
if (!memrefTy.hasStaticShape() ||
1001+
!memrefTy.getElementType().isIntOrFloat()) {
1002+
return rewriter.notifyMatchFailure(
1003+
launchOp, "Operand to launch op is not a memref with a static "
1004+
"shape and an integer or float element type.");
1005+
}
1006+
1007+
unsigned bitwidth = memrefTy.getElementTypeBitWidth();
1008+
if (bitwidth % 8 != 0) {
1009+
return rewriter.notifyMatchFailure(
1010+
launchOp, "Operand to launch op is not a memref with a "
1011+
"byte-aligned element type.");
1012+
}
1013+
1014+
uint64_t staticSize = static_cast<uint64_t>(bitwidth / 8) *
1015+
static_cast<uint64_t>(memrefTy.getNumElements());
1016+
1017+
Value sizeArg = rewriter.create<LLVM::ConstantOp>(
1018+
loc, getIndexType(), rewriter.getIndexAttr(staticSize));
1019+
llvmArgumentsWithSizes.push_back(llvmArg); // Presumably a bare pointer.
1020+
llvmArgumentsWithSizes.push_back(sizeArg);
9921021
}
9931022
}
994-
// Lower the kernel operands to match kernel parameters.
995-
// Note: If `useBarePtrCallConv` is set in the type converter's options,
996-
// the value of `kernelBarePtrCallConv` will be ignored.
997-
SmallVector<Value, 4> arguments = getTypeConverter()->promoteOperands(
998-
loc, launchOp.getKernelOperands(), adaptor.getKernelOperands(), rewriter,
999-
/*useBarePtrCallConv=*/kernelBarePtrCallConv);
10001023

10011024
std::optional<gpu::KernelDim3> clusterSize = std::nullopt;
10021025
if (launchOp.hasClusterSize()) {
@@ -1010,7 +1033,9 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
10101033
adaptor.getGridSizeZ()},
10111034
gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
10121035
adaptor.getBlockSizeZ()},
1013-
adaptor.getDynamicSharedMemorySize(), arguments, stream, clusterSize);
1036+
adaptor.getDynamicSharedMemorySize(),
1037+
llvmArgumentsWithSizes.empty() ? llvmArguments : llvmArgumentsWithSizes,
1038+
stream, clusterSize);
10141039
if (launchOp.getAsyncToken())
10151040
rewriter.replaceOp(launchOp, {stream});
10161041
else
@@ -1760,10 +1785,9 @@ LogicalResult ConvertCreateBsrOpToGpuRuntimeCallPattern::matchAndRewrite(
17601785
return success();
17611786
}
17621787

1763-
void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
1764-
RewritePatternSet &patterns,
1765-
bool kernelBarePtrCallConv,
1766-
bool typeCheckKernelArgs) {
1788+
void mlir::populateGpuToLLVMConversionPatterns(
1789+
LLVMTypeConverter &converter, RewritePatternSet &patterns,
1790+
bool kernelBarePtrCallConv, bool kernelIntersperseSizeCallConv) {
17671791
addOpaquePointerConversion<gpu::AsyncTokenType>(converter);
17681792
addOpaquePointerConversion<gpu::SparseDnTensorHandleType>(converter);
17691793
addOpaquePointerConversion<gpu::SparseSpMatHandleType>(converter);
@@ -1801,7 +1825,7 @@ void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
18011825
ConvertSpMatGetSizeOpToGpuRuntimeCallPattern,
18021826
ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(converter);
18031827
patterns.add<LegalizeLaunchFuncOpPattern>(converter, kernelBarePtrCallConv,
1804-
typeCheckKernelArgs);
1828+
kernelIntersperseSizeCallConv);
18051829
}
18061830

18071831
//===----------------------------------------------------------------------===//
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// RUN: mlir-opt %s --gpu-to-llvm="use-bare-pointers-for-kernels=1 intersperse-sizes-for-kernels=1" -split-input-file | FileCheck %s
2+
3+
module attributes {gpu.container_module, spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>} {
4+
llvm.func @malloc(i64) -> !llvm.ptr
5+
gpu.binary @kernels [#gpu.object<#spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>, "">]
6+
func.func @main() attributes {llvm.emit_c_interface} {
7+
// CHECK: [[RANK1UMD:%.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
8+
%rank1UndefMemrefDescriptor = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
9+
// CHECK: [[RANK2UMD:%.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
10+
%rank2UndefMemrefDescriptor = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
11+
%c1 = arith.constant 1 : index
12+
// CHECK: [[PTR1:%.*]] = llvm.extractvalue [[RANK1UMD]][1]
13+
// CHECK: [[PTR2:%.*]] = llvm.extractvalue [[RANK2UMD]][1]
14+
// CHECK: [[PTR3:%.*]] = llvm.extractvalue [[RANK2UMD]][1]
15+
// CHECK: [[SIZE1:%.*]] = llvm.mlir.constant(32 : index) : i64
16+
// CHECK: [[SIZE2:%.*]] = llvm.mlir.constant(256 : index) : i64
17+
// CHECK: [[SIZE3:%.*]] = llvm.mlir.constant(48 : index) : i64
18+
%6 = builtin.unrealized_conversion_cast %rank1UndefMemrefDescriptor : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<8xf32>
19+
%10 = builtin.unrealized_conversion_cast %rank2UndefMemrefDescriptor : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<8x8xi32>
20+
%14 = builtin.unrealized_conversion_cast %rank2UndefMemrefDescriptor : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<4x12xi8>
21+
// CHECK: gpu.launch_func @kernels::@kernel_add blocks in ({{.*}}) threads in ({{.*}}) : i64 args([[PTR1]] : !llvm.ptr, [[SIZE1]] : i64, [[PTR2]] : !llvm.ptr, [[SIZE2]] : i64, [[PTR3]] : !llvm.ptr, [[SIZE3]] : i64)
22+
gpu.launch_func @kernels::@kernel_add blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%6 : memref<8xf32>, %10 : memref<8x8xi32>, %14 : memref<4x12xi8>)
23+
return
24+
}
25+
}

mlir/test/lib/Pass/TestVulkanRunnerPipeline.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,11 @@ void buildTestVulkanRunnerPipeline(OpPassManager &passManager,
6868
passManager.addPass(createFinalizeMemRefToLLVMConversionPass());
6969
passManager.nest<func::FuncOp>().addPass(
7070
LLVM::createRequestCWrappersPass());
71-
// vulkan-runtime-wrappers.cpp uses the non-bare-pointer calling convention,
72-
// and the type check is needed to prevent accidental ABI mismatches.
71+
// vulkan-runtime-wrappers.cpp requires these calling convention options.
7372
GpuToLLVMConversionPassOptions opt;
7473
opt.hostBarePtrCallConv = false;
75-
opt.kernelBarePtrCallConv = false;
76-
opt.typeCheckKernelArgs = true;
74+
opt.kernelBarePtrCallConv = true;
75+
opt.kernelIntersperseSizeCallConv = true;
7776
passManager.addPass(createGpuToLLVMConversionPass(opt));
7877
}
7978
}

mlir/test/mlir-vulkan-runner/addi.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
// RUN: mlir-opt %s -test-vulkan-runner-pipeline \
2-
// RUN: | mlir-vulkan-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s
1+
// RUN: mlir-opt %s -test-vulkan-runner-pipeline=to-llvm \
2+
// RUN: | mlir-cpu-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s
33

44
// CHECK-COUNT-64: [3, 3, 3, 3, 3, 3, 3, 3]
55
module attributes {

mlir/test/mlir-vulkan-runner/addi8.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
// RUN: mlir-opt %s -test-vulkan-runner-pipeline \
2-
// RUN: | mlir-vulkan-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s
1+
// RUN: mlir-opt %s -test-vulkan-runner-pipeline=to-llvm \
2+
// RUN: | mlir-cpu-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s
33

44
// CHECK-COUNT-64: [3, 3, 3, 3, 3, 3, 3, 3]
55
module attributes {

mlir/test/mlir-vulkan-runner/mulf.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
// RUN: mlir-opt %s -test-vulkan-runner-pipeline \
2-
// RUN: | mlir-vulkan-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s
1+
// RUN: mlir-opt %s -test-vulkan-runner-pipeline=to-llvm \
2+
// RUN: | mlir-cpu-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s
33

44
// CHECK-COUNT-4: [6, 6, 6, 6]
55
module attributes {

mlir/test/mlir-vulkan-runner/subf.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
// RUN: mlir-opt %s -test-vulkan-runner-pipeline \
2-
// RUN: | mlir-vulkan-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s
1+
// RUN: mlir-opt %s -test-vulkan-runner-pipeline=to-llvm \
2+
// RUN: | mlir-cpu-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s
33

44
// CHECK-COUNT-32: [2.2, 2.2, 2.2, 2.2]
55
module attributes {

mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -169,26 +169,21 @@ mgpuLaunchKernel(void *vkKernel, size_t gridX, size_t gridY, size_t gridZ,
169169
void ** /*extra*/, size_t paramsCount) {
170170
auto manager = static_cast<VulkanRuntimeManager *>(vkRuntimeManager);
171171

172-
// The non-bare-pointer memref ABI interacts badly with mgpuLaunchKernel's
173-
// signature:
174-
// - The memref descriptor struct gets split into several elements, each
175-
// passed as their own "param".
176-
// - No metadata is provided as to the rank or element type/size of a memref.
177-
// Here we assume that all MemRefs have rank 1 and an element size of
178-
// 4 bytes. This means each descriptor struct will have five members.
179-
// TODO(https://github.com/llvm/llvm-project/issues/73457): Refactor the
180-
// ABI/API of mgpuLaunchKernel to use a different ABI for memrefs, so
181-
// that other memref types can also be used. This will allow migrating
182-
// the remaining tests and removal of mlir-vulkan-runner.
183-
const size_t paramsPerMemRef = 5;
172+
// GpuToLLVMConversionPass with the kernelBarePtrCallConv and
173+
// kernelIntersperseSizeCallConv options will set up the params array like:
174+
// { &memref_ptr0, &memref_size0, &memref_ptr1, &memref_size1, ... }
175+
const size_t paramsPerMemRef = 2;
184176
if (paramsCount % paramsPerMemRef != 0) {
185-
abort();
177+
abort(); // This would indicate a serious calling convention mismatch.
186178
}
187179
const DescriptorSetIndex setIndex = 0;
188180
BindingIndex bindIndex = 0;
189181
for (size_t i = 0; i < paramsCount; i += paramsPerMemRef) {
190-
auto memref = static_cast<MemRefDescriptor<uint32_t, 1> *>(params[i]);
191-
bindMemRef<uint32_t, 1>(manager, setIndex, bindIndex, memref);
182+
void *memrefBufferBasePtr = *static_cast<void **>(params[i + 0]);
183+
size_t memrefBufferSize = *static_cast<size_t *>(params[i + 1]);
184+
VulkanHostMemoryBuffer memBuffer{memrefBufferBasePtr,
185+
static_cast<uint32_t>(memrefBufferSize)};
186+
manager->setResourceData(setIndex, bindIndex, memBuffer);
192187
++bindIndex;
193188
}
194189

0 commit comments

Comments
 (0)