Skip to content

Commit 34e5a71

Browse files
[InstCombine] Combine ptrauth constants into ptrauth intrinsics. (#94705)
When we encounter two consecutive ptrauth intrinsics, we can already combine the inner matching sign + auth pair, e.g.: resign(sign(p,ks,ds),ks,ds,kr,dr) -> sign(p,kr,dr) We can generalize that to ptrauth constants, which are effectively constant equivalents to ptrauth.sign, i.e.: resign(ptrauth(p,ks,ds),ks,ds,kr,dr) -> ptrauth(p,kr,dr) auth(ptrauth(p,k,d),k,d) -> p While there, cleanup a redundant return after eraseInstFromFunction in the shared (intrinsic|constant)->intrinsic folding code.
1 parent 1b8ab2f commit 34e5a71

File tree

2 files changed

+98
-3
lines changed

2 files changed

+98
-3
lines changed

llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2643,13 +2643,14 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
26432643
// (sign|resign) + (auth|resign) can be folded by omitting the middle
26442644
// sign+auth component if the key and discriminator match.
26452645
bool NeedSign = II->getIntrinsicID() == Intrinsic::ptrauth_resign;
2646+
Value *Ptr = II->getArgOperand(0);
26462647
Value *Key = II->getArgOperand(1);
26472648
Value *Disc = II->getArgOperand(2);
26482649

26492650
// AuthKey will be the key we need to end up authenticating against in
26502651
// whatever we replace this sequence with.
26512652
Value *AuthKey = nullptr, *AuthDisc = nullptr, *BasePtr;
2652-
if (auto CI = dyn_cast<CallBase>(II->getArgOperand(0))) {
2653+
if (const auto *CI = dyn_cast<CallBase>(Ptr)) {
26532654
BasePtr = CI->getArgOperand(0);
26542655
if (CI->getIntrinsicID() == Intrinsic::ptrauth_sign) {
26552656
if (CI->getArgOperand(1) != Key || CI->getArgOperand(2) != Disc)
@@ -2661,6 +2662,27 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
26612662
AuthDisc = CI->getArgOperand(2);
26622663
} else
26632664
break;
2665+
} else if (const auto *PtrToInt = dyn_cast<PtrToIntOperator>(Ptr)) {
2666+
// ptrauth constants are equivalent to a call to @llvm.ptrauth.sign for
2667+
// our purposes, so check for that too.
2668+
const auto *CPA = dyn_cast<ConstantPtrAuth>(PtrToInt->getOperand(0));
2669+
if (!CPA || !CPA->isKnownCompatibleWith(Key, Disc, DL))
2670+
break;
2671+
2672+
// resign(ptrauth(p,ks,ds),ks,ds,kr,dr) -> ptrauth(p,kr,dr)
2673+
if (NeedSign && isa<ConstantInt>(II->getArgOperand(4))) {
2674+
auto *SignKey = cast<ConstantInt>(II->getArgOperand(3));
2675+
auto *SignDisc = cast<ConstantInt>(II->getArgOperand(4));
2676+
auto *SignAddrDisc = ConstantPointerNull::get(Builder.getPtrTy());
2677+
auto *NewCPA = ConstantPtrAuth::get(CPA->getPointer(), SignKey,
2678+
SignDisc, SignAddrDisc);
2679+
replaceInstUsesWith(
2680+
*II, ConstantExpr::getPointerCast(NewCPA, II->getType()));
2681+
return eraseInstFromFunction(*II);
2682+
}
2683+
2684+
// auth(ptrauth(p,k,d),k,d) -> p
2685+
BasePtr = Builder.CreatePtrToInt(CPA->getPointer(), II->getType());
26642686
} else
26652687
break;
26662688

@@ -2677,8 +2699,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
26772699
} else {
26782700
// sign(0) + auth(0) = nop
26792701
replaceInstUsesWith(*II, BasePtr);
2680-
eraseInstFromFunction(*II);
2681-
return nullptr;
2702+
return eraseInstFromFunction(*II);
26822703
}
26832704

26842705
SmallVector<Value *, 4> CallArgs;

llvm/test/Transforms/InstCombine/ptrauth-intrinsics.ll

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,27 @@ define i64 @test_ptrauth_nop(ptr %p) {
1212
ret i64 %authed
1313
}
1414

15+
declare void @foo()
16+
declare void @bar()
17+
18+
define i64 @test_ptrauth_nop_constant() {
19+
; CHECK-LABEL: @test_ptrauth_nop_constant(
20+
; CHECK-NEXT: ret i64 ptrtoint (ptr @foo to i64)
21+
;
22+
%authed = call i64 @llvm.ptrauth.auth(i64 ptrtoint(ptr ptrauth(ptr @foo, i32 1, i64 1234) to i64), i32 1, i64 1234)
23+
ret i64 %authed
24+
}
25+
26+
define i64 @test_ptrauth_nop_constant_addrdisc() {
27+
; CHECK-LABEL: @test_ptrauth_nop_constant_addrdisc(
28+
; CHECK-NEXT: ret i64 ptrtoint (ptr @foo to i64)
29+
;
30+
%addr = ptrtoint ptr @foo to i64
31+
%blended = call i64 @llvm.ptrauth.blend(i64 %addr, i64 1234)
32+
%authed = call i64 @llvm.ptrauth.auth(i64 ptrtoint(ptr ptrauth(ptr @foo, i32 1, i64 1234, ptr @foo) to i64), i32 1, i64 %blended)
33+
ret i64 %authed
34+
}
35+
1536
define i64 @test_ptrauth_nop_mismatch(ptr %p) {
1637
; CHECK-LABEL: @test_ptrauth_nop_mismatch(
1738
; CHECK-NEXT: [[TMP0:%.*]] = ptrtoint ptr [[P:%.*]] to i64
@@ -87,6 +108,59 @@ define i64 @test_ptrauth_resign_auth_mismatch(ptr %p) {
87108
ret i64 %authed
88109
}
89110

111+
define i64 @test_ptrauth_nop_constant_mismatch() {
112+
; CHECK-LABEL: @test_ptrauth_nop_constant_mismatch(
113+
; CHECK-NEXT: [[AUTHED:%.*]] = call i64 @llvm.ptrauth.auth(i64 ptrtoint (ptr ptrauth (ptr @foo, i32 1, i64 1234) to i64), i32 1, i64 12)
114+
; CHECK-NEXT: ret i64 [[AUTHED]]
115+
;
116+
%authed = call i64 @llvm.ptrauth.auth(i64 ptrtoint(ptr ptrauth(ptr @foo, i32 1, i64 1234) to i64), i32 1, i64 12)
117+
ret i64 %authed
118+
}
119+
120+
define i64 @test_ptrauth_nop_constant_mismatch_key() {
121+
; CHECK-LABEL: @test_ptrauth_nop_constant_mismatch_key(
122+
; CHECK-NEXT: [[AUTHED:%.*]] = call i64 @llvm.ptrauth.auth(i64 ptrtoint (ptr ptrauth (ptr @foo, i32 1, i64 1234) to i64), i32 0, i64 1234)
123+
; CHECK-NEXT: ret i64 [[AUTHED]]
124+
;
125+
%authed = call i64 @llvm.ptrauth.auth(i64 ptrtoint(ptr ptrauth(ptr @foo, i32 1, i64 1234) to i64), i32 0, i64 1234)
126+
ret i64 %authed
127+
}
128+
129+
define i64 @test_ptrauth_nop_constant_addrdisc_mismatch() {
130+
; CHECK-LABEL: @test_ptrauth_nop_constant_addrdisc_mismatch(
131+
; CHECK-NEXT: [[BLENDED:%.*]] = call i64 @llvm.ptrauth.blend(i64 ptrtoint (ptr @foo to i64), i64 12)
132+
; CHECK-NEXT: [[AUTHED:%.*]] = call i64 @llvm.ptrauth.auth(i64 ptrtoint (ptr ptrauth (ptr @foo, i32 1, i64 1234, ptr @foo) to i64), i32 1, i64 [[BLENDED]])
133+
; CHECK-NEXT: ret i64 [[AUTHED]]
134+
;
135+
%addr = ptrtoint ptr @foo to i64
136+
%blended = call i64 @llvm.ptrauth.blend(i64 %addr, i64 12)
137+
%authed = call i64 @llvm.ptrauth.auth(i64 ptrtoint(ptr ptrauth(ptr @foo, i32 1, i64 1234, ptr @foo) to i64), i32 1, i64 %blended)
138+
ret i64 %authed
139+
}
140+
141+
define i64 @test_ptrauth_nop_constant_addrdisc_mismatch2() {
142+
; CHECK-LABEL: @test_ptrauth_nop_constant_addrdisc_mismatch2(
143+
; CHECK-NEXT: [[BLENDED:%.*]] = call i64 @llvm.ptrauth.blend(i64 ptrtoint (ptr @bar to i64), i64 1234)
144+
; CHECK-NEXT: [[AUTHED:%.*]] = call i64 @llvm.ptrauth.auth(i64 ptrtoint (ptr ptrauth (ptr @foo, i32 1, i64 1234, ptr @foo) to i64), i32 1, i64 [[BLENDED]])
145+
; CHECK-NEXT: ret i64 [[AUTHED]]
146+
;
147+
%addr = ptrtoint ptr @bar to i64
148+
%blended = call i64 @llvm.ptrauth.blend(i64 %addr, i64 1234)
149+
%authed = call i64 @llvm.ptrauth.auth(i64 ptrtoint(ptr ptrauth(ptr @foo, i32 1, i64 1234, ptr @foo) to i64), i32 1, i64 %blended)
150+
ret i64 %authed
151+
}
152+
153+
define i64 @test_ptrauth_resign_ptrauth_constant(ptr %p) {
154+
; CHECK-LABEL: @test_ptrauth_resign_ptrauth_constant(
155+
; CHECK-NEXT: ret i64 ptrtoint (ptr ptrauth (ptr @foo, i32 0, i64 42) to i64)
156+
;
157+
158+
%tmp0 = ptrtoint ptr %p to i64
159+
%authed = call i64 @llvm.ptrauth.resign(i64 ptrtoint(ptr ptrauth(ptr @foo, i32 1, i64 1234) to i64), i32 1, i64 1234, i32 0, i64 42)
160+
ret i64 %authed
161+
}
162+
90163
declare i64 @llvm.ptrauth.auth(i64, i32, i64)
91164
declare i64 @llvm.ptrauth.sign(i64, i32, i64)
92165
declare i64 @llvm.ptrauth.resign(i64, i32, i64, i32, i64)
166+
declare i64 @llvm.ptrauth.blend(i64, i64)

0 commit comments

Comments
 (0)