Skip to content

Commit 864902e

Browse files
authored
[flang][cuda] Call CUFGetDeviceAddress to get global device address from host address (llvm#112989)
1 parent f7b6dc8 commit 864902e

File tree

3 files changed

+126
-15
lines changed

3 files changed

+126
-15
lines changed

flang/include/flang/Optimizer/Transforms/CufOpConversion.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@ class LLVMTypeConverter;
1818

1919
namespace mlir {
2020
class DataLayout;
21+
class SymbolTable;
2122
}
2223

2324
namespace cuf {
2425

2526
void populateCUFToFIRConversionPatterns(const fir::LLVMTypeConverter &converter,
2627
mlir::DataLayout &dl,
28+
const mlir::SymbolTable &symtab,
2729
mlir::RewritePatternSet &patterns);
2830

2931
} // namespace cuf

flang/lib/Optimizer/Transforms/CufOpConversion.cpp

Lines changed: 81 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,69 @@ static bool hasDoubleDescriptors(OpTy op) {
7777
return false;
7878
}
7979

80+
static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter,
81+
mlir::Location loc, mlir::Type toTy,
82+
mlir::Value val) {
83+
if (val.getType() != toTy)
84+
return rewriter.create<fir::ConvertOp>(loc, toTy, val);
85+
return val;
86+
}
87+
88+
mlir::Value getDeviceAddress(mlir::PatternRewriter &rewriter,
89+
mlir::OpOperand &operand,
90+
const mlir::SymbolTable &symtab) {
91+
mlir::Value v = operand.get();
92+
auto declareOp = v.getDefiningOp<fir::DeclareOp>();
93+
if (!declareOp)
94+
return v;
95+
96+
auto addrOfOp = declareOp.getMemref().getDefiningOp<fir::AddrOfOp>();
97+
if (!addrOfOp)
98+
return v;
99+
100+
auto globalOp = symtab.lookup<fir::GlobalOp>(
101+
addrOfOp.getSymbol().getRootReference().getValue());
102+
103+
if (!globalOp)
104+
return v;
105+
106+
bool isDevGlobal{false};
107+
auto attr = globalOp.getDataAttrAttr();
108+
if (attr) {
109+
switch (attr.getValue()) {
110+
case cuf::DataAttribute::Device:
111+
case cuf::DataAttribute::Managed:
112+
case cuf::DataAttribute::Pinned:
113+
isDevGlobal = true;
114+
break;
115+
default:
116+
break;
117+
}
118+
}
119+
if (!isDevGlobal)
120+
return v;
121+
mlir::OpBuilder::InsertionGuard guard(rewriter);
122+
rewriter.setInsertionPoint(operand.getOwner());
123+
auto loc = declareOp.getLoc();
124+
auto mod = declareOp->getParentOfType<mlir::ModuleOp>();
125+
fir::FirOpBuilder builder(rewriter, mod);
126+
127+
mlir::func::FuncOp callee =
128+
fir::runtime::getRuntimeFunc<mkRTKey(CUFGetDeviceAddress)>(loc, builder);
129+
auto fTy = callee.getFunctionType();
130+
auto toTy = fTy.getInput(0);
131+
mlir::Value inputArg =
132+
createConvertOp(rewriter, loc, toTy, declareOp.getResult());
133+
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
134+
mlir::Value sourceLine =
135+
fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
136+
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
137+
builder, loc, fTy, inputArg, sourceFile, sourceLine)};
138+
auto call = rewriter.create<fir::CallOp>(loc, callee, args);
139+
140+
return call->getResult(0);
141+
}
142+
80143
template <typename OpTy>
81144
static mlir::LogicalResult convertOpToCall(OpTy op,
82145
mlir::PatternRewriter &rewriter,
@@ -363,18 +426,14 @@ struct CufFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
363426
}
364427
};
365428

366-
static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter,
367-
mlir::Location loc, mlir::Type toTy,
368-
mlir::Value val) {
369-
if (val.getType() != toTy)
370-
return rewriter.create<fir::ConvertOp>(loc, toTy, val);
371-
return val;
372-
}
373-
374429
struct CufDataTransferOpConversion
375430
: public mlir::OpRewritePattern<cuf::DataTransferOp> {
376431
using OpRewritePattern::OpRewritePattern;
377432

433+
CufDataTransferOpConversion(mlir::MLIRContext *context,
434+
const mlir::SymbolTable &symtab)
435+
: OpRewritePattern(context), symtab{symtab} {}
436+
378437
mlir::LogicalResult
379438
matchAndRewrite(cuf::DataTransferOp op,
380439
mlir::PatternRewriter &rewriter) const override {
@@ -445,9 +504,11 @@ struct CufDataTransferOpConversion
445504
mlir::Value sourceLine =
446505
fir::factory::locationToLineNo(builder, loc, fTy.getInput(5));
447506

448-
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
449-
builder, loc, fTy, op.getDst(), op.getSrc(), bytes, modeValue,
450-
sourceFile, sourceLine)};
507+
mlir::Value dst = getDeviceAddress(rewriter, op.getDstMutable(), symtab);
508+
mlir::Value src = getDeviceAddress(rewriter, op.getSrcMutable(), symtab);
509+
llvm::SmallVector<mlir::Value> args{
510+
fir::runtime::createArguments(builder, loc, fTy, dst, src, bytes,
511+
modeValue, sourceFile, sourceLine)};
451512
builder.create<fir::CallOp>(loc, func, args);
452513
rewriter.eraseOp(op);
453514
return mlir::success();
@@ -552,6 +613,9 @@ struct CufDataTransferOpConversion
552613
}
553614
return mlir::success();
554615
}
616+
617+
private:
618+
const mlir::SymbolTable &symtab;
555619
};
556620

557621
class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
@@ -565,13 +629,15 @@ class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
565629
mlir::ModuleOp module = mlir::dyn_cast<mlir::ModuleOp>(op);
566630
if (!module)
567631
return signalPassFailure();
632+
mlir::SymbolTable symtab(module);
568633

569634
std::optional<mlir::DataLayout> dl =
570635
fir::support::getOrSetDataLayout(module, /*allowDefaultLayout=*/false);
571636
fir::LLVMTypeConverter typeConverter(module, /*applyTBAA=*/false,
572637
/*forceUnifiedTBAATree=*/false, *dl);
573638
target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect>();
574-
cuf::populateCUFToFIRConversionPatterns(typeConverter, *dl, patterns);
639+
cuf::populateCUFToFIRConversionPatterns(typeConverter, *dl, symtab,
640+
patterns);
575641
if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
576642
std::move(patterns)))) {
577643
mlir::emitError(mlir::UnknownLoc::get(ctx),
@@ -584,9 +650,9 @@ class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
584650

585651
void cuf::populateCUFToFIRConversionPatterns(
586652
const fir::LLVMTypeConverter &converter, mlir::DataLayout &dl,
587-
mlir::RewritePatternSet &patterns) {
653+
const mlir::SymbolTable &symtab, mlir::RewritePatternSet &patterns) {
588654
patterns.insert<CufAllocOpConversion>(patterns.getContext(), &dl, &converter);
589655
patterns.insert<CufAllocateOpConversion, CufDeallocateOpConversion,
590-
CufFreeOpConversion, CufDataTransferOpConversion>(
591-
patterns.getContext());
656+
CufFreeOpConversion>(patterns.getContext());
657+
patterns.insert<CufDataTransferOpConversion>(patterns.getContext(), symtab);
592658
}

flang/test/Fir/CUDA/cuda-data-transfer.fir

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,4 +189,47 @@ func.func @_QPsub7() {
189189
// CHECK: %[[SRC:.*]] = fir.convert %[[IHOST]]#0 : (!fir.ref<!fir.array<10xi32>>) -> !fir.llvm_ptr<i8>
190190
// CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[DST]], %[[SRC]], %[[BYTES]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
191191

192+
fir.global @_QMmtestsEn(dense<[3, 4, 5, 6, 7]> : tensor<5xi32>) {data_attr = #cuf.cuda<device>} : !fir.array<5xi32>
193+
func.func @_QPsub8() attributes {fir.bindc_name = "t"} {
194+
%c5 = arith.constant 5 : index
195+
%0 = fir.alloca !fir.array<5xi32> {bindc_name = "m", uniq_name = "_QFEm"}
196+
%1 = fir.shape %c5 : (index) -> !fir.shape<1>
197+
%2 = fir.declare %0(%1) {uniq_name = "_QFEm"} : (!fir.ref<!fir.array<5xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<5xi32>>
198+
%3 = fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
199+
%4 = fir.declare %3(%1) {data_attr = #cuf.cuda<device>, uniq_name = "_QMmtestsEn"} : (!fir.ref<!fir.array<5xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<5xi32>>
200+
cuf.data_transfer %4 to %2 {transfer_kind = #cuf.cuda_transfer<device_host>} : !fir.ref<!fir.array<5xi32>>, !fir.ref<!fir.array<5xi32>>
201+
return
202+
}
203+
204+
// CHECK-LABEL: func.func @_QPsub8()
205+
// CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.array<5xi32>
206+
// CHECK: %[[LOCAL:.*]] = fir.declare %[[ALLOCA]]
207+
// CHECK: %[[GBL:.*]] = fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
208+
// CHECK: %[[DECL:.*]] = fir.declare %[[GBL]]
209+
// CHECK: %[[HOST:.*]] = fir.convert %[[DECL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
210+
// CHECK: %[[SRC:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[HOST]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
211+
// CHECK: %[[DST:.*]] = fir.convert %[[LOCAL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
212+
// CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[DST]], %[[SRC]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
213+
214+
215+
func.func @_QPsub9() {
216+
%c5 = arith.constant 5 : index
217+
%0 = fir.alloca !fir.array<5xi32> {bindc_name = "m", uniq_name = "_QFtest9Em"}
218+
%1 = fir.shape %c5 : (index) -> !fir.shape<1>
219+
%2 = fir.declare %0(%1) {uniq_name = "_QFtest9Em"} : (!fir.ref<!fir.array<5xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<5xi32>>
220+
%3 = fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
221+
%4 = fir.declare %3(%1) {data_attr = #cuf.cuda<device>, uniq_name = "_QMmtestsEn"} : (!fir.ref<!fir.array<5xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<5xi32>>
222+
cuf.data_transfer %2 to %4 {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.ref<!fir.array<5xi32>>, !fir.ref<!fir.array<5xi32>>
223+
return
224+
}
225+
226+
// CHECK-LABEL: func.func @_QPsub9()
227+
// CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.array<5xi32>
228+
// CHECK: %[[LOCAL:.*]] = fir.declare %[[ALLOCA]]
229+
// CHECK: %[[GBL:.*]] = fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
230+
// CHECK: %[[DECL:.*]] = fir.declare %[[GBL]]
231+
// CHECK: %[[HOST:.*]] = fir.convert %[[DECL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
232+
// CHECK: %[[DST:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[HOST]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
233+
// CHECK: %[[SRC:.*]] = fir.convert %[[LOCAL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
234+
// CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[DST]], %[[SRC]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
192235
} // end of module

0 commit comments

Comments
 (0)