Skip to content

Commit 0605e04

Browse files
committed
[DSE] Optimizing shrinkinking of memory intrinsic
Currently for the following snippet: `memcpy(dst, src, 8); dst[7] = 0` DSE will transform it to: `memcpy(dst, src, 7); dst[7] = 0` Likewise if we have: `memcpy(dst, src, 9); dst[7] = 0; dst[8] = 0` DSE will transform it to: `memcpy(dst, src, 7); dst[7] = 0` However, in both cases we would prefer to emit an 8-byte `memcpy` followed by any overwrite of the trailing byte(s). This patch attempts to optimize the new intrinsic length within the available range of the original size and the maximally shrunk size.
1 parent b70792a commit 0605e04

File tree

3 files changed

+405
-155
lines changed

3 files changed

+405
-155
lines changed

llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp

Lines changed: 123 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
#include "llvm/Analysis/MustExecute.h"
4949
#include "llvm/Analysis/PostDominators.h"
5050
#include "llvm/Analysis/TargetLibraryInfo.h"
51+
#include "llvm/Analysis/TargetTransformInfo.h"
5152
#include "llvm/Analysis/ValueTracking.h"
5253
#include "llvm/IR/Argument.h"
5354
#include "llvm/IR/BasicBlock.h"
@@ -558,9 +559,10 @@ static void shortenAssignment(Instruction *Inst, Value *OriginalDest,
558559
for_each(LinkedDVRAssigns, InsertAssignForOverlap);
559560
}
560561

561-
static bool tryToShorten(Instruction *DeadI, int64_t &DeadStart,
562-
uint64_t &DeadSize, int64_t KillingStart,
563-
uint64_t KillingSize, bool IsOverwriteEnd) {
562+
static bool tryToShorten(Instruction *DeadI, int64_t DeadStart,
563+
uint64_t DeadSize, int64_t KillingStart,
564+
uint64_t KillingSize, bool IsOverwriteEnd,
565+
const TargetTransformInfo &TTI) {
564566
auto *DeadIntrinsic = cast<AnyMemIntrinsic>(DeadI);
565567
Align PrefAlign = DeadIntrinsic->getDestAlign().valueOrOne();
566568

@@ -583,11 +585,7 @@ static bool tryToShorten(Instruction *DeadI, int64_t &DeadStart,
583585
// Compute start and size of the region to remove. Make sure 'PrefAlign' is
584586
// maintained on the remaining store.
585587
if (IsOverwriteEnd) {
586-
// Calculate required adjustment for 'KillingStart' in order to keep
587-
// remaining store size aligned on 'PerfAlign'.
588-
uint64_t Off =
589-
offsetToAlignment(uint64_t(KillingStart - DeadStart), PrefAlign);
590-
ToRemoveStart = KillingStart + Off;
588+
ToRemoveStart = KillingStart;
591589
if (DeadSize <= uint64_t(ToRemoveStart - DeadStart))
592590
return false;
593591
ToRemoveSize = DeadSize - uint64_t(ToRemoveStart - DeadStart);
@@ -612,6 +610,108 @@ static bool tryToShorten(Instruction *DeadI, int64_t &DeadStart,
612610
assert(DeadSize > ToRemoveSize && "Can't remove more than original size");
613611

614612
uint64_t NewSize = DeadSize - ToRemoveSize;
613+
614+
// Try to coerce the new memcpy/memset size to a "fast" value. This typically
615+
// means some exact multiple of the register width of the loads/stores.
616+
617+
// If scalar size >= vec size, assume target will use scalars for implementing
618+
// memset/memcpy.
619+
TypeSize ScalarSize =
620+
TTI.getRegisterBitWidth(TargetTransformInfo::RGK_Scalar);
621+
TypeSize VecSize =
622+
TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector);
623+
uint64_t MemUnit = 0;
624+
if (ScalarSize >= VecSize)
625+
MemUnit = ScalarSize.getFixedValue();
626+
// Otherwise assume memset/memcpy will be lowered with Vec's
627+
else
628+
MemUnit =
629+
TTI.getLoadStoreVecRegBitWidth(DeadIntrinsic->getDestAddressSpace());
630+
631+
MemUnit /= 8U;
632+
633+
// Assume loads/stores are issued by power of 2 regions. Try to minimize
634+
// number of power of 2 blocks.
635+
// ie if we have DeadSize = 15
636+
// NewSize = 7 -> 8 (4 + 3 + 2 + 1) -> (8)
637+
// NewSize = 9 -> 9 (8 + 1) == (8 + 1)
638+
// NewSize = 11 -> 12 (8 + 2 + 1) -> (8 + 4)
639+
uint64_t Upper = DeadSize;
640+
uint64_t Lower = NewSize;
641+
642+
uint64_t RoundLower = MemUnit * (Lower / MemUnit);
643+
644+
// We have some trailing loads/stores we can try to optimize.
645+
if (RoundLower != Lower && Lower != 0 && (RoundLower + MemUnit) != 0) {
646+
Upper = std::min(Upper, RoundLower + MemUnit - 1);
647+
// Don't bust inlining doing this.
648+
uint64_t InlineThresh = TTI.getMaxMemIntrinsicInlineSizeThreshold();
649+
if (Upper > InlineThresh && Lower <= InlineThresh)
650+
Upper = InlineThresh;
651+
652+
// Replace Lower with value in range [Lower, Upper] that has min popcount
653+
// (selecting for minimum value as tiebreaker when popcount is the same).
654+
// The idea here is this will require the minimum number of load/stores and
655+
// within that will use the presumably preferable minimum width.
656+
657+
// Get highest bit that differs between Lower and Upper. Anything above this
658+
// bit must be in the new value. Anything below it thats larger than Lower
659+
// is fair game.
660+
uint64_t Dif = (Lower - 1) ^ Upper;
661+
uint64_t HighestBit = 63 - llvm::countl_zero(Dif);
662+
663+
// Make Lo/Hi masks from the HighestDif bit. Lo mask is use to find value we
664+
// can roundup for minimum power of 2 chunk, Hi mask is preserved.
665+
uint64_t HighestP2 = static_cast<uint64_t>(1) << HighestBit;
666+
uint64_t LoMask = HighestP2 - 1;
667+
uint64_t HiMask = -HighestP2;
668+
669+
// Minimum power of 2 for the "tail"
670+
uint64_t LoVal = Lower & LoMask;
671+
if (LoVal)
672+
LoVal = llvm::bit_ceil(LoVal);
673+
// Preserved high bits to stay in range.
674+
uint64_t HiVal = Lower & HiMask;
675+
Lower = LoVal | HiVal;
676+
677+
// If we have more than two tail stores see if we can just roundup the next
678+
// memunit.
679+
if (llvm::popcount(Lower % MemUnit) > 1 &&
680+
DeadSize >= (RoundLower + MemUnit))
681+
Lower = RoundLower + MemUnit;
682+
683+
uint64_t OptimizedNewSize = NewSize;
684+
// If we are over-writing the begining, make sure we don't mess up the
685+
// alignment.
686+
if (IsOverwriteEnd || isAligned(PrefAlign, DeadSize - Lower)) {
687+
OptimizedNewSize = Lower;
688+
} else {
689+
// Our minimal value isn't properly aligned, see if we can
690+
// increase the size of a tail loads/stores.
691+
Lower = HiVal | HighestP2;
692+
if (isAligned(PrefAlign, DeadSize - Lower))
693+
OptimizedNewSize = Lower;
694+
// If we can't adjust size without messing up alignment, see if the new
695+
// size is actually preferable.
696+
// TODO: We should probably do better here than just giving up.
697+
else if ((NewSize <= InlineThresh) == (DeadSize <= InlineThresh) &&
698+
llvm::popcount(NewSize) > llvm::popcount(DeadSize) &&
699+
DeadSize / MemUnit == NewSize / MemUnit)
700+
return false;
701+
}
702+
703+
// Adjust new starting point for the memset/memcpy.
704+
if (OptimizedNewSize != NewSize) {
705+
if (!IsOverwriteEnd)
706+
ToRemoveSize = DeadSize - OptimizedNewSize;
707+
NewSize = OptimizedNewSize;
708+
}
709+
710+
// Our optimal length is the original length, skip the transform.
711+
if (NewSize == DeadSize)
712+
return false;
713+
}
714+
615715
if (auto *AMI = dyn_cast<AtomicMemIntrinsic>(DeadI)) {
616716
// When shortening an atomic memory intrinsic, the newly shortened
617717
// length must remain an integer multiple of the element size.
@@ -654,7 +754,8 @@ static bool tryToShorten(Instruction *DeadI, int64_t &DeadStart,
654754
}
655755

656756
static bool tryToShortenEnd(Instruction *DeadI, OverlapIntervalsTy &IntervalMap,
657-
int64_t &DeadStart, uint64_t &DeadSize) {
757+
int64_t &DeadStart, uint64_t &DeadSize,
758+
const TargetTransformInfo &TTI) {
658759
if (IntervalMap.empty() || !isShortenableAtTheEnd(DeadI))
659760
return false;
660761

@@ -672,7 +773,7 @@ static bool tryToShortenEnd(Instruction *DeadI, OverlapIntervalsTy &IntervalMap,
672773
// be non negative due to preceding checks.
673774
KillingSize >= DeadSize - (uint64_t)(KillingStart - DeadStart)) {
674775
if (tryToShorten(DeadI, DeadStart, DeadSize, KillingStart, KillingSize,
675-
true)) {
776+
true, TTI)) {
676777
IntervalMap.erase(OII);
677778
return true;
678779
}
@@ -682,7 +783,8 @@ static bool tryToShortenEnd(Instruction *DeadI, OverlapIntervalsTy &IntervalMap,
682783

683784
static bool tryToShortenBegin(Instruction *DeadI,
684785
OverlapIntervalsTy &IntervalMap,
685-
int64_t &DeadStart, uint64_t &DeadSize) {
786+
int64_t &DeadStart, uint64_t &DeadSize,
787+
const TargetTransformInfo &TTI) {
686788
if (IntervalMap.empty() || !isShortenableAtTheBeginning(DeadI))
687789
return false;
688790

@@ -701,7 +803,7 @@ static bool tryToShortenBegin(Instruction *DeadI,
701803
assert(KillingSize - (uint64_t)(DeadStart - KillingStart) < DeadSize &&
702804
"Should have been handled as OW_Complete");
703805
if (tryToShorten(DeadI, DeadStart, DeadSize, KillingStart, KillingSize,
704-
false)) {
806+
false, TTI)) {
705807
IntervalMap.erase(OII);
706808
return true;
707809
}
@@ -824,6 +926,7 @@ struct DSEState {
824926
DominatorTree &DT;
825927
PostDominatorTree &PDT;
826928
const TargetLibraryInfo &TLI;
929+
const TargetTransformInfo &TTI;
827930
const DataLayout &DL;
828931
const LoopInfo &LI;
829932

@@ -868,9 +971,9 @@ struct DSEState {
868971

869972
DSEState(Function &F, AliasAnalysis &AA, MemorySSA &MSSA, DominatorTree &DT,
870973
PostDominatorTree &PDT, const TargetLibraryInfo &TLI,
871-
const LoopInfo &LI)
974+
const TargetTransformInfo &TTI, const LoopInfo &LI)
872975
: F(F), AA(AA), EI(DT, &LI), BatchAA(AA, &EI), MSSA(MSSA), DT(DT),
873-
PDT(PDT), TLI(TLI), DL(F.getDataLayout()), LI(LI) {
976+
PDT(PDT), TLI(TLI), TTI(TTI), DL(F.getDataLayout()), LI(LI) {
874977
// Collect blocks with throwing instructions not modeled in MemorySSA and
875978
// alloc-like objects.
876979
unsigned PO = 0;
@@ -2066,10 +2169,10 @@ struct DSEState {
20662169
uint64_t DeadSize = Loc.Size.getValue();
20672170
GetPointerBaseWithConstantOffset(Ptr, DeadStart, DL);
20682171
OverlapIntervalsTy &IntervalMap = OI.second;
2069-
Changed |= tryToShortenEnd(DeadI, IntervalMap, DeadStart, DeadSize);
2172+
Changed |= tryToShortenEnd(DeadI, IntervalMap, DeadStart, DeadSize, TTI);
20702173
if (IntervalMap.empty())
20712174
continue;
2072-
Changed |= tryToShortenBegin(DeadI, IntervalMap, DeadStart, DeadSize);
2175+
Changed |= tryToShortenBegin(DeadI, IntervalMap, DeadStart, DeadSize, TTI);
20732176
}
20742177
return Changed;
20752178
}
@@ -2137,10 +2240,11 @@ struct DSEState {
21372240
static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA,
21382241
DominatorTree &DT, PostDominatorTree &PDT,
21392242
const TargetLibraryInfo &TLI,
2243+
const TargetTransformInfo &TTI,
21402244
const LoopInfo &LI) {
21412245
bool MadeChange = false;
21422246

2143-
DSEState State(F, AA, MSSA, DT, PDT, TLI, LI);
2247+
DSEState State(F, AA, MSSA, DT, PDT, TLI, TTI, LI);
21442248
// For each store:
21452249
for (unsigned I = 0; I < State.MemDefs.size(); I++) {
21462250
MemoryDef *KillingDef = State.MemDefs[I];
@@ -2332,12 +2436,13 @@ static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA,
23322436
PreservedAnalyses DSEPass::run(Function &F, FunctionAnalysisManager &AM) {
23332437
AliasAnalysis &AA = AM.getResult<AAManager>(F);
23342438
const TargetLibraryInfo &TLI = AM.getResult<TargetLibraryAnalysis>(F);
2439+
const TargetTransformInfo &TTI = AM.getResult<TargetIRAnalysis>(F);
23352440
DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F);
23362441
MemorySSA &MSSA = AM.getResult<MemorySSAAnalysis>(F).getMSSA();
23372442
PostDominatorTree &PDT = AM.getResult<PostDominatorTreeAnalysis>(F);
23382443
LoopInfo &LI = AM.getResult<LoopAnalysis>(F);
23392444

2340-
bool Changed = eliminateDeadStores(F, AA, MSSA, DT, PDT, TLI, LI);
2445+
bool Changed = eliminateDeadStores(F, AA, MSSA, DT, PDT, TLI, TTI, LI);
23412446

23422447
#ifdef LLVM_ENABLE_STATS
23432448
if (AreStatisticsEnabled())

0 commit comments

Comments
 (0)