Skip to content

Commit 60105ac

Browse files
authored
[flang][cuda] Fix kernel registration (#113372)
The registration needs the fct pointer and the name. This patch updates the entry point with an extra arg and the translation as well.
1 parent a3508e0 commit 60105ac

File tree

3 files changed

+14
-10
lines changed

3 files changed

+14
-10
lines changed

flang/include/flang/Runtime/CUDA/registration.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ extern "C" {
2020
void *RTDECL(CUFRegisterModule)(void *data);
2121

2222
/// Register a device function.
23-
void RTDECL(CUFRegisterFunction)(void **module, const char *fct);
23+
void RTDECL(CUFRegisterFunction)(
24+
void **module, const char *fctSym, char *fctName);
2425

2526
} // extern "C"
2627

flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,15 @@ LogicalResult registerKernel(cuf::RegisterKernelOp op,
6363
llvm::Type *ptrTy = builder.getPtrTy(0);
6464
llvm::FunctionCallee fct = module->getOrInsertFunction(
6565
RTNAME_STRING(CUFRegisterFunction),
66-
llvm::FunctionType::get(ptrTy, ArrayRef<llvm::Type *>({ptrTy, ptrTy}),
67-
false));
66+
llvm::FunctionType::get(
67+
ptrTy, ArrayRef<llvm::Type *>({ptrTy, ptrTy, ptrTy}), false));
6868
llvm::Value *modulePtr = moduleTranslation.lookupValue(op.getModulePtr());
69-
builder.CreateCall(
70-
fct, {modulePtr, getOrCreateFunctionName(module, builder,
71-
op.getKernelModuleName().str(),
72-
op.getKernelName().str())});
69+
llvm::Function *fctSym =
70+
moduleTranslation.lookupFunction(op.getKernelName().str());
71+
builder.CreateCall(fct, {modulePtr, fctSym,
72+
getOrCreateFunctionName(
73+
module, builder, op.getKernelModuleName().str(),
74+
op.getKernelName().str())});
7375
return mlir::success();
7476
}
7577

flang/runtime/CUDA/registration.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@ void *RTDECL(CUFRegisterModule)(void *data) {
2626
return fatHandle;
2727
}
2828

29-
void RTDEF(CUFRegisterFunction)(void **module, const char *fct) {
30-
__cudaRegisterFunction(module, fct, const_cast<char *>(fct), fct, -1,
31-
(uint3 *)0, (uint3 *)0, (dim3 *)0, (dim3 *)0, (int *)0);
29+
void RTDEF(CUFRegisterFunction)(
30+
void **module, const char *fctSym, char *fctName) {
31+
__cudaRegisterFunction(module, fctSym, fctName, fctName, -1, (uint3 *)0,
32+
(uint3 *)0, (dim3 *)0, (dim3 *)0, (int *)0);
3233
}
3334
}
3435
} // namespace Fortran::runtime::cuda

0 commit comments

Comments
 (0)