Skip to content

Commit ba11731

Browse files
committed
working dupv for fwd mode
1 parent 8da0630 commit ba11731

File tree

5 files changed

+45
-10
lines changed

5 files changed

+45
-10
lines changed

compiler/rustc_ast/src/expand/autodiff_attrs.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ pub enum DiffActivity {
5050
/// with it.
5151
Dual,
5252
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
53+
/// with it. It expects the shadow argument to be `width` times larger than the original
54+
/// input/output.
55+
Dualv,
56+
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
5357
/// with it. Drop the code which updates the original input/output for maximum performance.
5458
DualOnly,
5559
/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
@@ -127,6 +131,7 @@ pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
127131
DiffMode::Source => false,
128132
DiffMode::Forward => {
129133
activity == DiffActivity::Dual
134+
|| activity == DiffActivity::Dualv
130135
|| activity == DiffActivity::DualOnly
131136
|| activity == DiffActivity::Const
132137
}
@@ -150,7 +155,7 @@ pub fn valid_ty_for_activity(ty: &P<Ty>, activity: DiffActivity) -> bool {
150155
if matches!(activity, Const) {
151156
return true;
152157
}
153-
if matches!(activity, Dual | DualOnly) {
158+
if matches!(activity, Dual | DualOnly | Dualv) {
154159
return true;
155160
}
156161
// FIXME(ZuseZ4) We should make this more robust to also
@@ -167,7 +172,7 @@ pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
167172
DiffMode::Error => false,
168173
DiffMode::Source => false,
169174
DiffMode::Forward => {
170-
matches!(activity, Dual | DualOnly | Const)
175+
matches!(activity, Dual | DualOnly | Dualv | Const)
171176
}
172177
DiffMode::Reverse => {
173178
matches!(activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const)
@@ -183,6 +188,7 @@ impl Display for DiffActivity {
183188
DiffActivity::Active => write!(f, "Active"),
184189
DiffActivity::ActiveOnly => write!(f, "ActiveOnly"),
185190
DiffActivity::Dual => write!(f, "Dual"),
191+
DiffActivity::Dualv => write!(f, "Dualv"),
186192
DiffActivity::DualOnly => write!(f, "DualOnly"),
187193
DiffActivity::Duplicated => write!(f, "Duplicated"),
188194
DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"),
@@ -214,6 +220,7 @@ impl FromStr for DiffActivity {
214220
"ActiveOnly" => Ok(DiffActivity::ActiveOnly),
215221
"Const" => Ok(DiffActivity::Const),
216222
"Dual" => Ok(DiffActivity::Dual),
223+
"Dualv" => Ok(DiffActivity::Dualv),
217224
"DualOnly" => Ok(DiffActivity::DualOnly),
218225
"Duplicated" => Ok(DiffActivity::Duplicated),
219226
"DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly),

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -775,8 +775,13 @@ mod llvm_enzyme {
775775
d_inputs.push(shadow_arg.clone());
776776
}
777777
}
778-
DiffActivity::Dual | DiffActivity::DualOnly => {
779-
for i in 0..x.width {
778+
DiffActivity::Dual | DiffActivity::DualOnly | DiffActivity::Dualv => {
779+
let iterations = if matches!(activity, DiffActivity::Dualv) {
780+
1
781+
} else {
782+
x.width
783+
};
784+
for i in 0..iterations {
780785
let mut shadow_arg = arg.clone();
781786
let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
782787
ident.name
@@ -863,8 +868,8 @@ mod llvm_enzyme {
863868
}
864869
};
865870

866-
if let DiffActivity::Dual = x.ret_activity {
867-
let kind = if x.width == 1 {
871+
if matches!(x.ret_activity, DiffActivity::Dual | DiffActivity::Dualv) {
872+
let kind = if x.width == 1 || matches!(x.ret_activity, DiffActivity::Dualv) {
868873
// Dual can only be used for f32/f64 ret.
869874
// In that case we return now a tuple with two floats.
870875
TyKind::Tup(thin_vec![ty.clone(), ty.clone()])

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
122122
/// Empty string, to be used where LLVM expects an instruction name, indicating
123123
/// that the instruction is to be left unnamed (i.e. numbered, in textual IR).
124124
// FIXME(eddyb) pass `&CStr` directly to FFI once it's a thin pointer.
125-
const UNNAMED: *const c_char = c"".as_ptr();
125+
pub(crate) const UNNAMED: *const c_char = c"".as_ptr();
126126

127127
impl<'ll, CX: Borrow<SCx<'ll>>> BackendTypes for GenericBuilder<'_, 'll, CX> {
128128
type Value = <GenericCx<'ll, CX> as BackendTypes>::Value;

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use rustc_middle::bug;
1010
use tracing::{debug, trace};
1111

1212
use crate::back::write::llvm_err;
13-
use crate::builder::SBuilder;
13+
use crate::builder::{SBuilder, UNNAMED};
1414
use crate::context::SimpleCx;
1515
use crate::declare::declare_simple_fn;
1616
use crate::errors::{AutoDiffWithoutEnable, LlvmError};
@@ -51,6 +51,7 @@ fn has_sret(fnc: &Value) -> bool {
5151
// using iterators and peek()?
5252
fn match_args_from_caller_to_enzyme<'ll>(
5353
cx: &SimpleCx<'ll>,
54+
builder: &SBuilder<'ll,'ll>,
5455
width: u32,
5556
args: &mut Vec<&'ll llvm::Value>,
5657
inputs: &[DiffActivity],
@@ -78,6 +79,7 @@ fn match_args_from_caller_to_enzyme<'ll>(
7879
let enzyme_const = cx.create_metadata("enzyme_const".to_string()).unwrap();
7980
let enzyme_out = cx.create_metadata("enzyme_out".to_string()).unwrap();
8081
let enzyme_dup = cx.create_metadata("enzyme_dup".to_string()).unwrap();
82+
let enzyme_dupv = cx.create_metadata("enzyme_dupv".to_string()).unwrap();
8183
let enzyme_dupnoneed = cx.create_metadata("enzyme_dupnoneed".to_string()).unwrap();
8284

8385
while activity_pos < inputs.len() {
@@ -90,13 +92,26 @@ fn match_args_from_caller_to_enzyme<'ll>(
9092
DiffActivity::Active => (enzyme_out, false),
9193
DiffActivity::ActiveOnly => (enzyme_out, false),
9294
DiffActivity::Dual => (enzyme_dup, true),
95+
DiffActivity::Dualv => (enzyme_dupv, true),
9396
DiffActivity::DualOnly => (enzyme_dupnoneed, true),
9497
DiffActivity::Duplicated => (enzyme_dup, true),
9598
DiffActivity::DuplicatedOnly => (enzyme_dupnoneed, true),
9699
DiffActivity::FakeActivitySize => (enzyme_const, false),
97100
};
98101
let outer_arg = outer_args[outer_pos];
99102
args.push(cx.get_metadata_value(activity));
103+
if matches!(diff_activity, DiffActivity::Dualv) {
104+
let next_outer_arg = outer_args[outer_pos + 1];
105+
// stride: sizeof(T) * n_elems.
106+
// T=f32 => 4 bytes
107+
// n_elems is the next integer.
108+
// Now we multiply `4 * next_outer_arg` to get the stride.
109+
//let mul = builder
110+
// .build_mul(cx.get_const_i64(4), next_outer_arg)
111+
// .unwrap();
112+
let mul = unsafe {llvm::LLVMBuildMul(builder.llbuilder, cx.get_const_i64(4), next_outer_arg, UNNAMED)};
113+
args.push(mul);
114+
}
100115
args.push(outer_arg);
101116
if duplicated {
102117
// We know that duplicated args by construction have a following argument,
@@ -125,7 +140,13 @@ fn match_args_from_caller_to_enzyme<'ll>(
125140
// int2 >= int1, which means the shadow vector is large enough to store the gradient.
126141
assert_eq!(cx.type_kind(next_outer_ty), TypeKind::Integer);
127142

128-
for i in 0..(width as usize) {
143+
let iterations = if matches!(diff_activity, DiffActivity::Dualv) {
144+
1
145+
} else {
146+
width as usize
147+
};
148+
149+
for i in 0..iterations {
129150
let next_outer_arg2 = outer_args[outer_pos + 2 * (i + 1)];
130151
let next_outer_ty2 = cx.val_ty(next_outer_arg2);
131152
assert_eq!(cx.type_kind(next_outer_ty2), TypeKind::Pointer);
@@ -136,7 +157,7 @@ fn match_args_from_caller_to_enzyme<'ll>(
136157
}
137158
args.push(cx.get_metadata_value(enzyme_const));
138159
args.push(next_outer_arg);
139-
outer_pos += 2 + 2 * width as usize;
160+
outer_pos += 2 + 2 * iterations;
140161
activity_pos += 2;
141162
} else {
142163
// A duplicated pointer will have the following two outer_fn arguments:
@@ -344,6 +365,7 @@ fn generate_enzyme_call<'ll>(
344365
let outer_args: Vec<&llvm::Value> = get_params(outer_fn);
345366
match_args_from_caller_to_enzyme(
346367
&cx,
368+
&builder,
347369
attrs.width,
348370
&mut args,
349371
&attrs.input_activity,

compiler/rustc_monomorphize/src/partitioning/autodiff.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec
3131
let activity = match da[i] {
3232
DiffActivity::DualOnly
3333
| DiffActivity::Dual
34+
| DiffActivity::Dualv
3435
| DiffActivity::DuplicatedOnly
3536
| DiffActivity::Duplicated => DiffActivity::FakeActivitySize,
3637
DiffActivity::Const => DiffActivity::Const,

0 commit comments

Comments
 (0)