@@ -1732,23 +1732,23 @@ class AdjointGenerator
1732
1732
template <typename Func, typename... Args>
1733
1733
Value *applyChainRule(Type *diffType, IRBuilder<> &Builder, Func rule,
1734
1734
Args... args) {
1735
- return ((DiffeGradientUtils *)gutils)
1735
+ return ((GradientUtils *)gutils)
1736
1736
->applyChainRule(diffType, Builder, rule, args...);
1737
1737
}
1738
1738
1739
1739
/// Unwraps a vector derivative from its internal representation and applies a
1740
1740
/// function f to each element.
1741
1741
template <typename Func, typename... Args>
1742
1742
void applyChainRule(IRBuilder<> &Builder, Func rule, Args... args) {
1743
- ((DiffeGradientUtils *)gutils)->applyChainRule(Builder, rule, args...);
1743
+ ((GradientUtils *)gutils)->applyChainRule(Builder, rule, args...);
1744
1744
}
1745
1745
1746
1746
/// Unwraps an collection of constant vector derivatives from their internal
1747
1747
/// representations and applies a function f to each element.
1748
1748
template <typename Func>
1749
1749
void applyChainRule(ArrayRef<Value *> diffs, IRBuilder<> &Builder,
1750
1750
Func rule) {
1751
- ((DiffeGradientUtils *)gutils)->applyChainRule(diffs, Builder, rule);
1751
+ ((GradientUtils *)gutils)->applyChainRule(diffs, Builder, rule);
1752
1752
}
1753
1753
1754
1754
bool shouldFree() {
0 commit comments