Skip to content

Reland "[MLIR][LLVM] Change addressof builders to use opaque pointers" #69292

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 17, 2023

Conversation

Dinistro
Copy link
Contributor

This commit changes the builders of the llvm.mlir.addressof operations
to no longer produce typed pointers.

As a consequence, a GPU to NVVM pattern had to be updated, that still
relied on typed pointers.

This commit changes the builders of the `llvm.mlir.addressof` operations
to no longer produce typed pointers.

As a consequence, a GPU to NVVM pattern had to be updated, that still
relied on typed pointers.
@llvmbot
Copy link
Member

llvmbot commented Oct 17, 2023

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir-gpu

Author: Christian Ulmann (Dinistro)

Changes

This commit changes the builders of the llvm.mlir.addressof operations
to no longer produce typed pointers.

As a consequence, a GPU to NVVM pattern had to be updated, that still
relied on typed pointers.


Full diff: https://github.com/llvm/llvm-project/pull/69292.diff

5 Files Affected:

  • (modified) mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp (+3-4)
  • (modified) mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp (+3-4)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td (+2-2)
  • (modified) mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp (+9-9)
  • (modified) mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir (+13-15)
diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
index 684ce37b2398ce2..f05f1c2dc33881d 100644
--- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
@@ -117,8 +117,8 @@ class PrintOpLowering : public ConversionPattern {
   ///   * `i32 (i8*, ...)`
   static LLVM::LLVMFunctionType getPrintfType(MLIRContext *context) {
     auto llvmI32Ty = IntegerType::get(context, 32);
-    auto llvmI8PtrTy = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
-    auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy,
+    auto llvmPtrTy = LLVM::LLVMPointerType::get(context);
+    auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmPtrTy,
                                                   /*isVarArg=*/true);
     return llvmFnType;
   }
@@ -162,8 +162,7 @@ class PrintOpLowering : public ConversionPattern {
     Value cst0 = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
                                                   builder.getIndexAttr(0));
     return builder.create<LLVM::GEPOp>(
-        loc,
-        LLVM::LLVMPointerType::get(IntegerType::get(builder.getContext(), 8)),
+        loc, LLVM::LLVMPointerType::get(builder.getContext()), global.getType(),
         globalPtr, ArrayRef<Value>({cst0, cst0}));
   }
 };
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
index 684ce37b2398ce2..f05f1c2dc33881d 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
@@ -117,8 +117,8 @@ class PrintOpLowering : public ConversionPattern {
   ///   * `i32 (i8*, ...)`
   static LLVM::LLVMFunctionType getPrintfType(MLIRContext *context) {
     auto llvmI32Ty = IntegerType::get(context, 32);
-    auto llvmI8PtrTy = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
-    auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy,
+    auto llvmPtrTy = LLVM::LLVMPointerType::get(context);
+    auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmPtrTy,
                                                   /*isVarArg=*/true);
     return llvmFnType;
   }
@@ -162,8 +162,7 @@ class PrintOpLowering : public ConversionPattern {
     Value cst0 = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
                                                   builder.getIndexAttr(0));
     return builder.create<LLVM::GEPOp>(
-        loc,
-        LLVM::LLVMPointerType::get(IntegerType::get(builder.getContext(), 8)),
+        loc, LLVM::LLVMPointerType::get(builder.getContext()), global.getType(),
         globalPtr, ArrayRef<Value>({cst0, cst0}));
   }
 };
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 8745d14c8d48318..2a572ab4de706a3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1071,7 +1071,7 @@ def LLVM_AddressOfOp : LLVM_Op<"mlir.addressof",
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       build($_builder, $_state,
-            LLVM::LLVMPointerType::get(global.getType(), global.getAddrSpace()),
+            LLVM::LLVMPointerType::get($_builder.getContext(), global.getAddrSpace()),
             global.getSymName());
       $_state.addAttributes(attrs);
     }]>,
@@ -1079,7 +1079,7 @@ def LLVM_AddressOfOp : LLVM_Op<"mlir.addressof",
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       build($_builder, $_state,
-            LLVM::LLVMPointerType::get(func.getFunctionType()), func.getName());
+            LLVM::LLVMPointerType::get($_builder.getContext()), func.getName());
       $_state.addAttributes(attrs);
     }]>
   ];
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 96d8fceba706617..6d2585aa30ab4c5 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -441,7 +441,7 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
   Location loc = gpuPrintfOp->getLoc();
 
   mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
-  mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8);
+  mlir::Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
 
   // Note: this is the GPUModule op, not the ModuleOp that surrounds it
   // This ensures that global constants and declarations are placed within
@@ -449,7 +449,7 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
   auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
 
   auto vprintfType =
-      LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {i8Ptr, i8Ptr});
+      LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType});
   LLVM::LLVMFuncOp vprintfDecl =
       getOrDefineFunction(moduleOp, loc, rewriter, "vprintf", vprintfType);
 
@@ -473,7 +473,8 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
   // Get a pointer to the format string's first element
   Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
   Value stringStart = rewriter.create<LLVM::GEPOp>(
-      loc, i8Ptr, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
+      loc, getTypeConverter()->getPointerType(globalType), globalType,
+      globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
   SmallVector<Type> types;
   SmallVector<Value> args;
   // Promote and pack the arguments into a stack allocation.
@@ -490,18 +491,17 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
   }
   Type structType =
       LLVM::LLVMStructType::getLiteral(gpuPrintfOp.getContext(), types);
-  Type structPtrType = LLVM::LLVMPointerType::get(structType);
   Value one = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(),
                                                 rewriter.getIndexAttr(1));
-  Value tempAlloc = rewriter.create<LLVM::AllocaOp>(loc, structPtrType, one,
-                                                    /*alignment=*/0);
+  Value tempAlloc =
+      rewriter.create<LLVM::AllocaOp>(loc, ptrType, structType, one,
+                                      /*alignment=*/0);
   for (auto [index, arg] : llvm::enumerate(args)) {
     Value ptr = rewriter.create<LLVM::GEPOp>(
-        loc, LLVM::LLVMPointerType::get(arg.getType()), tempAlloc,
-        ArrayRef<LLVM::GEPArg>{0, index});
+        loc, getTypeConverter()->getPointerType(structType), structType,
+        tempAlloc, ArrayRef<LLVM::GEPArg>{0, index});
     rewriter.create<LLVM::StoreOp>(loc, arg, ptr);
   }
-  tempAlloc = rewriter.create<LLVM::BitcastOp>(loc, i8Ptr, tempAlloc);
   std::array<Value, 2> printfArgs = {stringStart, tempAlloc};
 
   rewriter.create<LLVM::CallOp>(loc, vprintfDecl, printfArgs);
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index 391ccd74841dca4..a8c02e32ef92b6b 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -542,16 +542,15 @@ gpu.module @test_module_28 {
 gpu.module @test_module_29 {
   // CHECK-DAG: llvm.mlir.global internal constant @[[$PRINT_GLOBAL0:[A-Za-z0-9_]+]]("Hello, world\0A\00")
   // CHECK-DAG: llvm.mlir.global internal constant @[[$PRINT_GLOBAL1:[A-Za-z0-9_]+]]("Hello: %d\0A\00")
-  // CHECK-DAG: llvm.func @vprintf(!llvm.ptr<i8>, !llvm.ptr<i8>) -> i32
+  // CHECK-DAG: llvm.func @vprintf(!llvm.ptr, !llvm.ptr) -> i32
 
   // CHECK-LABEL: func @test_const_printf
   gpu.func @test_const_printf() {
-    // CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL0]] : !llvm.ptr<array<14 x i8>>
-    // CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][0, 0] : (!llvm.ptr<array<14 x i8>>) -> !llvm.ptr<i8>
+    // CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL0]] : !llvm.ptr
+    // CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<14 x i8>
     // CHECK-NEXT: %[[O:.*]] = llvm.mlir.constant(1 : index) : i64
-    // CHECK-NEXT: %[[ALLOC:.*]] = llvm.alloca %[[O]] x !llvm.struct<()> : (i64) -> !llvm.ptr<struct<()>>
-    // CHECK-NEXT: %[[ARGPTR:.*]] = llvm.bitcast %[[ALLOC]] : !llvm.ptr<struct<()>> to !llvm.ptr<i8>
-    // CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ARGPTR]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> i32
+    // CHECK-NEXT: %[[ALLOC:.*]] = llvm.alloca %[[O]] x !llvm.struct<()> : (i64) -> !llvm.ptr
+    // CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ALLOC]]) : (!llvm.ptr, !llvm.ptr) -> i32
     gpu.printf "Hello, world\n"
     gpu.return
   }
@@ -559,17 +558,16 @@ gpu.module @test_module_29 {
   // CHECK-LABEL: func @test_printf
   // CHECK: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: f32)
   gpu.func @test_printf(%arg0: i32, %arg1: f32) {
-    // CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL1]] : !llvm.ptr<array<11 x i8>>
-    // CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][0, 0] : (!llvm.ptr<array<11 x i8>>) -> !llvm.ptr<i8>
+    // CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL1]] : !llvm.ptr
+    // CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<11 x i8>
     // CHECK-NEXT: %[[EXT:.+]] = llvm.fpext %[[ARG1]] : f32 to f64
     // CHECK-NEXT: %[[O:.*]] = llvm.mlir.constant(1 : index) : i64
-    // CHECK-NEXT: %[[ALLOC:.*]] = llvm.alloca %[[O]] x !llvm.struct<(i32, f64)> : (i64) -> !llvm.ptr<struct<(i32, f64)>>
-    // CHECK-NEXT: %[[EL0:.*]] = llvm.getelementptr %[[ALLOC]][0, 0] : (!llvm.ptr<struct<(i32, f64)>>) -> !llvm.ptr<i32>
-    // CHECK-NEXT: llvm.store %[[ARG0]], %[[EL0]] : !llvm.ptr<i32>
-    // CHECK-NEXT: %[[EL1:.*]] = llvm.getelementptr %[[ALLOC]][0, 1] : (!llvm.ptr<struct<(i32, f64)>>) -> !llvm.ptr<f64>
-    // CHECK-NEXT: llvm.store %[[EXT]], %[[EL1]] : !llvm.ptr<f64>
-    // CHECK-NEXT: %[[ARGPTR:.*]] = llvm.bitcast %[[ALLOC]] : !llvm.ptr<struct<(i32, f64)>> to !llvm.ptr<i8>
-    // CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ARGPTR]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> i32
+    // CHECK-NEXT: %[[ALLOC:.*]] = llvm.alloca %[[O]] x !llvm.struct<(i32, f64)> : (i64) -> !llvm.ptr
+    // CHECK-NEXT: %[[EL0:.*]] = llvm.getelementptr %[[ALLOC]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i32, f64)>
+    // CHECK-NEXT: llvm.store %[[ARG0]], %[[EL0]] : i32, !llvm.ptr
+    // CHECK-NEXT: %[[EL1:.*]] = llvm.getelementptr %[[ALLOC]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i32, f64)>
+    // CHECK-NEXT: llvm.store %[[EXT]], %[[EL1]] : f64, !llvm.ptr
+    // CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ALLOC]]) : (!llvm.ptr, !llvm.ptr) -> i32
     gpu.printf "Hello: %d\n" %arg0, %arg1 : i32, f32
     gpu.return
   }

@Dinistro Dinistro changed the title Reland [MLIR][LLVM] Change addressof builders to use opaque pointers Reland "[MLIR][LLVM] Change addressof builders to use opaque pointers" Oct 17, 2023
Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

LGTM!

@Dinistro Dinistro merged commit 484668c into llvm:main Oct 17, 2023
@Dinistro Dinistro deleted the change-addressof-builders branch October 17, 2023 09:33
MaheshRavishankar added a commit to iree-org/llvm-project that referenced this pull request Oct 25, 2023
MaheshRavishankar added a commit to MaheshRavishankar/llvm-project that referenced this pull request Oct 25, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants