Skip to content

Commit 0569021

Browse files
authored
Custom noderivative found error handler (rust-lang#438)
1 parent 47bdce8 commit 0569021

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ cl::opt<bool> nonmarkedglobals_inactiveloads(
9797
cl::opt<bool> EnzymeJuliaAddrLoad(
9898
"enzyme-julia-addr-load", cl::init(false), cl::Hidden,
9999
cl::desc("Mark all loads resulting in an addr(13)* to be legal to redo"));
100+
101+
void (*CustomErrorHandler)(const char *) = nullptr;
100102
}
101103

102104
struct CacheAnalysis {
@@ -1702,6 +1704,11 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
17021704
}
17031705

17041706
if (todiff->empty()) {
1707+
if (todiff->empty() && CustomErrorHandler) {
1708+
std::string s =
1709+
("No augmented forward pass found for " + todiff->getName()).str();
1710+
CustomErrorHandler(s.c_str());
1711+
}
17051712
llvm::errs() << "mod: " << *todiff->getParent() << "\n";
17061713
llvm::errs() << *todiff << "\n";
17071714
assert(0 && "attempting to differentiate function without definition");
@@ -3195,6 +3202,10 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
31953202
"return or non-constant");
31963203
}
31973204

3205+
if (key.todiff->empty() && CustomErrorHandler) {
3206+
std::string s = ("No derivative found for " + key.todiff->getName()).str();
3207+
CustomErrorHandler(s.c_str());
3208+
}
31983209
assert(!key.todiff->empty());
31993210

32003211
ReturnType retVal =
@@ -3922,6 +3933,11 @@ Function *EnzymeLogic::CreateForwardDiff(
39223933
todiff, &todiff->getEntryBlock(),
39233934
"Cannot use provided custom derivative pass");
39243935
}
3936+
if (todiff->empty() && CustomErrorHandler) {
3937+
std::string s =
3938+
("No forward derivative found for " + todiff->getName()).str();
3939+
CustomErrorHandler(s.c_str());
3940+
}
39253941
if (todiff->empty())
39263942
llvm::errs() << *todiff << "\n";
39273943
assert(!todiff->empty());

enzyme/Enzyme/EnzymeLogic.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050

5151
extern "C" {
5252
extern llvm::cl::opt<bool> EnzymePrint;
53+
extern void (*CustomErrorHandler)(const char *);
5354
}
5455

5556
enum class AugmentedStruct { Tape, Return, DifferentialReturn };

0 commit comments

Comments
 (0)