Skip to content

[flang][OpenMP] Implement HAS_DEVICE_ADDR clause #128568

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Mar 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions flang/include/flang/Support/OpenMP-utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ struct EntryBlockArgsEntry {
/// Structure holding the information needed to create and bind entry block
/// arguments associated to all clauses that can define them.
struct EntryBlockArgs {
EntryBlockArgsEntry hasDeviceAddr;
llvm::ArrayRef<mlir::Value> hostEvalVars;
EntryBlockArgsEntry inReduction;
EntryBlockArgsEntry map;
Expand All @@ -44,21 +45,21 @@ struct EntryBlockArgs {
EntryBlockArgsEntry useDevicePtr;

bool isValid() const {
return inReduction.isValid() && map.isValid() && priv.isValid() &&
reduction.isValid() && taskReduction.isValid() &&
return hasDeviceAddr.isValid() && inReduction.isValid() && map.isValid() &&
priv.isValid() && reduction.isValid() && taskReduction.isValid() &&
useDeviceAddr.isValid() && useDevicePtr.isValid();
}

auto getSyms() const {
return llvm::concat<const semantics::Symbol *const>(inReduction.syms,
map.syms, priv.syms, reduction.syms, taskReduction.syms,
useDeviceAddr.syms, useDevicePtr.syms);
return llvm::concat<const semantics::Symbol *const>(hasDeviceAddr.syms,
inReduction.syms, map.syms, priv.syms, reduction.syms,
taskReduction.syms, useDeviceAddr.syms, useDevicePtr.syms);
}

auto getVars() const {
return llvm::concat<const mlir::Value>(hostEvalVars, inReduction.vars,
map.vars, priv.vars, reduction.vars, taskReduction.vars,
useDeviceAddr.vars, useDevicePtr.vars);
return llvm::concat<const mlir::Value>(hasDeviceAddr.vars, hostEvalVars,
inReduction.vars, map.vars, priv.vars, reduction.vars,
taskReduction.vars, useDeviceAddr.vars, useDevicePtr.vars);
}
};

Expand Down
34 changes: 27 additions & 7 deletions flang/lib/Lower/OpenMP/ClauseProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -849,14 +849,34 @@ bool ClauseProcessor::processDepend(mlir::omp::DependClauseOps &result) const {
}

bool ClauseProcessor::processHasDeviceAddr(
mlir::omp::HasDeviceAddrClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const {
return findRepeatableClause<omp::clause::HasDeviceAddr>(
[&](const omp::clause::HasDeviceAddr &devAddrClause,
const parser::CharBlock &) {
addUseDeviceClause(converter, devAddrClause.v, result.hasDeviceAddrVars,
isDeviceSyms);
lower::StatementContext &stmtCtx, mlir::omp::HasDeviceAddrClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &hasDeviceSyms) const {
// For HAS_DEVICE_ADDR objects, implicitly map the top-level entities.
// Their address (or the whole descriptor, if the entity had one) will be
// passed to the target region.
std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;
bool clauseFound = findRepeatableClause<omp::clause::HasDeviceAddr>(
[&](const omp::clause::HasDeviceAddr &clause,
const parser::CharBlock &source) {
mlir::Location location = converter.genLocation(source);
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
omp::ObjectList baseObjects;
llvm::transform(clause.v, std::back_inserter(baseObjects),
[&](const omp::Object &object) {
if (auto maybeBase = getBaseObject(object, semaCtx))
return *maybeBase;
return object;
});
processMapObjects(stmtCtx, location, baseObjects, mapTypeBits,
parentMemberIndices, result.hasDeviceAddrVars,
hasDeviceSyms);
});

insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
result.hasDeviceAddrVars, hasDeviceSyms);
return clauseFound;
}

bool ClauseProcessor::processIf(
Expand Down
3 changes: 2 additions & 1 deletion flang/lib/Lower/OpenMP/ClauseProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,9 @@ class ClauseProcessor {
bool processFinal(lower::StatementContext &stmtCtx,
mlir::omp::FinalClauseOps &result) const;
bool processHasDeviceAddr(
lower::StatementContext &stmtCtx,
mlir::omp::HasDeviceAddrClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
llvm::SmallVectorImpl<const semantics::Symbol *> &hasDeviceSyms) const;
bool processHint(mlir::omp::HintClauseOps &result) const;
bool processInclusive(mlir::Location currentLocation,
mlir::omp::InclusiveClauseOps &result) const;
Expand Down
7 changes: 6 additions & 1 deletion flang/lib/Lower/OpenMP/Clauses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,12 @@ std::optional<Object> getBaseObject(const Object &object,
return Object{SymbolAndDesignatorExtractor::symbol_addr(comp->symbol()),
ea.Designate(evaluate::DataRef{
SymbolAndDesignatorExtractor::AsRvalueRef(*comp)})};
} else if (base.UnwrapSymbolRef()) {
} else if (auto *symRef = base.UnwrapSymbolRef()) {
// This is the base symbol of the array reference, which is the same
// as the symbol in the input object,
// e.g. A(i) is represented as {Symbol(A), Designator(ArrayRef(A, i))}.
// Here we have the Symbol(A), which is what we started with.
assert(&**symRef == object.sym());
return std::nullopt;
}
} else {
Expand Down
12 changes: 10 additions & 2 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ static void bindEntryBlockArgs(lower::AbstractConverter &converter,
// Process in clause name alphabetical order to match block arguments order.
// Do not bind host_eval variables because they cannot be used inside of the
// corresponding region, except for very specific cases handled separately.
bindMapLike(args.hasDeviceAddr.syms, op.getHasDeviceAddrBlockArgs());
bindPrivateLike(args.inReduction.syms, args.inReduction.vars,
op.getInReductionBlockArgs());
bindMapLike(args.map.syms, op.getMapBlockArgs());
Expand Down Expand Up @@ -1654,7 +1655,7 @@ static void genTargetClauses(
cp.processBare(clauseOps);
cp.processDepend(clauseOps);
cp.processDevice(stmtCtx, clauseOps);
cp.processHasDeviceAddr(clauseOps, hasDeviceAddrSyms);
cp.processHasDeviceAddr(stmtCtx, clauseOps, hasDeviceAddrSyms);
if (!hostEvalInfo.empty()) {
// Only process host_eval if compiling for the host device.
processHostEvalClauses(converter, semaCtx, stmtCtx, eval, loc);
Expand Down Expand Up @@ -2200,6 +2201,10 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
if (dsp.getAllSymbolsToPrivatize().contains(&sym))
return;

// These symbols are mapped individually in processHasDeviceAddr.
if (llvm::is_contained(hasDeviceAddrSyms, &sym))
return;

// Structure component symbols don't have bindings, and can only be
// explicitly mapped individually. If a member is captured implicitly
// we map the entirety of the derived type when we find its symbol.
Expand Down Expand Up @@ -2290,10 +2295,13 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,

auto targetOp = firOpBuilder.create<mlir::omp::TargetOp>(loc, clauseOps);

llvm::SmallVector<mlir::Value> mapBaseValues;
llvm::SmallVector<mlir::Value> hasDeviceAddrBaseValues, mapBaseValues;
extractMappedBaseValues(clauseOps.hasDeviceAddrVars, hasDeviceAddrBaseValues);
extractMappedBaseValues(clauseOps.mapVars, mapBaseValues);

EntryBlockArgs args;
args.hasDeviceAddr.syms = hasDeviceAddrSyms;
args.hasDeviceAddr.vars = hasDeviceAddrBaseValues;
args.hostEvalVars = clauseOps.hostEvalVars;
// TODO: Add in_reduction syms and vars.
args.map.syms = mapSyms;
Expand Down
4 changes: 2 additions & 2 deletions flang/lib/Lower/OpenMP/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ mlir::Value createParentSymAndGenIntermediateMaps(
/// Checks if an omp::Object is an array expression with a subscript, e.g.
/// array(1,2).
auto isArrayExprWithSubscript = [](omp::Object obj) {
if (auto maybeRef = evaluate::ExtractDataRef(*obj.ref())) {
if (auto maybeRef = evaluate::ExtractDataRef(obj.ref())) {
evaluate::DataRef ref = *maybeRef;
if (auto *arr = std::get_if<evaluate::ArrayRef>(&ref.u))
return !arr->subscript().empty();
Expand Down Expand Up @@ -454,7 +454,7 @@ getComponentObject(std::optional<Object> object,
if (!object)
return std::nullopt;

auto ref = evaluate::ExtractDataRef(*object.value().ref());
auto ref = evaluate::ExtractDataRef(object.value().ref());
if (!ref)
return std::nullopt;

Expand Down
55 changes: 48 additions & 7 deletions flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,10 +250,24 @@ class MapInfoFinalizationPass

mapFlags flags = mapFlags::OMP_MAP_TO |
(mapFlags(mapTypeFlag) &
(mapFlags::OMP_MAP_IMPLICIT | mapFlags::OMP_MAP_CLOSE));
(mapFlags::OMP_MAP_IMPLICIT | mapFlags::OMP_MAP_CLOSE |
mapFlags::OMP_MAP_ALWAYS));
return llvm::to_underlying(flags);
}

/// Check if the mapOp is present in the HasDeviceAddr clause on
/// the userOp. Only applies to TargetOp.
bool isHasDeviceAddr(mlir::omp::MapInfoOp mapOp, mlir::Operation *userOp) {
assert(userOp && "Expecting non-null argument");
if (auto targetOp = llvm::dyn_cast<mlir::omp::TargetOp>(userOp)) {
for (mlir::Value hda : targetOp.getHasDeviceAddrVars()) {
if (hda.getDefiningOp() == mapOp)
return true;
}
}
return false;
}

mlir::omp::MapInfoOp genDescriptorMemberMaps(mlir::omp::MapInfoOp op,
fir::FirOpBuilder &builder,
mlir::Operation *target) {
Expand All @@ -263,11 +277,11 @@ class MapInfoFinalizationPass
// TODO: map the addendum segment of the descriptor, similarly to the
// base address/data pointer member.
mlir::Value descriptor = getDescriptorFromBoxMap(op, builder);
auto baseAddr = genBaseAddrMap(descriptor, op.getBounds(),
op.getMapType().value_or(0), builder);

mlir::ArrayAttr newMembersAttr;
mlir::SmallVector<mlir::Value> newMembers;
llvm::SmallVector<llvm::SmallVector<int64_t>> memberIndices;
bool IsHasDeviceAddr = isHasDeviceAddr(op, target);

if (!mapMemberUsers.empty() || !op.getMembers().empty())
getMemberIndicesAsVectors(
Expand All @@ -281,6 +295,12 @@ class MapInfoFinalizationPass
// member information to now have one new member for the base address, or
// we are expanding a parent that is a descriptor and we have to adjust
// all of its members to reflect the insertion of the base address.
//
// If we're expanding a top-level descriptor for a map operation that
// resulted from "has_device_addr" clause, then we want the base pointer
// from the descriptor to be used verbatim, i.e. without additional
// remapping. To avoid this remapping, simply don't generate any map
// information for the descriptor members.
if (!mapMemberUsers.empty()) {
// Currently, there should only be one user per map when this pass
// is executed. Either a parent map, holding the current map in its
Expand All @@ -291,6 +311,8 @@ class MapInfoFinalizationPass
assert(mapMemberUsers.size() == 1 &&
"OMPMapInfoFinalization currently only supports single users of a "
"MapInfoOp");
auto baseAddr = genBaseAddrMap(descriptor, op.getBounds(),
op.getMapType().value_or(0), builder);
ParentAndPlacement mapUser = mapMemberUsers[0];
adjustMemberIndices(memberIndices, mapUser.index);
llvm::SmallVector<mlir::Value> newMemberOps;
Expand All @@ -302,7 +324,9 @@ class MapInfoFinalizationPass
mapUser.parent.getMembersMutable().assign(newMemberOps);
mapUser.parent.setMembersIndexAttr(
builder.create2DI64ArrayAttr(memberIndices));
} else {
} else if (!IsHasDeviceAddr) {
auto baseAddr = genBaseAddrMap(descriptor, op.getBounds(),
op.getMapType().value_or(0), builder);
newMembers.push_back(baseAddr);
if (!op.getMembers().empty()) {
for (auto &indices : memberIndices)
Expand All @@ -316,15 +340,26 @@ class MapInfoFinalizationPass
}
}

// Descriptors for objects listed on the `has_device_addr` will always
// be copied. This is because the descriptor can be rematerialized by the
// compiler, and so the address of the descriptor for a given object at
// one place in the code may differ from that address in another place.
// The contents of the descriptor (the base address in particular) will
// remain unchanged though.
uint64_t MapType = op.getMapType().value_or(0);
if (IsHasDeviceAddr) {
MapType |= llvm::to_underlying(
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
}

mlir::omp::MapInfoOp newDescParentMapOp =
builder.create<mlir::omp::MapInfoOp>(
op->getLoc(), op.getResult().getType(), descriptor,
mlir::TypeAttr::get(fir::unwrapRefType(descriptor.getType())),
/*varPtrPtr=*/mlir::Value{}, newMembers, newMembersAttr,
/*bounds=*/mlir::SmallVector<mlir::Value>{},
builder.getIntegerAttr(
builder.getIntegerType(64, false),
getDescriptorMapType(op.getMapType().value_or(0), target)),
builder.getIntegerAttr(builder.getIntegerType(64, false),
getDescriptorMapType(MapType, target)),
/*mapperId*/ mlir::FlatSymbolRefAttr(), op.getMapCaptureTypeAttr(),
op.getNameAttr(),
/*partial_map=*/builder.getBoolAttr(false));
Expand Down Expand Up @@ -443,6 +478,12 @@ class MapInfoFinalizationPass
addOperands(useDevPtrMutableOpRange, target,
argIface.getUseDevicePtrBlockArgsStart() +
argIface.numUseDevicePtrBlockArgs());
} else if (auto targetOp = llvm::dyn_cast<mlir::omp::TargetOp>(target)) {
mlir::MutableOperandRange hasDevAddrMutableOpRange =
targetOp.getHasDeviceAddrVarsMutable();
addOperands(hasDevAddrMutableOpRange, target,
argIface.getHasDeviceAddrBlockArgsStart() +
argIface.numHasDeviceAddrBlockArgs());
}
}

Expand Down
10 changes: 6 additions & 4 deletions flang/lib/Support/OpenMP-utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@ mlir::Block *genEntryBlock(mlir::OpBuilder &builder, const EntryBlockArgs &args,

llvm::SmallVector<mlir::Type> types;
llvm::SmallVector<mlir::Location> locs;
unsigned numVars = args.hostEvalVars.size() + args.inReduction.vars.size() +
args.map.vars.size() + args.priv.vars.size() +
args.reduction.vars.size() + args.taskReduction.vars.size() +
args.useDeviceAddr.vars.size() + args.useDevicePtr.vars.size();
unsigned numVars = args.hasDeviceAddr.vars.size() + args.hostEvalVars.size() +
args.inReduction.vars.size() + args.map.vars.size() +
args.priv.vars.size() + args.reduction.vars.size() +
args.taskReduction.vars.size() + args.useDeviceAddr.vars.size() +
args.useDevicePtr.vars.size();
types.reserve(numVars);
locs.reserve(numVars);

Expand All @@ -34,6 +35,7 @@ mlir::Block *genEntryBlock(mlir::OpBuilder &builder, const EntryBlockArgs &args,

// Populate block arguments in clause name alphabetical order to match
// expected order by the BlockArgOpenMPOpInterface.
extractTypeLoc(args.hasDeviceAddr.vars);
extractTypeLoc(args.hostEvalVars);
extractTypeLoc(args.inReduction.vars);
extractTypeLoc(args.map.vars);
Expand Down
23 changes: 23 additions & 0 deletions flang/test/Lower/OpenMP/has_device_addr-mapinfo.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=51 %s -mmlir -mlir-print-op-generic -o - | FileCheck %s
!RUN: bbc -emit-hlfir -fopenmp -fopenmp-version=51 %s -mlir-print-op-generic -o - | FileCheck %s

! Check that we don't generate member information for the descriptor of `a`
! on entry to the target region.

integer function s(a)
integer :: a(:)
integer :: t
!$omp target data map(to:a) use_device_addr(a)
!$omp target map(from:t) has_device_addr(a)
t = size(a, 1)
!$omp end target
!$omp end target data
s = t
end

! Check that the map.info for `a` only takes a single parameter.

!CHECK-DAG: %[[MAP_A:[0-9]+]] = "omp.map.info"(%[[STORAGE_A:[0-9#]+]]) <{map_capture_type = #omp<variable_capture_kind(ByRef)>, map_type = 517 : ui64, name = "a", operandSegmentSizes = array<i32: 1, 0, 0, 0>, partial_map = false, var_type = !fir.box<!fir.array<?xi32>>}> : (!fir.ref<!fir.box<!fir.array<?xi32>>>) -> !fir.ref<!fir.array<?xi32>>
!CHECK-DAG: %[[MAP_T:[0-9]+]] = "omp.map.info"(%[[STORAGE_T:[0-9#]+]]) <{map_capture_type = #omp<variable_capture_kind(ByRef)>, map_type = 2 : ui64, name = "t", operandSegmentSizes = array<i32: 1, 0, 0, 0>, partial_map = false, var_type = i32}> : (!fir.ref<i32>) -> !fir.ref<i32>

!CHECK: "omp.target"(%[[MAP_A]], %[[MAP_T]])
11 changes: 8 additions & 3 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
Original file line number Diff line number Diff line change
Expand Up @@ -462,13 +462,18 @@ class OpenMP_HasDeviceAddrClauseSkip<
bit description = false, bit extraClassDeclaration = false
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
let traits = [
BlockArgOpenMPOpInterface
];

let arguments = (ins
Variadic<OpenMP_PointerLikeType>:$has_device_addr_vars
);

let optAssemblyFormat = [{
`has_device_addr` `(` $has_device_addr_vars `:` type($has_device_addr_vars)
`)`
let extraClassDeclaration = [{
unsigned numHasDeviceAddrBlockArgs() {
return getHasDeviceAddrVars().size();
}
}];

let description = [{
Expand Down
5 changes: 3 additions & 2 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1326,8 +1326,9 @@ def TargetOp : OpenMP_Op<"target", traits = [
}] # clausesExtraClassDeclaration;

let assemblyFormat = clausesAssemblyFormat # [{
custom<HostEvalInReductionMapPrivateRegion>(
$region, $host_eval_vars, type($host_eval_vars), $in_reduction_vars,
custom<TargetOpRegion>(
$region, $has_device_addr_vars, type($has_device_addr_vars),
$host_eval_vars, type($host_eval_vars), $in_reduction_vars,
type($in_reduction_vars), $in_reduction_byref, $in_reduction_syms,
$map_vars, type($map_vars), $private_vars, type($private_vars),
$private_syms, $private_maps) attr-dict
Expand Down
13 changes: 8 additions & 5 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,10 @@ class BlockArgOpenMPClause<string clauseNameSnake, string clauseNameCamel,
>;
}

def BlockArgHostEvalClause : BlockArgOpenMPClause<"host_eval", "HostEval", ?>;
def BlockArgHasDeviceAddrClause : BlockArgOpenMPClause<
"has_device_addr", "HasDeviceAddr", ?>;
def BlockArgHostEvalClause : BlockArgOpenMPClause<
"host_eval", "HostEval", BlockArgHasDeviceAddrClause>;
def BlockArgInReductionClause : BlockArgOpenMPClause<
"in_reduction", "InReduction", BlockArgHostEvalClause>;
def BlockArgMapClause : BlockArgOpenMPClause<
Expand All @@ -100,10 +103,10 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {

let cppNamespace = "::mlir::omp";

defvar clauses = [ BlockArgHostEvalClause, BlockArgInReductionClause,
BlockArgMapClause, BlockArgPrivateClause, BlockArgReductionClause,
BlockArgTaskReductionClause, BlockArgUseDeviceAddrClause,
BlockArgUseDevicePtrClause ];
defvar clauses = [ BlockArgHasDeviceAddrClause, BlockArgHostEvalClause,
BlockArgInReductionClause, BlockArgMapClause, BlockArgPrivateClause,
BlockArgReductionClause, BlockArgTaskReductionClause,
BlockArgUseDeviceAddrClause, BlockArgUseDevicePtrClause ];

let methods = !listconcat(
!foreach(clause, clauses, clause.numArgsMethod),
Expand Down
Loading