Skip to content

Commit d4d3ae9

Browse files
authored
Implement fast block lookup and phi recomputation (rust-lang#234)
* Implement fast block lookup and phi recomputation * Attempt debug ci * Fix cache lookup * try debug build * Fix pending unwrap bug
1 parent 9c842d8 commit d4d3ae9

File tree

4 files changed

+281
-78
lines changed

4 files changed

+281
-78
lines changed

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3013,8 +3013,13 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
30133013
for (auto V : unwrapToOrig[newi]) {
30143014
ValueToValueMapTy empty;
30153015
IRBuilder<> lb(cast<Instruction>(V));
3016+
// This must disallow caching here as otherwise performing the loop in
3017+
// the wrong order may result in first replacing the later unwrapped
3018+
// value, caching it, then attempting to reuse it for an earlier
3019+
// replacement.
30163020
V->replaceAllUsesWith(
3017-
gutils->unwrapM(nexti, lb, empty, UnwrapMode::LegalFullUnwrap));
3021+
gutils->unwrapM(nexti, lb, empty, UnwrapMode::LegalFullUnwrap,
3022+
/*scope*/ nullptr, /*permitCache*/ false));
30183023
cast<Instruction>(V)->eraseFromParent();
30193024
}
30203025
}

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 84 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -176,15 +176,13 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
176176
// assert(!val->getName().startswith("$tapeload"));
177177
if (permitCache && unwrap_cache[BuilderM.GetInsertBlock()].find(idx) !=
178178
unwrap_cache[BuilderM.GetInsertBlock()].end()) {
179-
if (unwrap_cache[BuilderM.GetInsertBlock()][idx]->getType() !=
180-
val->getType()) {
179+
auto cachedValue = unwrap_cache[BuilderM.GetInsertBlock()][idx];
180+
if (cachedValue->getType() != val->getType()) {
181181
llvm::errs() << "val: " << *val << "\n";
182-
llvm::errs() << "unwrap_cache[cidx]: "
183-
<< *unwrap_cache[BuilderM.GetInsertBlock()][idx] << "\n";
182+
llvm::errs() << "unwrap_cache[cidx]: " << *cachedValue << "\n";
184183
}
185-
assert(unwrap_cache[BuilderM.GetInsertBlock()][idx]->getType() ==
186-
val->getType());
187-
return unwrap_cache[BuilderM.GetInsertBlock()][idx];
184+
assert(cachedValue->getType() == val->getType());
185+
return cachedValue;
188186
}
189187

190188
#define getOpFullest(Builder, vtmp, frominst, check) \
@@ -361,8 +359,9 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
361359
auto toreturn = BuilderM.CreateBinOp(op->getOpcode(), op0, op1,
362360
op->getName() + "_unwrap");
363361
unwrappedLoads[toreturn] = val;
364-
if (auto newi = dyn_cast<Instruction>(toreturn))
362+
if (auto newi = dyn_cast<Instruction>(toreturn)) {
365363
newi->copyIRFlags(op);
364+
}
366365
if (permitCache)
367366
unwrap_cache[BuilderM.GetInsertBlock()][idx] = toreturn;
368367
assert(val->getType() == toreturn->getType());
@@ -834,17 +833,10 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
834833
if (BuilderM.GetInsertPoint() != oldB->end())
835834
goto endCheck;
836835

837-
// todo speed this up
838-
BasicBlock *fwd = nullptr;
839-
for (const auto &pair : reverseBlocks) {
840-
const std::vector<BasicBlock *> &vec = pair.second;
841-
if (std::find(vec.begin(), vec.end(), oldB) != vec.end()) {
842-
fwd = pair.first;
843-
break;
844-
}
845-
}
846-
if (!fwd)
836+
auto found = reverseBlockToPrimal.find(oldB);
837+
if (found == reverseBlockToPrimal.end())
847838
goto endCheck;
839+
BasicBlock *fwd = found->second;
848840

849841
SmallVector<BasicBlock *, 2> predBlocks;
850842
predBlocks.push_back(bi2->getSuccessor(0));
@@ -876,6 +868,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
876868
blocks[i]->moveAfter(last);
877869
last = blocks[i];
878870
reverseBlocks[fwd].push_back(blocks[i]);
871+
reverseBlockToPrimal[blocks[i]] = fwd;
879872
IRBuilder<> B(blocks[i]);
880873

881874
unwrap_cache[blocks[i]] = unwrap_cache[oldB];
@@ -884,7 +877,17 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
884877

885878
if (auto inst = dyn_cast<Instruction>(
886879
phi->getIncomingValueForBlock(PB))) {
887-
if (inst->mayReadFromMemory() || !EnzymeSpeculatePHIs)
880+
// Recompute the phi computation with the conditional if:
881+
// 1) the instruction may reat from memory AND does not
882+
// dominate the current insertion point (thereby
883+
// potentially making such recomputation without the
884+
// condition illegal)
885+
// 2) the value is a call or load and option is set to not
886+
// speculatively recompute values within a phi
887+
if ((inst->mayReadFromMemory() &&
888+
!DT.dominates(inst->getParent(), phi->getParent())) ||
889+
(!EnzymeSpeculatePHIs &&
890+
(isa<CallInst>(inst) || isa<LoadInst>(inst))))
888891
vals.push_back(
889892
getOpFull(B, phi->getIncomingValueForBlock(PB), PB));
890893
else
@@ -895,10 +898,11 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
895898
getOpFull(BuilderM, phi->getIncomingValueForBlock(PB), PB));
896899

897900
if (!vals[i]) {
898-
for (size_t j = 0; j < i; i++) {
901+
for (size_t j = 0; j < i; j++) {
899902
reverseBlocks[fwd].erase(std::find(reverseBlocks[fwd].begin(),
900903
reverseBlocks[fwd].end(),
901904
blocks[j]));
905+
reverseBlockToPrimal.erase(blocks[j]);
902906
unwrap_cache.erase(blocks[j]);
903907
lookup_cache.erase(blocks[j]);
904908
SmallVector<Instruction *, 4> toErase;
@@ -910,7 +914,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
910914
}
911915
}
912916
bret->eraseFromParent();
913-
for (size_t j = 0; j < i; i++) {
917+
for (size_t j = 0; j < i; j++) {
914918
blocks[j]->eraseFromParent();
915919
};
916920
goto endCheck;
@@ -939,6 +943,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
939943

940944
BuilderM.SetInsertPoint(bret);
941945
reverseBlocks[fwd].push_back(bret);
946+
reverseBlockToPrimal[bret] = fwd;
942947
auto toret = BuilderM.CreatePHI(val->getType(), vals.size());
943948
for (size_t i = 0; i < vals.size(); i++)
944949
toret->addIncoming(vals[i], endingBlocks[i]);
@@ -992,17 +997,10 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
992997
if (BuilderM.GetInsertPoint() != oldB->end())
993998
goto endCheck;
994999

995-
// todo speed this up
996-
BasicBlock *fwd = nullptr;
997-
for (const auto &pair : reverseBlocks) {
998-
const std::vector<BasicBlock *> &vec = pair.second;
999-
if (std::find(vec.begin(), vec.end(), oldB) != vec.end()) {
1000-
fwd = pair.first;
1001-
break;
1002-
}
1003-
}
1004-
if (!fwd)
1000+
auto found = reverseBlockToPrimal.find(oldB);
1001+
if (found == reverseBlockToPrimal.end())
10051002
goto endCheck;
1003+
BasicBlock *fwd = found->second;
10061004

10071005
SmallVector<BasicBlock *, 2> predBlocks;
10081006
Value *cond = nullptr;
@@ -1046,14 +1044,24 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
10461044
blocks[i]->moveAfter(last);
10471045
last = blocks[i];
10481046
reverseBlocks[fwd].push_back(blocks[i]);
1047+
reverseBlockToPrimal[blocks[i]] = fwd;
10491048
IRBuilder<> B(blocks[i]);
10501049

10511050
unwrap_cache[blocks[i]] = unwrap_cache[oldB];
10521051
lookup_cache[blocks[i]] = lookup_cache[oldB];
10531052

10541053
if (auto inst =
10551054
dyn_cast<Instruction>(phi->getIncomingValueForBlock(PB))) {
1056-
if (inst->mayReadFromMemory() || !EnzymeSpeculatePHIs)
1055+
// Recompute the phi computation with the conditional if:
1056+
// 1) the instruction may reat from memory AND does not dominate
1057+
// the current insertion point (thereby potentially making such
1058+
// recomputation without the condition illegal)
1059+
// 2) the value is a call or load and option is set to not
1060+
// speculatively recompute values within a phi
1061+
if ((inst->mayReadFromMemory() &&
1062+
!DT.dominates(inst->getParent(), phi->getParent())) ||
1063+
(!EnzymeSpeculatePHIs &&
1064+
(isa<CallInst>(inst) || isa<LoadInst>(inst))))
10571065
vals.push_back(getOpFull(B, phi->getIncomingValueForBlock(PB), PB));
10581066
else
10591067
vals.push_back(
@@ -1063,10 +1071,11 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
10631071
getOpFull(BuilderM, phi->getIncomingValueForBlock(PB), PB));
10641072

10651073
if (!vals[i]) {
1066-
for (size_t j = 0; j < i; i++) {
1074+
for (size_t j = 0; j < i; j++) {
10671075
reverseBlocks[fwd].erase(std::find(reverseBlocks[fwd].begin(),
10681076
reverseBlocks[fwd].end(),
10691077
blocks[j]));
1078+
reverseBlockToPrimal.erase(blocks[j]);
10701079
unwrap_cache.erase(blocks[j]);
10711080
lookup_cache.erase(blocks[j]);
10721081
SmallVector<Instruction *, 4> toErase;
@@ -1078,7 +1087,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
10781087
}
10791088
}
10801089
bret->eraseFromParent();
1081-
for (size_t j = 0; j < i; i++) {
1090+
for (size_t j = 0; j < i; j++) {
10821091
blocks[j]->eraseFromParent();
10831092
};
10841093
goto endCheck;
@@ -1088,6 +1097,38 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
10881097
endingBlocks.push_back(B.GetInsertBlock());
10891098
}
10901099

1100+
// Fast path to not make a split block if no additional instructions
1101+
// were made in the two blocks
1102+
if (isa<BranchInst>(equivalentTerminator) && blocks[0]->size() == 1 &&
1103+
blocks[1]->size() == 1) {
1104+
for (size_t j = 0; j < blocks.size(); j++) {
1105+
reverseBlocks[fwd].erase(std::find(
1106+
reverseBlocks[fwd].begin(), reverseBlocks[fwd].end(), blocks[j]));
1107+
reverseBlockToPrimal.erase(blocks[j]);
1108+
unwrap_cache.erase(blocks[j]);
1109+
lookup_cache.erase(blocks[j]);
1110+
SmallVector<Instruction *, 4> toErase;
1111+
for (auto &I : *blocks[j]) {
1112+
toErase.push_back(&I);
1113+
}
1114+
for (auto I : toErase) {
1115+
erase(I);
1116+
}
1117+
}
1118+
bret->eraseFromParent();
1119+
for (size_t j = 0; j < blocks.size(); j++) {
1120+
blocks[j]->eraseFromParent();
1121+
};
1122+
Value *toret = BuilderM.CreateSelect(cond, vals[0], vals[1],
1123+
phi->getName() + "_unwrap");
1124+
if (permitCache) {
1125+
unwrap_cache[BuilderM.GetInsertBlock()][idx] = toret;
1126+
}
1127+
if (auto instRet = dyn_cast<Instruction>(toret))
1128+
unwrappedLoads[instRet] = val;
1129+
return toret;
1130+
}
1131+
10911132
bret->moveAfter(last);
10921133
if (isa<BranchInst>(equivalentTerminator)) {
10931134
BuilderM.CreateCondBr(cond, blocks[0], blocks[1]);
@@ -1102,6 +1143,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
11021143
}
11031144
BuilderM.SetInsertPoint(bret);
11041145
reverseBlocks[fwd].push_back(bret);
1146+
reverseBlockToPrimal[bret] = fwd;
11051147
auto toret = BuilderM.CreatePHI(val->getType(), vals.size());
11061148
for (size_t i = 0; i < vals.size(); i++)
11071149
toret->addIncoming(vals[i], endingBlocks[i]);
@@ -1328,9 +1370,10 @@ Value *GradientUtils::cacheForReverse(IRBuilder<> &BuilderQ, Value *malloc,
13281370
for (auto u : users) {
13291371
if (auto li = dyn_cast<LoadInst>(u)) {
13301372
IRBuilder<> lb(li);
1331-
ValueToValueMapTy empty;
1332-
li->replaceAllUsesWith(
1333-
unwrapM(ret, lb, empty, UnwrapMode::LegalFullUnwrap));
1373+
auto replacewith =
1374+
(idx < 0) ? tape
1375+
: lb.CreateExtractValue(tape, {(unsigned)idx});
1376+
li->replaceAllUsesWith(replacewith);
13341377
erase(li);
13351378
} else {
13361379
llvm::errs() << "newFunc: " << *newFunc << "\n";
@@ -1768,13 +1811,10 @@ bool GradientUtils::legalRecompute(const Value *val,
17681811
if (BuilderM) {
17691812
fwdBlockIfReverse = BuilderM->GetInsertBlock();
17701813
if (!reverse) {
1771-
for (auto pair : reverseBlocks) {
1772-
if (std::find(pair.second.begin(), pair.second.end(),
1773-
BuilderM->GetInsertBlock()) != pair.second.end()) {
1774-
fwdBlockIfReverse = pair.first;
1775-
reverse = true;
1776-
break;
1777-
}
1814+
auto found = reverseBlockToPrimal.find(BuilderM->GetInsertBlock());
1815+
if (found != reverseBlockToPrimal.end()) {
1816+
fwdBlockIfReverse = found->second;
1817+
reverse = true;
17781818
}
17791819
}
17801820
if (fwdBlockIfReverse->getParent() != oldFunc)

enzyme/Enzyme/GradientUtils.h

Lines changed: 32 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,12 @@ class GradientUtils : public CacheUtility {
106106
ScalarEvolution &OrigSE;
107107
std::shared_ptr<ActivityAnalyzer> ATA;
108108
SmallVector<BasicBlock *, 12> originalBlocks;
109+
110+
// Map of primal block to corresponding block(s) in reverse
109111
std::map<BasicBlock *, std::vector<BasicBlock *>> reverseBlocks;
112+
// Map of block in reverse to corresponding primal block
113+
std::map<BasicBlock *, BasicBlock *> reverseBlockToPrimal;
114+
110115
SmallPtrSet<PHINode *, 4> fictiousPHIs;
111116
ValueToValueMapTy originalToNewFn;
112117
std::vector<CallInst *> originalCalls;
@@ -441,25 +446,23 @@ class GradientUtils : public CacheUtility {
441446
BasicBlock *addReverseBlock(BasicBlock *currentBlock, Twine name,
442447
bool forkCache = true) {
443448
assert(reverseBlocks.size());
444-
445-
// todo speed this up
446-
for (auto &pair : reverseBlocks) {
447-
std::vector<BasicBlock *> &vec = pair.second;
448-
if (vec.back() == currentBlock) {
449-
450-
BasicBlock *rev =
451-
BasicBlock::Create(currentBlock->getContext(), name, newFunc);
452-
rev->moveAfter(currentBlock);
453-
vec.push_back(rev);
454-
if (forkCache) {
455-
unwrap_cache[rev] = unwrap_cache[currentBlock];
456-
lookup_cache[rev] = lookup_cache[currentBlock];
457-
}
458-
return rev;
459-
}
460-
}
461-
assert(0 && "cannot find reverse location to add into");
462-
llvm_unreachable("cannot find reverse location to add into");
449+
auto found = reverseBlockToPrimal.find(currentBlock);
450+
assert(found != reverseBlockToPrimal.end());
451+
452+
std::vector<BasicBlock *> &vec = reverseBlocks[found->second];
453+
assert(vec.size());
454+
assert(vec.back() == currentBlock);
455+
456+
BasicBlock *rev =
457+
BasicBlock::Create(currentBlock->getContext(), name, newFunc);
458+
rev->moveAfter(currentBlock);
459+
vec.push_back(rev);
460+
reverseBlockToPrimal[rev] = found->second;
461+
if (forkCache) {
462+
unwrap_cache[rev] = unwrap_cache[currentBlock];
463+
lookup_cache[rev] = lookup_cache[currentBlock];
464+
}
465+
return rev;
463466
}
464467

465468
public:
@@ -831,19 +834,13 @@ class GradientUtils : public CacheUtility {
831834

832835
private:
833836
BasicBlock *originalForReverseBlock(BasicBlock &BB2) const {
834-
assert(reverseBlocks.size() != 0);
835-
for (auto BB : originalBlocks) {
836-
auto it = reverseBlocks.find(BB);
837-
assert(it != reverseBlocks.end());
838-
if (std::find(it->second.begin(), it->second.end(), &BB2) !=
839-
it->second.end()) {
840-
return BB;
841-
}
837+
auto found = reverseBlockToPrimal.find(&BB2);
838+
if (found == reverseBlockToPrimal.end()) {
839+
llvm::errs() << "newFunc: " << *newFunc << "\n";
840+
llvm::errs() << BB2 << "\n";
842841
}
843-
llvm::errs() << *newFunc << "\n";
844-
llvm::errs() << BB2 << "\n";
845-
assert(0 && "could not find original block for given reverse block");
846-
report_fatal_error("could not find original block for given reverse block");
842+
assert(found != reverseBlockToPrimal.end());
843+
return found->second;
847844
}
848845

849846
public:
@@ -1264,8 +1261,10 @@ class DiffeGradientUtils : public GradientUtils {
12641261
for (BasicBlock *BB : originalBlocks) {
12651262
if (BB == inversionAllocs)
12661263
continue;
1267-
reverseBlocks[BB].push_back(BasicBlock::Create(
1268-
BB->getContext(), "invert" + BB->getName(), newFunc));
1264+
BasicBlock *RBB = BasicBlock::Create(BB->getContext(),
1265+
"invert" + BB->getName(), newFunc);
1266+
reverseBlocks[BB].push_back(RBB);
1267+
reverseBlockToPrimal[RBB] = BB;
12691268
}
12701269
assert(reverseBlocks.size() != 0);
12711270
}

0 commit comments

Comments
 (0)