Skip to content

Commit 969ddce

Browse files
authored
Cast gutils to correct subclass for applyChainRule (rust-lang#733)
1 parent 95a6666 commit 969ddce

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1732,23 +1732,23 @@ class AdjointGenerator
17321732
template <typename Func, typename... Args>
17331733
Value *applyChainRule(Type *diffType, IRBuilder<> &Builder, Func rule,
17341734
Args... args) {
1735-
return ((DiffeGradientUtils *)gutils)
1735+
return ((GradientUtils *)gutils)
17361736
->applyChainRule(diffType, Builder, rule, args...);
17371737
}
17381738

17391739
/// Unwraps a vector derivative from its internal representation and applies a
17401740
/// function f to each element.
17411741
template <typename Func, typename... Args>
17421742
void applyChainRule(IRBuilder<> &Builder, Func rule, Args... args) {
1743-
((DiffeGradientUtils *)gutils)->applyChainRule(Builder, rule, args...);
1743+
((GradientUtils *)gutils)->applyChainRule(Builder, rule, args...);
17441744
}
17451745

17461746
/// Unwraps an collection of constant vector derivatives from their internal
17471747
/// representations and applies a function f to each element.
17481748
template <typename Func>
17491749
void applyChainRule(ArrayRef<Value *> diffs, IRBuilder<> &Builder,
17501750
Func rule) {
1751-
((DiffeGradientUtils *)gutils)->applyChainRule(diffs, Builder, rule);
1751+
((GradientUtils *)gutils)->applyChainRule(diffs, Builder, rule);
17521752
}
17531753

17541754
bool shouldFree() {

0 commit comments

Comments
 (0)