@@ -176,15 +176,13 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
176
176
// assert(!val->getName().startswith("$tapeload"));
177
177
if (permitCache && unwrap_cache[BuilderM.GetInsertBlock ()].find (idx) !=
178
178
unwrap_cache[BuilderM.GetInsertBlock ()].end ()) {
179
- if ( unwrap_cache[BuilderM.GetInsertBlock ()][idx]-> getType () !=
180
- val->getType ()) {
179
+ auto cachedValue = unwrap_cache[BuilderM.GetInsertBlock ()][idx];
180
+ if (cachedValue-> getType () != val->getType ()) {
181
181
llvm::errs () << " val: " << *val << " \n " ;
182
- llvm::errs () << " unwrap_cache[cidx]: "
183
- << *unwrap_cache[BuilderM.GetInsertBlock ()][idx] << " \n " ;
182
+ llvm::errs () << " unwrap_cache[cidx]: " << *cachedValue << " \n " ;
184
183
}
185
- assert (unwrap_cache[BuilderM.GetInsertBlock ()][idx]->getType () ==
186
- val->getType ());
187
- return unwrap_cache[BuilderM.GetInsertBlock ()][idx];
184
+ assert (cachedValue->getType () == val->getType ());
185
+ return cachedValue;
188
186
}
189
187
190
188
#define getOpFullest (Builder, vtmp, frominst, check ) \
@@ -361,8 +359,9 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
361
359
auto toreturn = BuilderM.CreateBinOp (op->getOpcode (), op0, op1,
362
360
op->getName () + " _unwrap" );
363
361
unwrappedLoads[toreturn] = val;
364
- if (auto newi = dyn_cast<Instruction>(toreturn))
362
+ if (auto newi = dyn_cast<Instruction>(toreturn)) {
365
363
newi->copyIRFlags (op);
364
+ }
366
365
if (permitCache)
367
366
unwrap_cache[BuilderM.GetInsertBlock ()][idx] = toreturn;
368
367
assert (val->getType () == toreturn->getType ());
@@ -834,17 +833,10 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
834
833
if (BuilderM.GetInsertPoint () != oldB->end ())
835
834
goto endCheck;
836
835
837
- // todo speed this up
838
- BasicBlock *fwd = nullptr ;
839
- for (const auto &pair : reverseBlocks) {
840
- const std::vector<BasicBlock *> &vec = pair.second ;
841
- if (std::find (vec.begin (), vec.end (), oldB) != vec.end ()) {
842
- fwd = pair.first ;
843
- break ;
844
- }
845
- }
846
- if (!fwd)
836
+ auto found = reverseBlockToPrimal.find (oldB);
837
+ if (found == reverseBlockToPrimal.end ())
847
838
goto endCheck;
839
+ BasicBlock *fwd = found->second ;
848
840
849
841
SmallVector<BasicBlock *, 2 > predBlocks;
850
842
predBlocks.push_back (bi2->getSuccessor (0 ));
@@ -876,6 +868,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
876
868
blocks[i]->moveAfter (last);
877
869
last = blocks[i];
878
870
reverseBlocks[fwd].push_back (blocks[i]);
871
+ reverseBlockToPrimal[blocks[i]] = fwd;
879
872
IRBuilder<> B (blocks[i]);
880
873
881
874
unwrap_cache[blocks[i]] = unwrap_cache[oldB];
@@ -884,7 +877,17 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
884
877
885
878
if (auto inst = dyn_cast<Instruction>(
886
879
phi->getIncomingValueForBlock (PB))) {
887
- if (inst->mayReadFromMemory () || !EnzymeSpeculatePHIs)
880
+ // Recompute the phi computation with the conditional if:
881
+ // 1) the instruction may reat from memory AND does not
882
+ // dominate the current insertion point (thereby
883
+ // potentially making such recomputation without the
884
+ // condition illegal)
885
+ // 2) the value is a call or load and option is set to not
886
+ // speculatively recompute values within a phi
887
+ if ((inst->mayReadFromMemory () &&
888
+ !DT.dominates (inst->getParent (), phi->getParent ())) ||
889
+ (!EnzymeSpeculatePHIs &&
890
+ (isa<CallInst>(inst) || isa<LoadInst>(inst))))
888
891
vals.push_back (
889
892
getOpFull (B, phi->getIncomingValueForBlock (PB), PB));
890
893
else
@@ -895,10 +898,11 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
895
898
getOpFull (BuilderM, phi->getIncomingValueForBlock (PB), PB));
896
899
897
900
if (!vals[i]) {
898
- for (size_t j = 0 ; j < i; i ++) {
901
+ for (size_t j = 0 ; j < i; j ++) {
899
902
reverseBlocks[fwd].erase (std::find (reverseBlocks[fwd].begin (),
900
903
reverseBlocks[fwd].end (),
901
904
blocks[j]));
905
+ reverseBlockToPrimal.erase (blocks[j]);
902
906
unwrap_cache.erase (blocks[j]);
903
907
lookup_cache.erase (blocks[j]);
904
908
SmallVector<Instruction *, 4 > toErase;
@@ -910,7 +914,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
910
914
}
911
915
}
912
916
bret->eraseFromParent ();
913
- for (size_t j = 0 ; j < i; i ++) {
917
+ for (size_t j = 0 ; j < i; j ++) {
914
918
blocks[j]->eraseFromParent ();
915
919
};
916
920
goto endCheck;
@@ -939,6 +943,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
939
943
940
944
BuilderM.SetInsertPoint (bret);
941
945
reverseBlocks[fwd].push_back (bret);
946
+ reverseBlockToPrimal[bret] = fwd;
942
947
auto toret = BuilderM.CreatePHI (val->getType (), vals.size ());
943
948
for (size_t i = 0 ; i < vals.size (); i++)
944
949
toret->addIncoming (vals[i], endingBlocks[i]);
@@ -992,17 +997,10 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
992
997
if (BuilderM.GetInsertPoint () != oldB->end ())
993
998
goto endCheck;
994
999
995
- // todo speed this up
996
- BasicBlock *fwd = nullptr ;
997
- for (const auto &pair : reverseBlocks) {
998
- const std::vector<BasicBlock *> &vec = pair.second ;
999
- if (std::find (vec.begin (), vec.end (), oldB) != vec.end ()) {
1000
- fwd = pair.first ;
1001
- break ;
1002
- }
1003
- }
1004
- if (!fwd)
1000
+ auto found = reverseBlockToPrimal.find (oldB);
1001
+ if (found == reverseBlockToPrimal.end ())
1005
1002
goto endCheck;
1003
+ BasicBlock *fwd = found->second ;
1006
1004
1007
1005
SmallVector<BasicBlock *, 2 > predBlocks;
1008
1006
Value *cond = nullptr ;
@@ -1046,14 +1044,24 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
1046
1044
blocks[i]->moveAfter (last);
1047
1045
last = blocks[i];
1048
1046
reverseBlocks[fwd].push_back (blocks[i]);
1047
+ reverseBlockToPrimal[blocks[i]] = fwd;
1049
1048
IRBuilder<> B (blocks[i]);
1050
1049
1051
1050
unwrap_cache[blocks[i]] = unwrap_cache[oldB];
1052
1051
lookup_cache[blocks[i]] = lookup_cache[oldB];
1053
1052
1054
1053
if (auto inst =
1055
1054
dyn_cast<Instruction>(phi->getIncomingValueForBlock (PB))) {
1056
- if (inst->mayReadFromMemory () || !EnzymeSpeculatePHIs)
1055
+ // Recompute the phi computation with the conditional if:
1056
+ // 1) the instruction may reat from memory AND does not dominate
1057
+ // the current insertion point (thereby potentially making such
1058
+ // recomputation without the condition illegal)
1059
+ // 2) the value is a call or load and option is set to not
1060
+ // speculatively recompute values within a phi
1061
+ if ((inst->mayReadFromMemory () &&
1062
+ !DT.dominates (inst->getParent (), phi->getParent ())) ||
1063
+ (!EnzymeSpeculatePHIs &&
1064
+ (isa<CallInst>(inst) || isa<LoadInst>(inst))))
1057
1065
vals.push_back (getOpFull (B, phi->getIncomingValueForBlock (PB), PB));
1058
1066
else
1059
1067
vals.push_back (
@@ -1063,10 +1071,11 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
1063
1071
getOpFull (BuilderM, phi->getIncomingValueForBlock (PB), PB));
1064
1072
1065
1073
if (!vals[i]) {
1066
- for (size_t j = 0 ; j < i; i ++) {
1074
+ for (size_t j = 0 ; j < i; j ++) {
1067
1075
reverseBlocks[fwd].erase (std::find (reverseBlocks[fwd].begin (),
1068
1076
reverseBlocks[fwd].end (),
1069
1077
blocks[j]));
1078
+ reverseBlockToPrimal.erase (blocks[j]);
1070
1079
unwrap_cache.erase (blocks[j]);
1071
1080
lookup_cache.erase (blocks[j]);
1072
1081
SmallVector<Instruction *, 4 > toErase;
@@ -1078,7 +1087,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
1078
1087
}
1079
1088
}
1080
1089
bret->eraseFromParent ();
1081
- for (size_t j = 0 ; j < i; i ++) {
1090
+ for (size_t j = 0 ; j < i; j ++) {
1082
1091
blocks[j]->eraseFromParent ();
1083
1092
};
1084
1093
goto endCheck;
@@ -1088,6 +1097,38 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
1088
1097
endingBlocks.push_back (B.GetInsertBlock ());
1089
1098
}
1090
1099
1100
+ // Fast path to not make a split block if no additional instructions
1101
+ // were made in the two blocks
1102
+ if (isa<BranchInst>(equivalentTerminator) && blocks[0 ]->size () == 1 &&
1103
+ blocks[1 ]->size () == 1 ) {
1104
+ for (size_t j = 0 ; j < blocks.size (); j++) {
1105
+ reverseBlocks[fwd].erase (std::find (
1106
+ reverseBlocks[fwd].begin (), reverseBlocks[fwd].end (), blocks[j]));
1107
+ reverseBlockToPrimal.erase (blocks[j]);
1108
+ unwrap_cache.erase (blocks[j]);
1109
+ lookup_cache.erase (blocks[j]);
1110
+ SmallVector<Instruction *, 4 > toErase;
1111
+ for (auto &I : *blocks[j]) {
1112
+ toErase.push_back (&I);
1113
+ }
1114
+ for (auto I : toErase) {
1115
+ erase (I);
1116
+ }
1117
+ }
1118
+ bret->eraseFromParent ();
1119
+ for (size_t j = 0 ; j < blocks.size (); j++) {
1120
+ blocks[j]->eraseFromParent ();
1121
+ };
1122
+ Value *toret = BuilderM.CreateSelect (cond, vals[0 ], vals[1 ],
1123
+ phi->getName () + " _unwrap" );
1124
+ if (permitCache) {
1125
+ unwrap_cache[BuilderM.GetInsertBlock ()][idx] = toret;
1126
+ }
1127
+ if (auto instRet = dyn_cast<Instruction>(toret))
1128
+ unwrappedLoads[instRet] = val;
1129
+ return toret;
1130
+ }
1131
+
1091
1132
bret->moveAfter (last);
1092
1133
if (isa<BranchInst>(equivalentTerminator)) {
1093
1134
BuilderM.CreateCondBr (cond, blocks[0 ], blocks[1 ]);
@@ -1102,6 +1143,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
1102
1143
}
1103
1144
BuilderM.SetInsertPoint (bret);
1104
1145
reverseBlocks[fwd].push_back (bret);
1146
+ reverseBlockToPrimal[bret] = fwd;
1105
1147
auto toret = BuilderM.CreatePHI (val->getType (), vals.size ());
1106
1148
for (size_t i = 0 ; i < vals.size (); i++)
1107
1149
toret->addIncoming (vals[i], endingBlocks[i]);
@@ -1328,9 +1370,10 @@ Value *GradientUtils::cacheForReverse(IRBuilder<> &BuilderQ, Value *malloc,
1328
1370
for (auto u : users) {
1329
1371
if (auto li = dyn_cast<LoadInst>(u)) {
1330
1372
IRBuilder<> lb (li);
1331
- ValueToValueMapTy empty;
1332
- li->replaceAllUsesWith (
1333
- unwrapM (ret, lb, empty, UnwrapMode::LegalFullUnwrap));
1373
+ auto replacewith =
1374
+ (idx < 0 ) ? tape
1375
+ : lb.CreateExtractValue (tape, {(unsigned )idx});
1376
+ li->replaceAllUsesWith (replacewith);
1334
1377
erase (li);
1335
1378
} else {
1336
1379
llvm::errs () << " newFunc: " << *newFunc << " \n " ;
@@ -1768,13 +1811,10 @@ bool GradientUtils::legalRecompute(const Value *val,
1768
1811
if (BuilderM) {
1769
1812
fwdBlockIfReverse = BuilderM->GetInsertBlock ();
1770
1813
if (!reverse) {
1771
- for (auto pair : reverseBlocks) {
1772
- if (std::find (pair.second .begin (), pair.second .end (),
1773
- BuilderM->GetInsertBlock ()) != pair.second .end ()) {
1774
- fwdBlockIfReverse = pair.first ;
1775
- reverse = true ;
1776
- break ;
1777
- }
1814
+ auto found = reverseBlockToPrimal.find (BuilderM->GetInsertBlock ());
1815
+ if (found != reverseBlockToPrimal.end ()) {
1816
+ fwdBlockIfReverse = found->second ;
1817
+ reverse = true ;
1778
1818
}
1779
1819
}
1780
1820
if (fwdBlockIfReverse->getParent () != oldFunc)
0 commit comments