-
Notifications
You must be signed in to change notification settings - Fork 13.6k
Reland "[MLIR][LLVM] Change addressof builders to use opaque pointers" #69292
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This commit changes the builders of the `llvm.mlir.addressof` operations to no longer produce typed pointers. As a consequence, a GPU to NVVM pattern had to be updated, that still relied on typed pointers.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-gpu Author: Christian Ulmann (Dinistro) ChangesThis commit changes the builders of the As a consequence, a GPU to NVVM pattern had to be updated, that still Full diff: https://github.com/llvm/llvm-project/pull/69292.diff 5 Files Affected:
diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
index 684ce37b2398ce2..f05f1c2dc33881d 100644
--- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
@@ -117,8 +117,8 @@ class PrintOpLowering : public ConversionPattern {
/// * `i32 (i8*, ...)`
static LLVM::LLVMFunctionType getPrintfType(MLIRContext *context) {
auto llvmI32Ty = IntegerType::get(context, 32);
- auto llvmI8PtrTy = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
- auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy,
+ auto llvmPtrTy = LLVM::LLVMPointerType::get(context);
+ auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmPtrTy,
/*isVarArg=*/true);
return llvmFnType;
}
@@ -162,8 +162,7 @@ class PrintOpLowering : public ConversionPattern {
Value cst0 = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
builder.getIndexAttr(0));
return builder.create<LLVM::GEPOp>(
- loc,
- LLVM::LLVMPointerType::get(IntegerType::get(builder.getContext(), 8)),
+ loc, LLVM::LLVMPointerType::get(builder.getContext()), global.getType(),
globalPtr, ArrayRef<Value>({cst0, cst0}));
}
};
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
index 684ce37b2398ce2..f05f1c2dc33881d 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
@@ -117,8 +117,8 @@ class PrintOpLowering : public ConversionPattern {
/// * `i32 (i8*, ...)`
static LLVM::LLVMFunctionType getPrintfType(MLIRContext *context) {
auto llvmI32Ty = IntegerType::get(context, 32);
- auto llvmI8PtrTy = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
- auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy,
+ auto llvmPtrTy = LLVM::LLVMPointerType::get(context);
+ auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmPtrTy,
/*isVarArg=*/true);
return llvmFnType;
}
@@ -162,8 +162,7 @@ class PrintOpLowering : public ConversionPattern {
Value cst0 = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
builder.getIndexAttr(0));
return builder.create<LLVM::GEPOp>(
- loc,
- LLVM::LLVMPointerType::get(IntegerType::get(builder.getContext(), 8)),
+ loc, LLVM::LLVMPointerType::get(builder.getContext()), global.getType(),
globalPtr, ArrayRef<Value>({cst0, cst0}));
}
};
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 8745d14c8d48318..2a572ab4de706a3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1071,7 +1071,7 @@ def LLVM_AddressOfOp : LLVM_Op<"mlir.addressof",
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
[{
build($_builder, $_state,
- LLVM::LLVMPointerType::get(global.getType(), global.getAddrSpace()),
+ LLVM::LLVMPointerType::get($_builder.getContext(), global.getAddrSpace()),
global.getSymName());
$_state.addAttributes(attrs);
}]>,
@@ -1079,7 +1079,7 @@ def LLVM_AddressOfOp : LLVM_Op<"mlir.addressof",
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
[{
build($_builder, $_state,
- LLVM::LLVMPointerType::get(func.getFunctionType()), func.getName());
+ LLVM::LLVMPointerType::get($_builder.getContext()), func.getName());
$_state.addAttributes(attrs);
}]>
];
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 96d8fceba706617..6d2585aa30ab4c5 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -441,7 +441,7 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
Location loc = gpuPrintfOp->getLoc();
mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
- mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8);
+ mlir::Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
// Note: this is the GPUModule op, not the ModuleOp that surrounds it
// This ensures that global constants and declarations are placed within
@@ -449,7 +449,7 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
auto vprintfType =
- LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {i8Ptr, i8Ptr});
+ LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType});
LLVM::LLVMFuncOp vprintfDecl =
getOrDefineFunction(moduleOp, loc, rewriter, "vprintf", vprintfType);
@@ -473,7 +473,8 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
// Get a pointer to the format string's first element
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
Value stringStart = rewriter.create<LLVM::GEPOp>(
- loc, i8Ptr, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
+ loc, getTypeConverter()->getPointerType(globalType), globalType,
+ globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
SmallVector<Type> types;
SmallVector<Value> args;
// Promote and pack the arguments into a stack allocation.
@@ -490,18 +491,17 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
}
Type structType =
LLVM::LLVMStructType::getLiteral(gpuPrintfOp.getContext(), types);
- Type structPtrType = LLVM::LLVMPointerType::get(structType);
Value one = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(),
rewriter.getIndexAttr(1));
- Value tempAlloc = rewriter.create<LLVM::AllocaOp>(loc, structPtrType, one,
- /*alignment=*/0);
+ Value tempAlloc =
+ rewriter.create<LLVM::AllocaOp>(loc, ptrType, structType, one,
+ /*alignment=*/0);
for (auto [index, arg] : llvm::enumerate(args)) {
Value ptr = rewriter.create<LLVM::GEPOp>(
- loc, LLVM::LLVMPointerType::get(arg.getType()), tempAlloc,
- ArrayRef<LLVM::GEPArg>{0, index});
+ loc, getTypeConverter()->getPointerType(structType), structType,
+ tempAlloc, ArrayRef<LLVM::GEPArg>{0, index});
rewriter.create<LLVM::StoreOp>(loc, arg, ptr);
}
- tempAlloc = rewriter.create<LLVM::BitcastOp>(loc, i8Ptr, tempAlloc);
std::array<Value, 2> printfArgs = {stringStart, tempAlloc};
rewriter.create<LLVM::CallOp>(loc, vprintfDecl, printfArgs);
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index 391ccd74841dca4..a8c02e32ef92b6b 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -542,16 +542,15 @@ gpu.module @test_module_28 {
gpu.module @test_module_29 {
// CHECK-DAG: llvm.mlir.global internal constant @[[$PRINT_GLOBAL0:[A-Za-z0-9_]+]]("Hello, world\0A\00")
// CHECK-DAG: llvm.mlir.global internal constant @[[$PRINT_GLOBAL1:[A-Za-z0-9_]+]]("Hello: %d\0A\00")
- // CHECK-DAG: llvm.func @vprintf(!llvm.ptr<i8>, !llvm.ptr<i8>) -> i32
+ // CHECK-DAG: llvm.func @vprintf(!llvm.ptr, !llvm.ptr) -> i32
// CHECK-LABEL: func @test_const_printf
gpu.func @test_const_printf() {
- // CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL0]] : !llvm.ptr<array<14 x i8>>
- // CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][0, 0] : (!llvm.ptr<array<14 x i8>>) -> !llvm.ptr<i8>
+ // CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL0]] : !llvm.ptr
+ // CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<14 x i8>
// CHECK-NEXT: %[[O:.*]] = llvm.mlir.constant(1 : index) : i64
- // CHECK-NEXT: %[[ALLOC:.*]] = llvm.alloca %[[O]] x !llvm.struct<()> : (i64) -> !llvm.ptr<struct<()>>
- // CHECK-NEXT: %[[ARGPTR:.*]] = llvm.bitcast %[[ALLOC]] : !llvm.ptr<struct<()>> to !llvm.ptr<i8>
- // CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ARGPTR]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> i32
+ // CHECK-NEXT: %[[ALLOC:.*]] = llvm.alloca %[[O]] x !llvm.struct<()> : (i64) -> !llvm.ptr
+ // CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ALLOC]]) : (!llvm.ptr, !llvm.ptr) -> i32
gpu.printf "Hello, world\n"
gpu.return
}
@@ -559,17 +558,16 @@ gpu.module @test_module_29 {
// CHECK-LABEL: func @test_printf
// CHECK: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: f32)
gpu.func @test_printf(%arg0: i32, %arg1: f32) {
- // CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL1]] : !llvm.ptr<array<11 x i8>>
- // CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][0, 0] : (!llvm.ptr<array<11 x i8>>) -> !llvm.ptr<i8>
+ // CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL1]] : !llvm.ptr
+ // CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<11 x i8>
// CHECK-NEXT: %[[EXT:.+]] = llvm.fpext %[[ARG1]] : f32 to f64
// CHECK-NEXT: %[[O:.*]] = llvm.mlir.constant(1 : index) : i64
- // CHECK-NEXT: %[[ALLOC:.*]] = llvm.alloca %[[O]] x !llvm.struct<(i32, f64)> : (i64) -> !llvm.ptr<struct<(i32, f64)>>
- // CHECK-NEXT: %[[EL0:.*]] = llvm.getelementptr %[[ALLOC]][0, 0] : (!llvm.ptr<struct<(i32, f64)>>) -> !llvm.ptr<i32>
- // CHECK-NEXT: llvm.store %[[ARG0]], %[[EL0]] : !llvm.ptr<i32>
- // CHECK-NEXT: %[[EL1:.*]] = llvm.getelementptr %[[ALLOC]][0, 1] : (!llvm.ptr<struct<(i32, f64)>>) -> !llvm.ptr<f64>
- // CHECK-NEXT: llvm.store %[[EXT]], %[[EL1]] : !llvm.ptr<f64>
- // CHECK-NEXT: %[[ARGPTR:.*]] = llvm.bitcast %[[ALLOC]] : !llvm.ptr<struct<(i32, f64)>> to !llvm.ptr<i8>
- // CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ARGPTR]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> i32
+ // CHECK-NEXT: %[[ALLOC:.*]] = llvm.alloca %[[O]] x !llvm.struct<(i32, f64)> : (i64) -> !llvm.ptr
+ // CHECK-NEXT: %[[EL0:.*]] = llvm.getelementptr %[[ALLOC]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i32, f64)>
+ // CHECK-NEXT: llvm.store %[[ARG0]], %[[EL0]] : i32, !llvm.ptr
+ // CHECK-NEXT: %[[EL1:.*]] = llvm.getelementptr %[[ALLOC]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i32, f64)>
+ // CHECK-NEXT: llvm.store %[[EXT]], %[[EL1]] : f64, !llvm.ptr
+ // CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ALLOC]]) : (!llvm.ptr, !llvm.ptr) -> i32
gpu.printf "Hello: %d\n" %arg0, %arg1 : i32, f32
gpu.return
}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix!
LGTM!
…pointers" (llvm#69292)" This reverts commit 484668c.
…pointers" (llvm#69292)" This reverts commit 484668c.
This commit changes the builders of the
llvm.mlir.addressof
operationsto no longer produce typed pointers.
As a consequence, a GPU to NVVM pattern had to be updated, that still
relied on typed pointers.