Skip to content

Commit 0eb6000

Browse files
authored
Merge pull request llvm#122 from agozillon/fix-for-dependent-PR
[Flang][MLIR] Fix to mapping adding additional map pointer component for Descriptor mappings, helps to fix use_device_addr/ptr with these types
2 parents a01a588 + 5c943d8 commit 0eb6000

File tree

5 files changed

+121
-57
lines changed

5 files changed

+121
-57
lines changed

flang/lib/Optimizer/Transforms/OMPMapInfoFinalization.cpp

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,34 @@ class OMPMapInfoFinalizationPass
5959
/*corresponding local alloca=*/fir::AllocaOp>
6060
localBoxAllocas;
6161

62+
unsigned long getDescriptorMapType(unsigned long mapTypeFlag,
63+
mlir::Operation *target) {
64+
auto newDescFlag = llvm::omp::OpenMPOffloadMappingFlags(mapTypeFlag);
65+
66+
if ((llvm::isa_and_nonnull<mlir::omp::TargetDataOp>(target) ||
67+
llvm::isa_and_nonnull<mlir::omp::TargetOp>(target)) &&
68+
static_cast<
69+
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
70+
(newDescFlag &
71+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM)) &&
72+
static_cast<
73+
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
74+
(newDescFlag & llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO) !=
75+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO))
76+
return static_cast<
77+
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
78+
newDescFlag | llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
79+
80+
if ((llvm::isa_and_nonnull<mlir::omp::TargetDataOp>(target) ||
81+
llvm::isa_and_nonnull<mlir::omp::TargetEnterDataOp>(target)) &&
82+
mapTypeFlag == 0)
83+
return static_cast<
84+
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
85+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
86+
87+
return mapTypeFlag;
88+
}
89+
6290
void genDescriptorMemberMaps(mlir::omp::MapInfoOp op,
6391
fir::FirOpBuilder &builder,
6492
mlir::Operation *target) {
@@ -176,8 +204,9 @@ class OMPMapInfoFinalizationPass
176204
mlir::IntegerType::get(builder.getContext(), 32)),
177205
llvm::ArrayRef<int32_t>({0})),
178206
/*bounds=*/mlir::SmallVector<mlir::Value>{},
179-
builder.getIntegerAttr(builder.getIntegerType(64, false),
180-
op.getMapType().value()),
207+
builder.getIntegerAttr(
208+
builder.getIntegerType(64, false),
209+
getDescriptorMapType(op.getMapType().value(), target)),
181210
op.getMapCaptureTypeAttr(), op.getNameAttr(), op.getPartialMapAttr());
182211
op.replaceAllUsesWith(newDescParentMapOp);
183212
op->erase();

flang/test/Integration/OpenMP/map-types-and-sizes.f90

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,17 @@ subroutine mapType_array
3030
!$omp end target
3131
end subroutine mapType_array
3232

33-
!CHECK: @.offload_sizes{{.*}} = private unnamed_addr constant [3 x i64] [i64 0, i64 24, i64 4]
34-
!CHECK: @.offload_maptypes{{.*}} = private unnamed_addr constant [3 x i64] [i64 32, i64 281474976710657, i64 281474976711187]
33+
!CHECK: @.offload_sizes{{.*}} = private unnamed_addr constant [4 x i64] [i64 0, i64 24, i64 8, i64 4]
34+
!CHECK: @.offload_maptypes{{.*}} = private unnamed_addr constant [4 x i64] [i64 32, i64 281474976711171, i64 281474976711171, i64 281474976711187]
3535
subroutine mapType_ptr
3636
integer, pointer :: a
3737
!$omp target
3838
a = 10
3939
!$omp end target
4040
end subroutine mapType_ptr
4141

42-
!CHECK: @.offload_sizes{{.*}} = private unnamed_addr constant [3 x i64] [i64 0, i64 24, i64 4]
43-
!CHECK: @.offload_maptypes{{.*}} = private unnamed_addr constant [3 x i64] [i64 32, i64 281474976710657, i64 281474976711187]
42+
!CHECK: @.offload_sizes{{.*}} = private unnamed_addr constant [4 x i64] [i64 0, i64 24, i64 8, i64 4]
43+
!CHECK: @.offload_maptypes{{.*}} = private unnamed_addr constant [4 x i64] [i64 32, i64 281474976711171, i64 281474976711171, i64 281474976711187]
4444
subroutine mapType_allocatable
4545
integer, allocatable :: a
4646
allocate(a)
@@ -50,17 +50,17 @@ subroutine mapType_allocatable
5050
deallocate(a)
5151
end subroutine mapType_allocatable
5252

53-
!CHECK: @.offload_sizes{{.*}} = private unnamed_addr constant [3 x i64] [i64 0, i64 24, i64 4]
54-
!CHECK: @.offload_maptypes{{.*}} = private unnamed_addr constant [3 x i64] [i64 32, i64 281474976710657, i64 281474976710675]
53+
!CHECK: @.offload_sizes{{.*}} = private unnamed_addr constant [4 x i64] [i64 0, i64 24, i64 8, i64 4]
54+
!CHECK: @.offload_maptypes{{.*}} = private unnamed_addr constant [4 x i64] [i64 32, i64 281474976710659, i64 281474976710659, i64 281474976710675]
5555
subroutine mapType_ptr_explicit
5656
integer, pointer :: a
5757
!$omp target map(tofrom: a)
5858
a = 10
5959
!$omp end target
6060
end subroutine mapType_ptr_explicit
6161

62-
!CHECK: @.offload_sizes{{.*}} = private unnamed_addr constant [3 x i64] [i64 0, i64 24, i64 4]
63-
!CHECK: @.offload_maptypes{{.*}} = private unnamed_addr constant [3 x i64] [i64 32, i64 281474976710657, i64 281474976710675]
62+
!CHECK: @.offload_sizes{{.*}} = private unnamed_addr constant [4 x i64] [i64 0, i64 24, i64 8, i64 4]
63+
!CHECK: @.offload_maptypes{{.*}} = private unnamed_addr constant [4 x i64] [i64 32, i64 281474976710659, i64 281474976710659, i64 281474976710675]
6464
subroutine mapType_allocatable_explicit
6565
integer, allocatable :: a
6666
allocate(a)
@@ -263,7 +263,7 @@ end subroutine mapType_common_block_members
263263
!CHECK: %[[ALLOCA_INT:.*]] = ptrtoint ptr %[[ALLOCA]] to i64
264264
!CHECK: %[[SIZE_DIFF:.*]] = sub i64 %[[ALLOCA_GEP_INT]], %[[ALLOCA_INT]]
265265
!CHECK: %[[DIV:.*]] = sdiv exact i64 %[[SIZE_DIFF]], ptrtoint (ptr getelementptr (i8, ptr null, i32 1) to i64)
266-
!CHECK: %[[OFFLOAD_SIZE_ARR:.*]] = getelementptr inbounds [3 x i64], ptr %.offload_sizes, i32 0, i32 0
266+
!CHECK: %[[OFFLOAD_SIZE_ARR:.*]] = getelementptr inbounds [4 x i64], ptr %.offload_sizes, i32 0, i32 0
267267
!CHECK: store i64 %[[DIV]], ptr %[[OFFLOAD_SIZE_ARR]], align 8
268268

269269
!CHECK-LABEL: define {{.*}} @{{.*}}maptype_allocatable_explicit_{{.*}}
@@ -273,7 +273,7 @@ end subroutine mapType_common_block_members
273273
!CHECK: %[[ALLOCA_INT:.*]] = ptrtoint ptr %[[ALLOCA]] to i64
274274
!CHECK: %[[SIZE_DIFF:.*]] = sub i64 %[[ALLOCA_GEP_INT]], %[[ALLOCA_INT]]
275275
!CHECK: %[[DIV:.*]] = sdiv exact i64 %[[SIZE_DIFF]], ptrtoint (ptr getelementptr (i8, ptr null, i32 1) to i64)
276-
!CHECK: %[[OFFLOAD_SIZE_ARR:.*]] = getelementptr inbounds [3 x i64], ptr %.offload_sizes, i32 0, i32 0
276+
!CHECK: %[[OFFLOAD_SIZE_ARR:.*]] = getelementptr inbounds [4 x i64], ptr %.offload_sizes, i32 0, i32 0
277277
!CHECK: store i64 %[[DIV]], ptr %[[OFFLOAD_SIZE_ARR]], align 8
278278

279279
!CHECK-LABEL: define {{.*}} @{{.*}}maptype_derived_implicit_{{.*}}

flang/test/Transforms/omp-map-info-finalization.fir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ module attributes {omp.is_target_device = false} {
1111
%5 = fir.allocmem i32 {fir.must_be_heap = true}
1212
%6 = fir.embox %5 : (!fir.heap<i32>) -> !fir.box<!fir.heap<i32>>
1313
fir.store %6 to %4#1 : !fir.ref<!fir.box<!fir.heap<i32>>>
14-
%c0 = arith.constant 1 : index
14+
%c0 = arith.constant 1 : index
1515
%c1 = arith.constant 0 : index
1616
%c2 = arith.constant 10 : index
1717
%dims:3 = fir.box_dims %2#1, %c1 : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
@@ -39,7 +39,7 @@ module attributes {omp.is_target_device = false} {
3939
// CHECK: fir.store %[[DECLARE1]]#1 to %[[ALLOCA]] : !fir.ref<!fir.box<!fir.array<?xi32>>>
4040
// CHECK: %[[BASE_ADDR_OFF_2:.*]] = fir.box_offset %[[ALLOCA]] base_addr : (!fir.ref<!fir.box<!fir.array<?xi32>>>) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>
4141
// CHECK: %[[DESC_MEMBER_MAP_2:.*]] = omp.map.info var_ptr(%[[ALLOCA]] : !fir.ref<!fir.box<!fir.array<?xi32>>>, !fir.array<?xi32>) var_ptr_ptr(%[[BASE_ADDR_OFF_2]] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) map_clauses(from) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>> {name = ""}
42-
// CHECK: %[[DESC_PARENT_MAP_2:.*]] = omp.map.info var_ptr(%[[ALLOCA]] : !fir.ref<!fir.box<!fir.array<?xi32>>>, !fir.box<!fir.array<?xi32>>) map_clauses(from) capture(ByRef) members(%[[DESC_MEMBER_MAP_2]] : [0] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) -> !fir.ref<!fir.array<?xi32>>
42+
// CHECK: %[[DESC_PARENT_MAP_2:.*]] = omp.map.info var_ptr(%[[ALLOCA]] : !fir.ref<!fir.box<!fir.array<?xi32>>>, !fir.box<!fir.array<?xi32>>) map_clauses(tofrom) capture(ByRef) members(%[[DESC_MEMBER_MAP_2]] : [0] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) -> !fir.ref<!fir.array<?xi32>>
4343
// CHECK: omp.target map_entries(%[[DESC_MEMBER_MAP]] -> %[[ARG1:.*]], %[[DESC_PARENT_MAP]] -> %[[ARG2:.*]], %[[DESC_MEMBER_MAP_2]] -> %[[ARG3:.*]], %[[DESC_PARENT_MAP_2]] -> %[[ARG4:.*]] : {{.*}}) {
4444
// CHECK: ^bb0(%[[ARG1]]: !fir.llvm_ptr<!fir.ref<i32>>, %[[ARG2]]: !fir.ref<!fir.box<!fir.heap<i32>>>, %[[ARG3]]: !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, %[[ARG4]]: !fir.ref<!fir.array<?xi32>>):
4545

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2635,10 +2635,13 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
26352635
mlir::dyn_cast<mlir::omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
26362636
int firstMemberIdx = getMapDataMemberIdx(
26372637
mapData, getFirstOrLastMappedMemberPtr(mapOp, true));
2638-
lowAddr = builder.CreatePointerCast(mapData.Pointers[firstMemberIdx],
2639-
builder.getPtrTy());
26402638
int lastMemberIdx = getMapDataMemberIdx(
26412639
mapData, getFirstOrLastMappedMemberPtr(mapOp, false));
2640+
2641+
// NOTE/TODO: Should perhaps use OriginalValue here instead of Pointers to
2642+
// avoid offset or any manipulations interfering with the calculation.
2643+
lowAddr = builder.CreatePointerCast(mapData.Pointers[firstMemberIdx],
2644+
builder.getPtrTy());
26422645
highAddr = builder.CreatePointerCast(
26432646
builder.CreateGEP(mapData.BaseType[lastMemberIdx],
26442647
mapData.Pointers[lastMemberIdx], builder.getInt64(1)),
@@ -2652,17 +2655,8 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
26522655
/*isSigned=*/false);
26532656
combinedInfo.Sizes.push_back(size);
26542657

2655-
// TODO: This will need to be expanded to include the whole host of logic for
2656-
// the map flags that Clang currently supports (e.g. it should take the map
2657-
// flag of the parent map flag, remove the OMP_MAP_TARGET_PARAM and do some
2658-
// further case specific flag modifications). For the moment, it handles what
2659-
// we support as expected.
2660-
llvm::omp::OpenMPOffloadMappingFlags mapFlag =
2661-
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
2662-
26632658
llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
26642659
ompBuilder.getMemberOfFlag(combinedInfo.BasePointers.size() - 1);
2665-
ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
26662660

26672661
// This creates the initial MEMBER_OF mapping that consists of
26682662
// the parent/top level container (same as above effectively, except
@@ -2671,6 +2665,12 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
26712665
// only relevant if the structure in its totality is being mapped,
26722666
// otherwise the above suffices.
26732667
if (!parentClause.getPartialMap()) {
2668+
// TODO: This will need to be expanded to include the whole host of logic
2669+
// for the map flags that Clang currently supports (e.g. it should do some
2670+
// further case specific flag modifications). For the moment, it handles
2671+
// what we support as expected.
2672+
llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types[mapDataIndex];
2673+
ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
26742674
combinedInfo.Types.emplace_back(mapFlag);
26752675
combinedInfo.DevicePointers.emplace_back(
26762676
mapData.DevicePointers[mapDataIndex]);
@@ -2721,6 +2721,24 @@ static void processMapMembersWithParent(
27212721

27222722
assert(memberDataIdx >= 0 && "could not find mapped member of structure");
27232723

2724+
if (checkIfPointerMap(memberClause)) {
2725+
auto mapFlag = llvm::omp::OpenMPOffloadMappingFlags(
2726+
memberClause.getMapType().value());
2727+
mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
2728+
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
2729+
ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
2730+
combinedInfo.Types.emplace_back(mapFlag);
2731+
combinedInfo.DevicePointers.emplace_back(
2732+
llvm::OpenMPIRBuilder::DeviceInfoTy::None);
2733+
combinedInfo.Names.emplace_back(
2734+
LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder));
2735+
combinedInfo.BasePointers.emplace_back(
2736+
mapData.BasePointers[mapDataIndex]);
2737+
combinedInfo.Pointers.emplace_back(mapData.BasePointers[memberDataIdx]);
2738+
combinedInfo.Sizes.emplace_back(builder.getInt64(
2739+
moduleTranslation.getLLVMModule()->getDataLayout().getPointerSize()));
2740+
}
2741+
27242742
// Same MemberOfFlag to indicate its link with parent and other members
27252743
// of.
27262744
auto mapFlag =
@@ -2736,7 +2754,14 @@ static void processMapMembersWithParent(
27362754
llvm::OpenMPIRBuilder::DeviceInfoTy::None);
27372755
combinedInfo.Names.emplace_back(
27382756
LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder));
2739-
combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
2757+
2758+
if (checkIfPointerMap(memberClause))
2759+
combinedInfo.BasePointers.emplace_back(
2760+
mapData.BasePointers[memberDataIdx]);
2761+
else
2762+
combinedInfo.BasePointers.emplace_back(
2763+
mapData.BasePointers[mapDataIndex]);
2764+
27402765
combinedInfo.Pointers.emplace_back(mapData.Pointers[memberDataIdx]);
27412766
combinedInfo.Sizes.emplace_back(mapData.Sizes[memberDataIdx]);
27422767
}

0 commit comments

Comments
 (0)