Skip to content

[SCEV] Use Step and Start to check if SCEVWrapPredicate is implied. #118184

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions llvm/include/llvm/Analysis/ScalarEvolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ class SCEVPredicate : public FoldingSetNode {
virtual bool isAlwaysTrue() const = 0;

/// Returns true if this predicate implies \p N.
virtual bool implies(const SCEVPredicate *N) const = 0;
virtual bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const = 0;

/// Prints a textual representation of this predicate with an indentation of
/// \p Depth.
Expand Down Expand Up @@ -286,7 +286,7 @@ class SCEVComparePredicate final : public SCEVPredicate {
const SCEV *LHS, const SCEV *RHS);

/// Implementation of the SCEVPredicate interface
bool implies(const SCEVPredicate *N) const override;
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override;
void print(raw_ostream &OS, unsigned Depth = 0) const override;
bool isAlwaysTrue() const override;

Expand Down Expand Up @@ -393,7 +393,7 @@ class SCEVWrapPredicate final : public SCEVPredicate {

/// Implementation of the SCEVPredicate interface
const SCEVAddRecExpr *getExpr() const;
bool implies(const SCEVPredicate *N) const override;
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override;
void print(raw_ostream &OS, unsigned Depth = 0) const override;
bool isAlwaysTrue() const override;

Expand All @@ -418,16 +418,17 @@ class SCEVUnionPredicate final : public SCEVPredicate {
SmallVector<const SCEVPredicate *, 16> Preds;

/// Adds a predicate to this union.
void add(const SCEVPredicate *N);
void add(const SCEVPredicate *N, ScalarEvolution &SE);

public:
SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds);
SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds,
ScalarEvolution &SE);

ArrayRef<const SCEVPredicate *> getPredicates() const { return Preds; }

/// Implementation of the SCEVPredicate interface
bool isAlwaysTrue() const override;
bool implies(const SCEVPredicate *N) const override;
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override;
void print(raw_ostream &OS, unsigned Depth) const override;

/// We estimate the complexity of a union predicate as the size number of
Expand Down
80 changes: 58 additions & 22 deletions llvm/lib/Analysis/ScalarEvolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5706,8 +5706,9 @@ bool PredicatedScalarEvolution::areAddRecsEqualWithPreds(
return true;

auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
if (Expr1 != Expr2 && !Preds->implies(SE.getEqualPredicate(Expr1, Expr2)) &&
!Preds->implies(SE.getEqualPredicate(Expr2, Expr1)))
if (Expr1 != Expr2 &&
!Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) &&
!Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE))
return false;
return true;
};
Expand Down Expand Up @@ -14857,7 +14858,7 @@ class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
bool addOverflowAssumption(const SCEVPredicate *P) {
if (!NewPreds) {
// Check if we've already made this assumption.
return Pred && Pred->implies(P);
return Pred && Pred->implies(P, SE);
}
NewPreds->push_back(P);
return true;
Expand Down Expand Up @@ -14938,7 +14939,8 @@ SCEVComparePredicate::SCEVComparePredicate(const FoldingSetNodeIDRef ID,
assert(LHS != RHS && "LHS and RHS are the same SCEV");
}

bool SCEVComparePredicate::implies(const SCEVPredicate *N) const {
bool SCEVComparePredicate::implies(const SCEVPredicate *N,
ScalarEvolution &SE) const {
const auto *Op = dyn_cast<SCEVComparePredicate>(N);

if (!Op)
Expand Down Expand Up @@ -14968,10 +14970,40 @@ SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID,

const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }

bool SCEVWrapPredicate::implies(const SCEVPredicate *N) const {
bool SCEVWrapPredicate::implies(const SCEVPredicate *N,
ScalarEvolution &SE) const {
const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
if (!Op || setFlags(Flags, Op->Flags) != Flags)
return false;

if (Op->AR == AR)
return true;

if (Flags != SCEVWrapPredicate::IncrementNSSW &&
Flags != SCEVWrapPredicate::IncrementNUSW)
return false;

return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags;
const SCEV *Step = AR->getStepRecurrence(SE);
const SCEV *OpStep = Op->AR->getStepRecurrence(SE);
if (!SE.isKnownPositive(Step) || !SE.isKnownPositive(OpStep))
return false;

// If both steps are positive, this implies N, if N's start and step are
// ULE/SLE (for NSUW/NSSW) than this'.
Type *WiderTy = SE.getWiderType(Step->getType(), OpStep->getType());
Step = SE.getNoopOrZeroExtend(Step, WiderTy);
OpStep = SE.getNoopOrZeroExtend(OpStep, WiderTy);

bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW;
const SCEV *OpStart = Op->AR->getStart();
const SCEV *Start = AR->getStart();
OpStart = IsNUW ? SE.getNoopOrZeroExtend(OpStart, WiderTy)
: SE.getNoopOrSignExtend(OpStart, WiderTy);
Start = IsNUW ? SE.getNoopOrZeroExtend(Start, WiderTy)
: SE.getNoopOrSignExtend(Start, WiderTy);
CmpInst::Predicate Pred = IsNUW ? CmpInst::ICMP_ULE : CmpInst::ICMP_SLE;
return SE.isKnownPredicate(Pred, OpStep, Step) &&
SE.isKnownPredicate(Pred, OpStart, Start);
}

bool SCEVWrapPredicate::isAlwaysTrue() const {
Expand Down Expand Up @@ -15015,48 +15047,51 @@ SCEVWrapPredicate::getImpliedFlags(const SCEVAddRecExpr *AR,
}

/// Union predicates don't get cached so create a dummy set ID for it.
SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds)
: SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds,
ScalarEvolution &SE)
: SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
for (const auto *P : Preds)
add(P);
add(P, SE);
}

bool SCEVUnionPredicate::isAlwaysTrue() const {
return all_of(Preds,
[](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
}

bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const {
bool SCEVUnionPredicate::implies(const SCEVPredicate *N,
ScalarEvolution &SE) const {
if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
return all_of(Set->Preds,
[this](const SCEVPredicate *I) { return this->implies(I); });
return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) {
return this->implies(I, SE);
});

return any_of(Preds,
[N](const SCEVPredicate *I) { return I->implies(N); });
[N, &SE](const SCEVPredicate *I) { return I->implies(N, SE); });
}

void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const {
for (const auto *Pred : Preds)
Pred->print(OS, Depth);
}

void SCEVUnionPredicate::add(const SCEVPredicate *N) {
void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE) {
if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
for (const auto *Pred : Set->Preds)
add(Pred);
add(Pred, SE);
return;
}

// Only add predicate if it is not already implied by this union predicate.
if (!implies(N))
if (!implies(N, SE))
Preds.push_back(N);
}

PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE,
Loop &L)
: SE(SE), L(L) {
SmallVector<const SCEVPredicate*, 4> Empty;
Preds = std::make_unique<SCEVUnionPredicate>(Empty);
Preds = std::make_unique<SCEVUnionPredicate>(Empty, SE);
}

void ScalarEvolution::registerUser(const SCEV *User,
Expand Down Expand Up @@ -15120,12 +15155,12 @@ unsigned PredicatedScalarEvolution::getSmallConstantMaxTripCount() {
}

void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) {
if (Preds->implies(&Pred))
if (Preds->implies(&Pred, SE))
return;

SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates());
NewPreds.push_back(&Pred);
Preds = std::make_unique<SCEVUnionPredicate>(NewPreds);
Preds = std::make_unique<SCEVUnionPredicate>(NewPreds, SE);
updateGeneration();
}

Expand Down Expand Up @@ -15192,9 +15227,10 @@ const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) {

PredicatedScalarEvolution::PredicatedScalarEvolution(
const PredicatedScalarEvolution &Init)
: RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates())),
Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
: RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates(),
SE)),
Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
for (auto I : Init.FlagsMap)
FlagsMap.insert(I);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,19 @@ target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
; CHECK-NEXT: Run-time memory checks:
; CHECK-NEXT: Check 0:
; CHECK-NEXT: Comparing group
; CHECK-NEXT: %arrayidx = getelementptr inbounds i32, ptr %a, i64 %idxprom
; CHECK-NEXT: Against group
; CHECK-NEXT: %arrayidx4 = getelementptr inbounds i32, ptr %b, i64 %conv11
; CHECK-NEXT: Against group
; CHECK-NEXT: %arrayidx = getelementptr inbounds i32, ptr %a, i64 %idxprom
; CHECK-NEXT: Grouped accesses:
; CHECK-NEXT: Group
; CHECK-NEXT: (Low: (4 + %a) High: (4 + (4 * (1 umax %x)) + %a))
; CHECK-NEXT: Member: {(4 + %a),+,4}<%for.body>
; CHECK-NEXT: Group
; CHECK-NEXT: (Low: %b High: ((4 * (1 umax %x)) + %b))
; CHECK-NEXT: Member: {%b,+,4}<%for.body>
; CHECK-NEXT: Group
; CHECK-NEXT: (Low: (4 + %a) High: (4 + (4 * (1 umax %x)) + %a))
; CHECK-NEXT: Member: {(4 + %a),+,4}<%for.body>
; CHECK: Non vectorizable stores to invariant address were not found in loop.
; CHECK-NEXT: SCEV assumptions:
; CHECK-NEXT: {1,+,1}<%for.body> Added Flags: <nusw>
; CHECK-NEXT: {0,+,1}<%for.body> Added Flags: <nusw>
; CHECK: Expressions re-written:
; CHECK-NEXT: [PSE] %arrayidx = getelementptr inbounds i32, ptr %a, i64 %idxprom:
; CHECK-NEXT: ((4 * (zext i32 {1,+,1}<%for.body> to i64))<nuw><nsw> + %a)<nuw>
Expand Down Expand Up @@ -85,7 +84,6 @@ exit:
; CHECK: Memory dependences are safe
; CHECK: SCEV assumptions:
; CHECK-NEXT: {1,+,1}<%for.body> Added Flags: <nusw>
; CHECK-NEXT: {0,+,1}<%for.body> Added Flags: <nusw>
define void @test2(i64 %x, ptr %a) {
entry:
br label %for.body
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

target datalayout = "e-m:o-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-n32:64-S128-Fn32"

; FIXME: {0,+,3} implies {0,+,2}.
; {0,+,3} [nssw] implies {0,+,2} [nssw]
define void @wrap_check_iv.3_implies_iv.2(i32 noundef %N, ptr %dst, ptr %src) {
; CHECK-LABEL: 'wrap_check_iv.3_implies_iv.2'
; CHECK-NEXT: loop:
Expand All @@ -26,7 +26,6 @@ define void @wrap_check_iv.3_implies_iv.2(i32 noundef %N, ptr %dst, ptr %src) {
; CHECK-NEXT: Non vectorizable stores to invariant address were not found in loop.
; CHECK-NEXT: SCEV assumptions:
; CHECK-NEXT: {0,+,3}<%loop> Added Flags: <nssw>
; CHECK-NEXT: {0,+,2}<%loop> Added Flags: <nssw>
; CHECK-EMPTY:
; CHECK-NEXT: Expressions re-written:
; CHECK-NEXT: [PSE] %gep.iv.2 = getelementptr inbounds i32, ptr %src, i64 %ext.iv.2:
Expand Down Expand Up @@ -59,7 +58,7 @@ exit:
ret void
}

; FIXME: {2,+,2} implies {0,+,2}.
; {2,+,2} [nssw] implies {0,+,2} [nssw].
define void @wrap_check_iv.3_implies_iv.2_different_start(i32 noundef %N, ptr %dst, ptr %src) {
; CHECK-LABEL: 'wrap_check_iv.3_implies_iv.2_different_start'
; CHECK-NEXT: loop:
Expand All @@ -82,7 +81,6 @@ define void @wrap_check_iv.3_implies_iv.2_different_start(i32 noundef %N, ptr %d
; CHECK-NEXT: Non vectorizable stores to invariant address were not found in loop.
; CHECK-NEXT: SCEV assumptions:
; CHECK-NEXT: {2,+,2}<%loop> Added Flags: <nssw>
; CHECK-NEXT: {0,+,2}<%loop> Added Flags: <nssw>
; CHECK-EMPTY:
; CHECK-NEXT: Expressions re-written:
; CHECK-NEXT: [PSE] %gep.iv.2 = getelementptr inbounds i32, ptr %src, i64 %ext.iv.2:
Expand Down
Loading