Skip to content

Commit 6aaa03a

Browse files
Revert "Reland "[MLIR][LLVM] Change addressof builders to use opaque pointers" (llvm#69292)"
This reverts commit 484668c.
1 parent 1d514d7 commit 6aaa03a

File tree

5 files changed

+34
-30
lines changed

5 files changed

+34
-30
lines changed

mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@ class PrintOpLowering : public ConversionPattern {
117117
/// * `i32 (i8*, ...)`
118118
static LLVM::LLVMFunctionType getPrintfType(MLIRContext *context) {
119119
auto llvmI32Ty = IntegerType::get(context, 32);
120-
auto llvmPtrTy = LLVM::LLVMPointerType::get(context);
121-
auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmPtrTy,
120+
auto llvmI8PtrTy = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
121+
auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy,
122122
/*isVarArg=*/true);
123123
return llvmFnType;
124124
}
@@ -162,7 +162,8 @@ class PrintOpLowering : public ConversionPattern {
162162
Value cst0 = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
163163
builder.getIndexAttr(0));
164164
return builder.create<LLVM::GEPOp>(
165-
loc, LLVM::LLVMPointerType::get(builder.getContext()), global.getType(),
165+
loc,
166+
LLVM::LLVMPointerType::get(IntegerType::get(builder.getContext(), 8)),
166167
globalPtr, ArrayRef<Value>({cst0, cst0}));
167168
}
168169
};

mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@ class PrintOpLowering : public ConversionPattern {
117117
/// * `i32 (i8*, ...)`
118118
static LLVM::LLVMFunctionType getPrintfType(MLIRContext *context) {
119119
auto llvmI32Ty = IntegerType::get(context, 32);
120-
auto llvmPtrTy = LLVM::LLVMPointerType::get(context);
121-
auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmPtrTy,
120+
auto llvmI8PtrTy = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
121+
auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy,
122122
/*isVarArg=*/true);
123123
return llvmFnType;
124124
}
@@ -162,7 +162,8 @@ class PrintOpLowering : public ConversionPattern {
162162
Value cst0 = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
163163
builder.getIndexAttr(0));
164164
return builder.create<LLVM::GEPOp>(
165-
loc, LLVM::LLVMPointerType::get(builder.getContext()), global.getType(),
165+
loc,
166+
LLVM::LLVMPointerType::get(IntegerType::get(builder.getContext(), 8)),
166167
globalPtr, ArrayRef<Value>({cst0, cst0}));
167168
}
168169
};

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,15 +1071,15 @@ def LLVM_AddressOfOp : LLVM_Op<"mlir.addressof",
10711071
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
10721072
[{
10731073
build($_builder, $_state,
1074-
LLVM::LLVMPointerType::get($_builder.getContext(), global.getAddrSpace()),
1074+
LLVM::LLVMPointerType::get(global.getType(), global.getAddrSpace()),
10751075
global.getSymName());
10761076
$_state.addAttributes(attrs);
10771077
}]>,
10781078
OpBuilder<(ins "LLVMFuncOp":$func,
10791079
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
10801080
[{
10811081
build($_builder, $_state,
1082-
LLVM::LLVMPointerType::get($_builder.getContext()), func.getName());
1082+
LLVM::LLVMPointerType::get(func.getFunctionType()), func.getName());
10831083
$_state.addAttributes(attrs);
10841084
}]>
10851085
];

mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -441,15 +441,15 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
441441
Location loc = gpuPrintfOp->getLoc();
442442

443443
mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
444-
mlir::Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
444+
mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8);
445445

446446
// Note: this is the GPUModule op, not the ModuleOp that surrounds it
447447
// This ensures that global constants and declarations are placed within
448448
// the device code, not the host code
449449
auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
450450

451451
auto vprintfType =
452-
LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType});
452+
LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {i8Ptr, i8Ptr});
453453
LLVM::LLVMFuncOp vprintfDecl =
454454
getOrDefineFunction(moduleOp, loc, rewriter, "vprintf", vprintfType);
455455

@@ -473,8 +473,7 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
473473
// Get a pointer to the format string's first element
474474
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
475475
Value stringStart = rewriter.create<LLVM::GEPOp>(
476-
loc, getTypeConverter()->getPointerType(globalType), globalType,
477-
globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
476+
loc, i8Ptr, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
478477
SmallVector<Type> types;
479478
SmallVector<Value> args;
480479
// Promote and pack the arguments into a stack allocation.
@@ -491,17 +490,18 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
491490
}
492491
Type structType =
493492
LLVM::LLVMStructType::getLiteral(gpuPrintfOp.getContext(), types);
493+
Type structPtrType = LLVM::LLVMPointerType::get(structType);
494494
Value one = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(),
495495
rewriter.getIndexAttr(1));
496-
Value tempAlloc =
497-
rewriter.create<LLVM::AllocaOp>(loc, ptrType, structType, one,
498-
/*alignment=*/0);
496+
Value tempAlloc = rewriter.create<LLVM::AllocaOp>(loc, structPtrType, one,
497+
/*alignment=*/0);
499498
for (auto [index, arg] : llvm::enumerate(args)) {
500499
Value ptr = rewriter.create<LLVM::GEPOp>(
501-
loc, getTypeConverter()->getPointerType(structType), structType,
502-
tempAlloc, ArrayRef<LLVM::GEPArg>{0, index});
500+
loc, LLVM::LLVMPointerType::get(arg.getType()), tempAlloc,
501+
ArrayRef<LLVM::GEPArg>{0, index});
503502
rewriter.create<LLVM::StoreOp>(loc, arg, ptr);
504503
}
504+
tempAlloc = rewriter.create<LLVM::BitcastOp>(loc, i8Ptr, tempAlloc);
505505
std::array<Value, 2> printfArgs = {stringStart, tempAlloc};
506506

507507
rewriter.create<LLVM::CallOp>(loc, vprintfDecl, printfArgs);

mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -542,32 +542,34 @@ gpu.module @test_module_28 {
542542
gpu.module @test_module_29 {
543543
// CHECK-DAG: llvm.mlir.global internal constant @[[$PRINT_GLOBAL0:[A-Za-z0-9_]+]]("Hello, world\0A\00")
544544
// CHECK-DAG: llvm.mlir.global internal constant @[[$PRINT_GLOBAL1:[A-Za-z0-9_]+]]("Hello: %d\0A\00")
545-
// CHECK-DAG: llvm.func @vprintf(!llvm.ptr, !llvm.ptr) -> i32
545+
// CHECK-DAG: llvm.func @vprintf(!llvm.ptr<i8>, !llvm.ptr<i8>) -> i32
546546

547547
// CHECK-LABEL: func @test_const_printf
548548
gpu.func @test_const_printf() {
549-
// CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL0]] : !llvm.ptr
550-
// CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<14 x i8>
549+
// CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL0]] : !llvm.ptr<array<14 x i8>>
550+
// CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][0, 0] : (!llvm.ptr<array<14 x i8>>) -> !llvm.ptr<i8>
551551
// CHECK-NEXT: %[[O:.*]] = llvm.mlir.constant(1 : index) : i64
552-
// CHECK-NEXT: %[[ALLOC:.*]] = llvm.alloca %[[O]] x !llvm.struct<()> : (i64) -> !llvm.ptr
553-
// CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ALLOC]]) : (!llvm.ptr, !llvm.ptr) -> i32
552+
// CHECK-NEXT: %[[ALLOC:.*]] = llvm.alloca %[[O]] x !llvm.struct<()> : (i64) -> !llvm.ptr<struct<()>>
553+
// CHECK-NEXT: %[[ARGPTR:.*]] = llvm.bitcast %[[ALLOC]] : !llvm.ptr<struct<()>> to !llvm.ptr<i8>
554+
// CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ARGPTR]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> i32
554555
gpu.printf "Hello, world\n"
555556
gpu.return
556557
}
557558

558559
// CHECK-LABEL: func @test_printf
559560
// CHECK: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: f32)
560561
gpu.func @test_printf(%arg0: i32, %arg1: f32) {
561-
// CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL1]] : !llvm.ptr
562-
// CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<11 x i8>
562+
// CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL1]] : !llvm.ptr<array<11 x i8>>
563+
// CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][0, 0] : (!llvm.ptr<array<11 x i8>>) -> !llvm.ptr<i8>
563564
// CHECK-NEXT: %[[EXT:.+]] = llvm.fpext %[[ARG1]] : f32 to f64
564565
// CHECK-NEXT: %[[O:.*]] = llvm.mlir.constant(1 : index) : i64
565-
// CHECK-NEXT: %[[ALLOC:.*]] = llvm.alloca %[[O]] x !llvm.struct<(i32, f64)> : (i64) -> !llvm.ptr
566-
// CHECK-NEXT: %[[EL0:.*]] = llvm.getelementptr %[[ALLOC]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i32, f64)>
567-
// CHECK-NEXT: llvm.store %[[ARG0]], %[[EL0]] : i32, !llvm.ptr
568-
// CHECK-NEXT: %[[EL1:.*]] = llvm.getelementptr %[[ALLOC]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i32, f64)>
569-
// CHECK-NEXT: llvm.store %[[EXT]], %[[EL1]] : f64, !llvm.ptr
570-
// CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ALLOC]]) : (!llvm.ptr, !llvm.ptr) -> i32
566+
// CHECK-NEXT: %[[ALLOC:.*]] = llvm.alloca %[[O]] x !llvm.struct<(i32, f64)> : (i64) -> !llvm.ptr<struct<(i32, f64)>>
567+
// CHECK-NEXT: %[[EL0:.*]] = llvm.getelementptr %[[ALLOC]][0, 0] : (!llvm.ptr<struct<(i32, f64)>>) -> !llvm.ptr<i32>
568+
// CHECK-NEXT: llvm.store %[[ARG0]], %[[EL0]] : !llvm.ptr<i32>
569+
// CHECK-NEXT: %[[EL1:.*]] = llvm.getelementptr %[[ALLOC]][0, 1] : (!llvm.ptr<struct<(i32, f64)>>) -> !llvm.ptr<f64>
570+
// CHECK-NEXT: llvm.store %[[EXT]], %[[EL1]] : !llvm.ptr<f64>
571+
// CHECK-NEXT: %[[ARGPTR:.*]] = llvm.bitcast %[[ALLOC]] : !llvm.ptr<struct<(i32, f64)>> to !llvm.ptr<i8>
572+
// CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ARGPTR]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> i32
571573
gpu.printf "Hello: %d\n" %arg0, %arg1 : i32, f32
572574
gpu.return
573575
}

0 commit comments

Comments
 (0)