@@ -77,6 +77,69 @@ static bool hasDoubleDescriptors(OpTy op) {
77
77
return false ;
78
78
}
79
79
80
+ static mlir::Value createConvertOp (mlir::PatternRewriter &rewriter,
81
+ mlir::Location loc, mlir::Type toTy,
82
+ mlir::Value val) {
83
+ if (val.getType () != toTy)
84
+ return rewriter.create <fir::ConvertOp>(loc, toTy, val);
85
+ return val;
86
+ }
87
+
88
+ mlir::Value getDeviceAddress (mlir::PatternRewriter &rewriter,
89
+ mlir::OpOperand &operand,
90
+ const mlir::SymbolTable &symtab) {
91
+ mlir::Value v = operand.get ();
92
+ auto declareOp = v.getDefiningOp <fir::DeclareOp>();
93
+ if (!declareOp)
94
+ return v;
95
+
96
+ auto addrOfOp = declareOp.getMemref ().getDefiningOp <fir::AddrOfOp>();
97
+ if (!addrOfOp)
98
+ return v;
99
+
100
+ auto globalOp = symtab.lookup <fir::GlobalOp>(
101
+ addrOfOp.getSymbol ().getRootReference ().getValue ());
102
+
103
+ if (!globalOp)
104
+ return v;
105
+
106
+ bool isDevGlobal{false };
107
+ auto attr = globalOp.getDataAttrAttr ();
108
+ if (attr) {
109
+ switch (attr.getValue ()) {
110
+ case cuf::DataAttribute::Device:
111
+ case cuf::DataAttribute::Managed:
112
+ case cuf::DataAttribute::Pinned:
113
+ isDevGlobal = true ;
114
+ break ;
115
+ default :
116
+ break ;
117
+ }
118
+ }
119
+ if (!isDevGlobal)
120
+ return v;
121
+ mlir::OpBuilder::InsertionGuard guard (rewriter);
122
+ rewriter.setInsertionPoint (operand.getOwner ());
123
+ auto loc = declareOp.getLoc ();
124
+ auto mod = declareOp->getParentOfType <mlir::ModuleOp>();
125
+ fir::FirOpBuilder builder (rewriter, mod);
126
+
127
+ mlir::func::FuncOp callee =
128
+ fir::runtime::getRuntimeFunc<mkRTKey (CUFGetDeviceAddress)>(loc, builder);
129
+ auto fTy = callee.getFunctionType ();
130
+ auto toTy = fTy .getInput (0 );
131
+ mlir::Value inputArg =
132
+ createConvertOp (rewriter, loc, toTy, declareOp.getResult ());
133
+ mlir::Value sourceFile = fir::factory::locationToFilename (builder, loc);
134
+ mlir::Value sourceLine =
135
+ fir::factory::locationToLineNo (builder, loc, fTy .getInput (2 ));
136
+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments (
137
+ builder, loc, fTy , inputArg, sourceFile, sourceLine)};
138
+ auto call = rewriter.create <fir::CallOp>(loc, callee, args);
139
+
140
+ return call->getResult (0 );
141
+ }
142
+
80
143
template <typename OpTy>
81
144
static mlir::LogicalResult convertOpToCall (OpTy op,
82
145
mlir::PatternRewriter &rewriter,
@@ -363,18 +426,14 @@ struct CufFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
363
426
}
364
427
};
365
428
366
- static mlir::Value createConvertOp (mlir::PatternRewriter &rewriter,
367
- mlir::Location loc, mlir::Type toTy,
368
- mlir::Value val) {
369
- if (val.getType () != toTy)
370
- return rewriter.create <fir::ConvertOp>(loc, toTy, val);
371
- return val;
372
- }
373
-
374
429
struct CufDataTransferOpConversion
375
430
: public mlir::OpRewritePattern<cuf::DataTransferOp> {
376
431
using OpRewritePattern::OpRewritePattern;
377
432
433
+ CufDataTransferOpConversion (mlir::MLIRContext *context,
434
+ const mlir::SymbolTable &symtab)
435
+ : OpRewritePattern(context), symtab{symtab} {}
436
+
378
437
mlir::LogicalResult
379
438
matchAndRewrite (cuf::DataTransferOp op,
380
439
mlir::PatternRewriter &rewriter) const override {
@@ -445,9 +504,11 @@ struct CufDataTransferOpConversion
445
504
mlir::Value sourceLine =
446
505
fir::factory::locationToLineNo (builder, loc, fTy .getInput (5 ));
447
506
448
- llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments (
449
- builder, loc, fTy , op.getDst (), op.getSrc (), bytes, modeValue,
450
- sourceFile, sourceLine)};
507
+ mlir::Value dst = getDeviceAddress (rewriter, op.getDstMutable (), symtab);
508
+ mlir::Value src = getDeviceAddress (rewriter, op.getSrcMutable (), symtab);
509
+ llvm::SmallVector<mlir::Value> args{
510
+ fir::runtime::createArguments (builder, loc, fTy , dst, src, bytes,
511
+ modeValue, sourceFile, sourceLine)};
451
512
builder.create <fir::CallOp>(loc, func, args);
452
513
rewriter.eraseOp (op);
453
514
return mlir::success ();
@@ -552,6 +613,9 @@ struct CufDataTransferOpConversion
552
613
}
553
614
return mlir::success ();
554
615
}
616
+
617
+ private:
618
+ const mlir::SymbolTable &symtab;
555
619
};
556
620
557
621
class CufOpConversion : public fir ::impl::CufOpConversionBase<CufOpConversion> {
@@ -565,13 +629,15 @@ class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
565
629
mlir::ModuleOp module = mlir::dyn_cast<mlir::ModuleOp>(op);
566
630
if (!module)
567
631
return signalPassFailure ();
632
+ mlir::SymbolTable symtab (module);
568
633
569
634
std::optional<mlir::DataLayout> dl =
570
635
fir::support::getOrSetDataLayout (module, /* allowDefaultLayout=*/ false );
571
636
fir::LLVMTypeConverter typeConverter (module, /* applyTBAA=*/ false ,
572
637
/* forceUnifiedTBAATree=*/ false , *dl);
573
638
target.addLegalDialect <fir::FIROpsDialect, mlir::arith::ArithDialect>();
574
- cuf::populateCUFToFIRConversionPatterns (typeConverter, *dl, patterns);
639
+ cuf::populateCUFToFIRConversionPatterns (typeConverter, *dl, symtab,
640
+ patterns);
575
641
if (mlir::failed (mlir::applyPartialConversion (getOperation (), target,
576
642
std::move (patterns)))) {
577
643
mlir::emitError (mlir::UnknownLoc::get (ctx),
@@ -584,9 +650,9 @@ class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
584
650
585
651
void cuf::populateCUFToFIRConversionPatterns (
586
652
const fir::LLVMTypeConverter &converter, mlir::DataLayout &dl,
587
- mlir::RewritePatternSet &patterns) {
653
+ const mlir::SymbolTable &symtab, mlir::RewritePatternSet &patterns) {
588
654
patterns.insert <CufAllocOpConversion>(patterns.getContext (), &dl, &converter);
589
655
patterns.insert <CufAllocateOpConversion, CufDeallocateOpConversion,
590
- CufFreeOpConversion, CufDataTransferOpConversion>(
591
- patterns.getContext ());
656
+ CufFreeOpConversion>(patterns. getContext ());
657
+ patterns.insert <CufDataTransferOpConversion>(patterns. getContext (), symtab );
592
658
}
0 commit comments