Skip to content

Commit ae09f3d

Browse files
wsmosesvchuravy
authored andcommitted
Handle Freeze instruction
1 parent b1bd36a commit ae09f3d

File tree

6 files changed

+97
-0
lines changed

6 files changed

+97
-0
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,46 @@ class AdjointGenerator
232232
return B.CreateLoad(alloc);
233233
}
234234

235+
#if LLVM_VERSION_MAJOR >= 10
236+
void visitFreezeInst(llvm::FreezeInst &inst) {
237+
eraseIfUnused(inst);
238+
if (gutils->isConstantInstruction(&inst))
239+
return;
240+
Value *orig_op0 = inst.getOperand(0);
241+
242+
switch (Mode) {
243+
case DerivativeMode::ReverseModeCombined:
244+
case DerivativeMode::ReverseModeGradient: {
245+
IRBuilder<> Builder2(inst.getParent());
246+
getReverseBuilder(Builder2);
247+
248+
Value *idiff = diffe(&inst, Builder2);
249+
Value *dif1 = Builder2.CreateFreeze(idiff);
250+
setDiffe(&inst, Constant::getNullValue(inst.getType()), Builder2);
251+
size_t size = 1;
252+
if (inst.getType()->isSized())
253+
size = (gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits(
254+
orig_op0->getType()) +
255+
7) /
256+
8;
257+
addToDiffe(orig_op0, dif1, Builder2, TR.addingType(size, orig_op0));
258+
return;
259+
}
260+
case DerivativeMode::ForwardMode: {
261+
IRBuilder<> BuilderZ(&inst);
262+
getForwardBuilder(BuilderZ);
263+
264+
Value *idiff = diffe(orig_op0, BuilderZ);
265+
Value *dif1 = BuilderZ.CreateFreeze(idiff);
266+
setDiffe(&inst, dif1, BuilderZ);
267+
return;
268+
}
269+
case DerivativeMode::ReverseModePrimal:
270+
return;
271+
}
272+
}
273+
#endif
274+
235275
void visitInstruction(llvm::Instruction &inst) {
236276
// TODO explicitly handle all instructions rather than using the catch all
237277
// below

enzyme/Enzyme/DifferentialUseAnalysis.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ static inline bool is_use_directly_needed_in_reverse(
7070
cast<InsertElementInst>(user)->getOperand(2) != val) ||
7171
(isa<ExtractElementInst>(user) &&
7272
cast<ExtractElementInst>(user)->getIndexOperand() != val)
73+
#if LLVM_VERSION_MAJOR >= 10
74+
|| isa<FreezeInst>(user)
75+
#endif
7376
// isa<ExtractElement>(use) ||
7477
// isa<InsertElementInst>(use) || isa<ShuffleVectorInst>(use) ||
7578
// isa<ExtractValueInst>(use) || isa<AllocaInst>(use)

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,21 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
252252
return val;
253253
} else if (isa<AllocaInst>(val)) {
254254
return val;
255+
#if LLVM_VERSION_MAJOR >= 10
256+
} else if (auto op = dyn_cast<FreezeInst>(val)) {
257+
auto op0 = getOp(op->getOperand(0));
258+
if (op0 == nullptr)
259+
goto endCheck;
260+
auto toreturn = BuilderM.CreateFreeze(op0, op->getName() + "_unwrap");
261+
if (auto newi = dyn_cast<Instruction>(toreturn)) {
262+
newi->copyIRFlags(op);
263+
unwrappedLoads[newi] = val;
264+
}
265+
if (permitCache)
266+
unwrap_cache[BuilderM.GetInsertBlock()][idx] = toreturn;
267+
assert(val->getType() == toreturn->getType());
268+
return toreturn;
269+
#endif
255270
} else if (auto op = dyn_cast<CastInst>(val)) {
256271
auto op0 = getOp(op->getOperand(0));
257272
if (op0 == nullptr)

enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1514,6 +1514,15 @@ void TypeAnalyzer::visitIntToPtrInst(IntToPtrInst &I) {
15141514
updateAnalysis(I.getOperand(0), getAnalysis(&I), &I);
15151515
}
15161516

1517+
#if LLVM_VERSION_MAJOR >= 10
1518+
void TypeAnalyzer::visitFreezeInst(FreezeInst &I) {
1519+
if (direction & DOWN)
1520+
updateAnalysis(&I, getAnalysis(I.getOperand(0)), &I);
1521+
if (direction & UP)
1522+
updateAnalysis(I.getOperand(0), getAnalysis(&I), &I);
1523+
}
1524+
#endif
1525+
15171526
void TypeAnalyzer::visitBitCastInst(BitCastInst &I) {
15181527
if (I.getType()->isIntOrIntVectorTy() || I.getType()->isFPOrFPVectorTy()) {
15191528
if (direction & DOWN)

enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,10 @@ class TypeAnalyzer : public llvm::InstVisitor<TypeAnalyzer> {
282282

283283
void visitBitCastInst(llvm::BitCastInst &I);
284284

285+
#if LLVM_VERSION_MAJOR >= 10
286+
void visitFreezeInst(llvm::FreezeInst &I);
287+
#endif
288+
285289
void visitSelectInst(llvm::SelectInst &I);
286290

287291
void visitExtractElementInst(llvm::ExtractElementInst &I);
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
; RUN: if [ %llvmver -ge 10 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi
2+
3+
; Function Attrs: nounwind readnone uwtable
4+
define double @tester(double %x) {
5+
entry:
6+
%out = freeze double %x
7+
ret double %out
8+
}
9+
10+
define double @test_derivative(double %x) {
11+
entry:
12+
%0 = tail call double (double (double)*, ...) @__enzyme_autodiff(double (double)* nonnull @tester, double %x)
13+
ret double %0
14+
}
15+
16+
declare double @cbrt(double)
17+
18+
; Function Attrs: nounwind
19+
declare double @__enzyme_autodiff(double (double)*, ...)
20+
21+
; CHECK: define internal { double } @diffetester(double %x, double %differeturn) {
22+
; CHECK-NEXT: entry:
23+
; CHECK-NEXT: %0 = freeze double %differeturn
24+
; CHECK-NEXT: %1 = insertvalue { double } undef, double %0, 0
25+
; CHECK-NEXT: ret { double } %1
26+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)