Skip to content

Commit 968dac4

Browse files
authored
Initial support for mapping allocatables/pointers in derived types and related map syntax (llvm#160)
* Apply upstream allocatable member mapping with some minor modifications * fix rebase issues * Fix tests
1 parent e7b6197 commit 968dac4

28 files changed

+2080
-450
lines changed

flang/include/flang/Lower/OpenMP/Clauses.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,13 @@ struct IdTyTemplate {
5555
return designator == other.designator;
5656
}
5757

58+
// Defining an "ordering" which allows types derived from this to be
59+
// utilised in maps and other containers that require comparison
60+
// operators for ordering
61+
bool operator<(const IdTyTemplate &other) const {
62+
return symbol < other.symbol;
63+
}
64+
5865
operator bool() const { return symbol != nullptr; }
5966
};
6067

@@ -76,6 +83,10 @@ struct ObjectT<Fortran::lower::omp::IdTyTemplate<Fortran::lower::omp::ExprTy>,
7683
Fortran::semantics::Symbol *sym() const { return identity.symbol; }
7784
const std::optional<ExprTy> &ref() const { return identity.designator; }
7885

86+
bool operator<(const ObjectT<IdTy, ExprTy> &other) const {
87+
return identity < other.identity;
88+
}
89+
7990
IdTy identity;
8091
};
8192
} // namespace tomp::type

flang/include/flang/Lower/OpenMP/Utils.h

Lines changed: 97 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/IR/Location.h"
1515
#include "mlir/IR/Value.h"
1616
#include "llvm/Support/CommandLine.h"
17+
#include <cstdint>
1718

1819
extern llvm::cl::opt<bool> treatIndexAsSection;
1920
extern llvm::cl::opt<bool> enableDelayedPrivatization;
@@ -34,6 +35,7 @@ struct OmpObjectList;
3435
} // namespace parser
3536

3637
namespace lower {
38+
class StatementContext;
3739
namespace pft {
3840
struct Evaluation;
3941
}
@@ -49,38 +51,111 @@ using DeclareTargetCapturePair =
4951
// and index data when lowering OpenMP map clauses. Keeps track of the
5052
// placement of the component in the derived type hierarchy it rests within,
5153
// alongside the generated mlir::omp::MapInfoOp for the mapped component.
52-
struct OmpMapMemberIndicesData {
54+
//
55+
// As an example of what the contents of this data structure may be like,
56+
// when provided the following derived type and map of that type:
57+
//
58+
// type :: bottom_layer
59+
// real(8) :: i2
60+
// real(4) :: array_i2(10)
61+
// real(4) :: array_j2(10)
62+
// end type bottom_layer
63+
//
64+
// type :: top_layer
65+
// real(4) :: i
66+
// integer(4) :: array_i(10)
67+
// real(4) :: j
68+
// type(bottom_layer) :: nested
69+
// integer, allocatable :: array_j(:)
70+
// integer(4) :: k
71+
// end type top_layer
72+
//
73+
// type(top_layer) :: top_dtype
74+
//
75+
// map(tofrom: top_dtype%nested%i2, top_dtype%k, top_dtype%nested%array_i2)
76+
//
77+
// We would end up with an OmpMapParentAndMemberData populated like below:
78+
//
79+
// memberPlacementIndices:
80+
// Vector 1: 3, 0
81+
// Vector 2: 5
82+
// Vector 3: 3, 1
83+
//
84+
// memberMap:
85+
// Entry 1: omp.map.info for "top_dtype%nested%i2"
86+
// Entry 2: omp.map.info for "top_dtype%k"
87+
// Entry 3: omp.map.info for "top_dtype%nested%array_i2"
88+
//
89+
// And this OmpMapParentAndMemberData would be accessed via the parent
90+
// symbol for top_dtype. Other parent derived type instances that have
91+
// members mapped would have there own OmpMapParentAndMemberData entry
92+
// accessed via their own symbol.
93+
struct OmpMapParentAndMemberData {
5394
// The indices representing the component members placement in its derived
5495
// type parents hierarchy.
55-
llvm::SmallVector<int> memberPlacementIndices;
96+
llvm::SmallVector<llvm::SmallVector<int64_t>> memberPlacementIndices;
5697

5798
// Placement of the member in the member vector.
58-
mlir::omp::MapInfoOp memberMap;
99+
llvm::SmallVector<mlir::omp::MapInfoOp> memberMap;
100+
101+
// The list of associated parent object symbols. used to track data we
102+
// need for various parent processing tasks when performing member
103+
// mapping, the main example currently being re-evaluating the parent
104+
// maps bounds at the final step of map processing, where we need to
105+
// keep a hold of all of the omp::Object's which contain array bounds
106+
// for the respective parent to calculate the final bounds from.
107+
//
108+
// As an Example:
109+
//
110+
// !$omp target map(tofrom: alloca_dtype_arr(2)%array_i,
111+
// alloca_dtype_arr(3)%array_i)
112+
//
113+
// parentObjList will contain alloca_dtype_arr(3) as well as
114+
// alloca_dtype_arr(2).
115+
ObjectList parentObjList;
59116
};
60117

61-
mlir::omp::MapInfoOp
62-
createMapInfoOp(mlir::OpBuilder &builder, mlir::Location loc,
63-
mlir::Value baseAddr, mlir::Value varPtrPtr, std::string name,
64-
mlir::ArrayRef<mlir::Value> bounds,
65-
mlir::ArrayRef<mlir::Value> members,
66-
mlir::DenseIntElementsAttr membersIndex, uint64_t mapType,
67-
mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy,
68-
bool partialMap = false);
69-
70-
void addChildIndexAndMapToParent(
71-
const omp::Object &object,
72-
std::map<const semantics::Symbol *,
73-
llvm::SmallVector<OmpMapMemberIndicesData>> &parentMemberIndices,
74-
mlir::omp::MapInfoOp &mapOp, semantics::SemanticsContext &semaCtx);
118+
void generateMemberPlacementIndices(
119+
const Object &object, llvm::SmallVectorImpl<int64_t> &indices,
120+
Fortran::semantics::SemanticsContext &semaCtx);
121+
122+
bool isMemberOrParentAllocatableOrPointer(
123+
const Object &object, Fortran::semantics::SemanticsContext &semaCtx);
124+
125+
bool isDuplicateMemberMapInfo(OmpMapParentAndMemberData &parentMembers,
126+
llvm::SmallVectorImpl<int64_t> &memberIndices);
127+
128+
mlir::omp::MapInfoOp createMapInfoOp(
129+
fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value baseAddr,
130+
mlir::Value varPtrPtr, std::string name, mlir::ArrayRef<mlir::Value> bounds,
131+
mlir::ArrayRef<mlir::Value> members, mlir::ArrayAttr membersIndex,
132+
uint64_t mapType, mlir::omp::VariableCaptureKind mapCaptureType,
133+
mlir::Type retTy, bool partialMap = false);
134+
135+
mlir::Value createParentSymAndGenIntermediateMaps(
136+
mlir::Location clauseLocation, Fortran::lower::AbstractConverter &converter,
137+
semantics::SemanticsContext &semaCtx, lower::StatementContext &stmtCtx,
138+
omp::ObjectList &objectList, llvm::SmallVector<int64_t> &indices,
139+
OmpMapParentAndMemberData &parentMemberIndices, std::string asFortran,
140+
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits);
141+
142+
omp::ObjectList gatherObjects(omp::Object obj,
143+
semantics::SemanticsContext &semaCtx);
144+
145+
void addChildIndexAndMapToParent(const omp::Object &object,
146+
OmpMapParentAndMemberData &parentMemberIndices,
147+
mlir::omp::MapInfoOp &mapOp,
148+
semantics::SemanticsContext &semaCtx);
75149

76150
void insertChildMapInfoIntoParent(
77-
lower::AbstractConverter &converter,
78-
std::map<const semantics::Symbol *,
79-
llvm::SmallVector<OmpMapMemberIndicesData>> &parentMemberIndices,
151+
Fortran::lower::AbstractConverter &converter,
152+
Fortran::semantics::SemanticsContext &semaCtx,
153+
Fortran::lower::StatementContext &stmtCtx,
154+
std::map<Object, OmpMapParentAndMemberData> &parentMemberIndices,
80155
llvm::SmallVectorImpl<mlir::Value> &mapOperands,
81-
llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms,
82156
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes,
83-
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs);
157+
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
158+
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSymbols);
84159

85160
mlir::Type getLoopVarType(lower::AbstractConverter &converter,
86161
std::size_t loopVarTypeSize);
@@ -94,8 +169,6 @@ void gatherFuncAndVarSyms(
94169

95170
int64_t getCollapseValue(const List<Clause> &clauses);
96171

97-
semantics::Symbol *getOmpObjectSymbol(const parser::OmpObject &ompObject);
98-
99172
void genObjectList(const ObjectList &objects,
100173
lower::AbstractConverter &converter,
101174
llvm::SmallVectorImpl<mlir::Value> &operands);

flang/include/flang/Optimizer/Builder/FIRBuilder.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,11 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
212212
llvm::ArrayRef<mlir::Value> lenParams,
213213
bool asTarget = false);
214214

215+
/// Create a two dimensional ArrayAttr containing integer data as
216+
/// IntegerAttrs, effectively: ArrayAttr<ArrayAttr<IntegerAttr>>>.
217+
mlir::ArrayAttr create2DIntegerArrayAttr(
218+
llvm::SmallVectorImpl<llvm::SmallVector<int64_t>> &intData);
219+
215220
/// Create a temporary using `fir.alloca`. This function does not hoist.
216221
/// It is the callers responsibility to set the insertion point if
217222
/// hoisting is required.

flang/lib/Frontend/FrontendActions.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,7 @@ bool CodeGenAction::beginSourceFileAction() {
372372
return false;
373373
}
374374

375+
375376
// Print initial full MLIR module, before lowering or transformations, if
376377
// -save-temps has been specified.
377378
if (!saveMLIRTempFile(ci.getInvocation(), *mlirModule, getCurrentFile(),

flang/lib/Lower/DirectivesCommon.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1026,7 +1026,10 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
10261026
// If it is a scalar subscript, then the upper bound
10271027
// is equal to the lower bound, and the extent is one.
10281028
ubound = lbound;
1029-
extent = one;
1029+
if (treatIndexAsSection)
1030+
extent = fir::factory::readExtent(builder, loc, dataExv, dimension);
1031+
else
1032+
extent = one;
10301033
} else {
10311034
asFortran << ':';
10321035
Fortran::semantics::MaybeExpr upper =

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 46 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -889,16 +889,17 @@ void ClauseProcessor::processMapObjects(
889889
lower::StatementContext &stmtCtx, mlir::Location clauseLocation,
890890
const omp::ObjectList &objects,
891891
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits,
892-
std::map<const semantics::Symbol *,
893-
llvm::SmallVector<OmpMapMemberIndicesData>> &parentMemberIndices,
892+
std::map<Object, OmpMapParentAndMemberData> &parentMemberIndices,
894893
llvm::SmallVectorImpl<mlir::Value> &mapVars,
895894
llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms,
896895
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
897896
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes) const {
898897
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
898+
899899
for (const omp::Object &object : objects) {
900900
llvm::SmallVector<mlir::Value> bounds;
901901
std::stringstream asFortran;
902+
std::optional<omp::Object> parentObj;
902903

903904
lower::AddrAndBoundsInfo info =
904905
lower::gatherDataOperandAddrAndBounds<mlir::omp::MapBoundsOp,
@@ -907,28 +908,47 @@ void ClauseProcessor::processMapObjects(
907908
object.ref(), clauseLocation, asFortran, bounds,
908909
treatIndexAsSection);
909910

911+
mlir::Value baseOp = info.rawInput;
912+
if (object.sym()->owner().IsDerivedType()) {
913+
omp::ObjectList objectList = gatherObjects(object, semaCtx);
914+
assert(!objectList.empty() &&
915+
"could not find parent objects of derived type member");
916+
parentObj = objectList[0];
917+
auto insert = parentMemberIndices.emplace(parentObj.value(),
918+
OmpMapParentAndMemberData{});
919+
insert.first->second.parentObjList.push_back(parentObj.value());
920+
921+
if (isMemberOrParentAllocatableOrPointer(object, semaCtx)) {
922+
llvm::SmallVector<int64_t> indices;
923+
generateMemberPlacementIndices(object, indices, semaCtx);
924+
baseOp = createParentSymAndGenIntermediateMaps(
925+
clauseLocation, converter, semaCtx, stmtCtx, objectList, indices,
926+
parentMemberIndices[parentObj.value()], asFortran.str(),
927+
mapTypeBits);
928+
}
929+
}
930+
910931
// Explicit map captures are captured ByRef by default,
911932
// optimisation passes may alter this to ByCopy or other capture
912933
// types to optimise
913-
mlir::Value baseOp = info.rawInput;
914934
auto location = mlir::NameLoc::get(
915935
mlir::StringAttr::get(firOpBuilder.getContext(), asFortran.str()),
916936
baseOp.getLoc());
917937
mlir::omp::MapInfoOp mapOp = createMapInfoOp(
918938
firOpBuilder, location, baseOp,
919939
/*varPtrPtr=*/mlir::Value{}, asFortran.str(), bounds,
920-
/*members=*/{}, /*membersIndex=*/mlir::DenseIntElementsAttr{},
940+
/*members=*/{}, /*membersIndex=*/mlir::ArrayAttr{},
921941
static_cast<
922942
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
923943
mapTypeBits),
924944
mlir::omp::VariableCaptureKind::ByRef, baseOp.getType());
925945

926-
if (object.sym()->owner().IsDerivedType()) {
927-
addChildIndexAndMapToParent(object, parentMemberIndices, mapOp, semaCtx);
946+
if (parentObj.has_value()) {
947+
addChildIndexAndMapToParent(
948+
object, parentMemberIndices[parentObj.value()], mapOp, semaCtx);
928949
} else {
929950
mapVars.push_back(mapOp);
930-
if (mapSyms)
931-
mapSyms->push_back(object.sym());
951+
mapSyms->push_back(object.sym());
932952
if (mapSymTypes)
933953
mapSymTypes->push_back(baseOp.getType());
934954
if (mapSymLocs)
@@ -949,9 +969,7 @@ bool ClauseProcessor::processMap(
949969
llvm::SmallVector<const semantics::Symbol *> localMapSyms;
950970
llvm::SmallVectorImpl<const semantics::Symbol *> *ptrMapSyms =
951971
mapSyms ? mapSyms : &localMapSyms;
952-
std::map<const semantics::Symbol *,
953-
llvm::SmallVector<OmpMapMemberIndicesData>>
954-
parentMemberIndices;
972+
std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;
955973

956974
bool clauseFound = findRepeatableClause<omp::clause::Map>(
957975
[&](const omp::clause::Map &clause, const parser::CharBlock &source) {
@@ -997,23 +1015,22 @@ bool ClauseProcessor::processMap(
9971015
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
9981016
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
9991017
}
1018+
10001019
processMapObjects(stmtCtx, clauseLocation,
10011020
std::get<omp::ObjectList>(clause.t), mapTypeBits,
10021021
parentMemberIndices, result.mapVars, ptrMapSyms,
10031022
mapSymLocs, mapSymTypes);
10041023
});
10051024

1006-
insertChildMapInfoIntoParent(converter, parentMemberIndices, result.mapVars,
1007-
*ptrMapSyms, mapSymTypes, mapSymLocs);
1008-
1025+
insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
1026+
result.mapVars, mapSymTypes, mapSymLocs,
1027+
ptrMapSyms);
10091028
return clauseFound;
10101029
}
10111030

10121031
bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx,
10131032
mlir::omp::MapClauseOps &result) {
1014-
std::map<const semantics::Symbol *,
1015-
llvm::SmallVector<OmpMapMemberIndicesData>>
1016-
parentMemberIndices;
1033+
std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;
10171034
llvm::SmallVector<const semantics::Symbol *> mapSymbols;
10181035

10191036
auto callbackFn = [&](const auto &clause, const parser::CharBlock &source) {
@@ -1034,9 +1051,9 @@ bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx,
10341051
clauseFound =
10351052
findRepeatableClause<omp::clause::From>(callbackFn) || clauseFound;
10361053

1037-
insertChildMapInfoIntoParent(converter, parentMemberIndices, result.mapVars,
1038-
mapSymbols,
1039-
/*mapSymTypes=*/nullptr, /*mapSymLocs=*/nullptr);
1054+
insertChildMapInfoIntoParent(
1055+
converter, semaCtx, stmtCtx, parentMemberIndices, result.mapVars,
1056+
/*mapSymTypes=*/nullptr, /*mapSymLocs=*/nullptr, &mapSymbols);
10401057
return clauseFound;
10411058
}
10421059

@@ -1110,9 +1127,7 @@ bool ClauseProcessor::processUseDeviceAddr(
11101127
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
11111128
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
11121129
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const {
1113-
std::map<const semantics::Symbol *,
1114-
llvm::SmallVector<OmpMapMemberIndicesData>>
1115-
parentMemberIndices;
1130+
std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;
11161131
bool clauseFound = findRepeatableClause<omp::clause::UseDeviceAddr>(
11171132
[&](const omp::clause::UseDeviceAddr &clause,
11181133
const parser::CharBlock &source) {
@@ -1125,9 +1140,9 @@ bool ClauseProcessor::processUseDeviceAddr(
11251140
&useDeviceSyms, &useDeviceLocs, &useDeviceTypes);
11261141
});
11271142

1128-
insertChildMapInfoIntoParent(converter, parentMemberIndices,
1129-
result.useDeviceAddrVars, useDeviceSyms,
1130-
&useDeviceTypes, &useDeviceLocs);
1143+
insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
1144+
result.useDeviceAddrVars, &useDeviceTypes,
1145+
&useDeviceLocs, &useDeviceSyms);
11311146
return clauseFound;
11321147
}
11331148

@@ -1136,9 +1151,8 @@ bool ClauseProcessor::processUseDevicePtr(
11361151
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
11371152
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
11381153
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const {
1139-
std::map<const semantics::Symbol *,
1140-
llvm::SmallVector<OmpMapMemberIndicesData>>
1141-
parentMemberIndices;
1154+
std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;
1155+
11421156
bool clauseFound = findRepeatableClause<omp::clause::UseDevicePtr>(
11431157
[&](const omp::clause::UseDevicePtr &clause,
11441158
const parser::CharBlock &source) {
@@ -1151,9 +1165,9 @@ bool ClauseProcessor::processUseDevicePtr(
11511165
&useDeviceSyms, &useDeviceLocs, &useDeviceTypes);
11521166
});
11531167

1154-
insertChildMapInfoIntoParent(converter, parentMemberIndices,
1155-
result.useDevicePtrVars, useDeviceSyms,
1156-
&useDeviceTypes, &useDeviceLocs);
1168+
insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
1169+
result.useDevicePtrVars, &useDeviceTypes,
1170+
&useDeviceLocs, &useDeviceSyms);
11571171
return clauseFound;
11581172
}
11591173

0 commit comments

Comments
 (0)