Skip to content

Commit 66efdd7

Browse files
committed
Filter example working
1 parent 33a6f93 commit 66efdd7

File tree

3 files changed

+73
-21
lines changed

3 files changed

+73
-21
lines changed

src/ecCircuits.ml

Lines changed: 68 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ module type CircuitInterface = sig
206206

207207
(* Mapreduce/Dependecy analysis related functions *)
208208
val is_decomposable : int -> int -> cbitstring cfun -> bool
209-
val decompose : int -> int -> cbitstring cfun -> (cbitstring cfun) list
209+
val decompose : int -> int -> cbitstring cfun -> (cbitstring cfun) list * (int * int)
210210
val permute : int -> (int -> int) -> cbitstring cfun -> cbitstring cfun
211211

212212
(* Wraps the backend call to deal with args/inputs *)
@@ -320,6 +320,10 @@ module type CBackend = sig
320320
val is_splittable : int -> int -> deps -> bool
321321

322322
val are_independent : block_deps -> bool
323+
324+
val single_dep : deps -> bool
325+
(* Assumes single_dep *)
326+
val dep_range : deps -> int * int
323327
end
324328
end
325329

@@ -425,11 +429,14 @@ module TestBack : CBackend = struct
425429
let get (r: reg) (idx: int) = r.(idx)
426430

427431
let permute (w: int) (perm: int -> int) (r: reg) : reg =
432+
Format.eprintf "Applying permutation to reg of size %d with block size of %d@." (size_of_reg r) w;
428433
Array.init (size_of_reg r) (fun i ->
429-
let block_idx, bit_idx = (i / w), (i mod w) in
430-
let idx = (perm block_idx)*w + bit_idx in
431-
r.(idx)
432-
)
434+
let block_idx, bit_idx = perm (i / w), (i mod w) in
435+
if block_idx < 0 then None
436+
else
437+
let idx = block_idx*w + bit_idx in
438+
Some r.(idx)
439+
) |> Array.filter_map (fun x -> x)
433440

434441

435442
(* Node operations *)
@@ -536,17 +543,17 @@ module TestBack : CBackend = struct
536543
| 0 -> true
537544
| 1 ->
538545
let blocks = block_deps_of_deps w_out d in
539-
(* Format.eprintf "Checking block width...@."; *)
546+
Format.eprintf "Checking block width...@.";
540547
Array.for_all (fun (_, d) ->
541548
if Map.is_empty d then true
542549
else
543550
let _, bits = Map.any d in
544551
Set.is_empty bits ||
545552
let base = Set.at_rank_exn 0 bits in
546-
(* Format.eprintf "Base for current block: %d@." base; *)
553+
Format.eprintf "Base for current block: %d@." base;
547554
Set.for_all (fun bit ->
548555
let dist = bit - base in
549-
(* Format.eprintf "Current bit: %d | Current dist: %d | Limit: %d@." bit dist w_in; *)
556+
Format.eprintf "Current bit: %d | Current dist: %d | Limit: %d@." bit dist w_in;
550557
0 <= dist && dist < w_in
551558
) bits
552559
) blocks
@@ -576,6 +583,28 @@ module TestBack : CBackend = struct
576583
true
577584
with BreakOut ->
578585
false
586+
587+
588+
let single_dep (d: deps) : bool =
589+
match Set.cardinal
590+
(Array.fold_left (Set.union) Set.empty
591+
(Array.map (fun dep -> Map.keys dep |> Set.of_enum) d))
592+
with
593+
| 0 | 1 -> true
594+
| _ -> false
595+
596+
(* Assumes single_dep, returns range (bot, top) such that valid idxs are bot <= i < top *)
597+
let dep_range (d: deps) : int * int =
598+
assert (single_dep d);
599+
let idxs =
600+
Array.fold_left (fun acc d ->
601+
Set.union (Map.fold Set.union d Set.empty) acc) Set.empty d
602+
in
603+
Format.eprintf "%a@." pp_deps d;
604+
Format.eprintf "Dep range for dependencies:@.";
605+
Set.iter (fun i -> Format.eprintf "%d " i) idxs;
606+
Format.eprintf "@.Min: %d | Max: %d@." (Set.min_elt idxs) (Set.max_elt idxs);
607+
(Set.min_elt idxs, Set.max_elt idxs + 1)
579608
end
580609

581610
end
@@ -1272,7 +1301,7 @@ module MakeCircuitInterfaceFromCBackend(Backend: CBackend) : CircuitInterface =
12721301
let array_oflist (circs : circuit list) (dfl: circuit) (len: int) : circuit =
12731302
let circs, inps = List.split circs in
12741303
let dif = len - List.length circs in
1275-
Format.eprintf "Len, Dif in array_oflist: %d, %d@." len dif;
1304+
(* Format.eprintf "Len, Dif in array_oflist: %d, %d@." len dif; *)
12761305
let circs = circs @ (List.init dif (fun _ -> fst dfl)) in
12771306
let inps = if dif > 0 then inps @ [snd dfl] else inps in
12781307
let circs = List.map
@@ -1518,14 +1547,32 @@ module MakeCircuitInterfaceFromCBackend(Backend: CBackend) : CircuitInterface =
15181547
(* For more complex circuits, we might be able to simulate this with a int -> (int, int) map *)
15191548
let is_decomposable (in_w: width) (out_w: width) ((`CBitstring r, inps) as c: cbitstring cfun) : bool =
15201549
match inps with
1521-
| {type_=`CIBitstring w} :: [] when w mod in_w = 0 && Backend.size_of_reg r mod out_w = 0 ->
1550+
| {type_=`CIBitstring w} :: [] when (Backend.size_of_reg r mod out_w = 0) ->
15221551
let deps = Backend.Deps.deps_of_reg r in
1523-
Backend.Deps.is_splittable in_w out_w deps
1552+
Backend.Deps.is_splittable in_w out_w deps &&
1553+
let base, top = Backend.Deps.dep_range deps in
1554+
let () = Format.eprintf "Passed backend check, checking width of deps (top - base = %d | in_w = %d)@." (top - base) in_w in
1555+
(top - base) mod in_w = 0
15241556
| _ ->
15251557
Format.eprintf "Failed decomposition type check@\n";
15261558
Format.eprintf "In_w: %d | Out_w : %d | Circ: %a" in_w out_w pp_circuit c;
15271559
false
15281560

1561+
(* TODO: Extend this for multiple inputs? *)
1562+
let align_renamer ((`CBitstring r, inps) : cbitstring cfun) : (int * int) * cinp * (Backend.inp -> Backend.inp option) =
1563+
match inps with
1564+
| [{type_ = `CIBitstring w; id}] ->
1565+
let d = Backend.Deps.deps_of_reg r in
1566+
assert (Backend.Deps.single_dep d);
1567+
let (start_idx, end_idx) as range = Backend.Deps.dep_range d in
1568+
range,
1569+
{type_ = `CIBitstring (end_idx - start_idx); id},
1570+
(fun (id_, w) ->
1571+
if id <> id_ then None else
1572+
if w < start_idx || w >= end_idx then None
1573+
else Some (id_, w - start_idx))
1574+
| _ -> assert false
1575+
15291576
let split_renamer (n: count) (in_w: width) (inp: cinp) : (cinp array) * (Backend.inp -> cbool_type option) =
15301577
match inp with
15311578
| {type_ = `CIBitstring w; id} when w mod in_w = 0 ->
@@ -1535,9 +1582,12 @@ module MakeCircuitInterfaceFromCBackend(Backend: CBackend) : CircuitInterface =
15351582
if id <> id_ then None else
15361583
let id_idx, bit_idx = (w / in_w), (w mod in_w) in
15371584
Some (Backend.input_node ~id:ids.(id_idx) bit_idx))
1585+
| {type_ = `CIBitstring w; id} ->
1586+
Format.eprintf "Failed to build split renamer for n=%d in_w=%d w=%d@." n in_w w;
1587+
assert false
15381588
| _ -> assert false
15391589

1540-
let decompose (in_w: width) (out_w: width) ((`CBitstring r, inps) as c: cbitstring cfun) : cbitstring cfun list =
1590+
let decompose (in_w: width) (out_w: width) ((`CBitstring r, inps) as c: cbitstring cfun) : cbitstring cfun list * (int * int) =
15411591
if not (is_decomposable in_w out_w c) then
15421592
let deps = Backend.Deps.block_deps_of_reg out_w r in
15431593
Format.eprintf "Failed to decompose. in_w=%d out_w=%d Deps:@.%a" in_w out_w (Backend.Deps.pp_block_deps) deps;
@@ -1546,11 +1596,13 @@ module MakeCircuitInterfaceFromCBackend(Backend: CBackend) : CircuitInterface =
15461596
let n = (Backend.size_of_reg r) / out_w in
15471597
let blocks = Array.init n (fun i ->
15481598
Backend.slice r (i*out_w) out_w) in
1549-
let cinps, renamer = split_renamer n in_w (List.hd inps) in
1599+
let range, cinp, aligner = align_renamer c in
1600+
let cinps, renamer = split_renamer n in_w cinp in
1601+
let renamer = fun i -> Option.bind (aligner i) renamer in
15501602
Array.map2 (fun r inp ->
15511603
let r = Backend.applys renamer r in
15521604
(`CBitstring r, [inp])
1553-
) blocks cinps |> Array.to_list
1605+
) blocks cinps |> Array.to_list, range
15541606

15551607
let permute (w: width) (perm: (int -> int)) ((`CBitstring r, inps): cbitstring cfun) : cbitstring cfun =
15561608
`CBitstring (Backend.permute w perm r), inps
@@ -2164,13 +2216,13 @@ let circuit_permute (bsz: int) (perm: int -> int) (c: circuit) : circuit =
21642216
in
21652217
(permute bsz perm c :> circuit)
21662218

2167-
let circuit_mapreduce ?(perm : (int -> int) option) (c: circuit) (w_in: width) (w_out: width) : circuit list =
2219+
let circuit_mapreduce ?(perm : (int -> int) option) (c: circuit) (w_in: width) (w_out: width) : circuit list * (int * int) =
21682220
let c = match c, perm with
21692221
| (`CBitstring _, inps) as c, None -> c
21702222
| (`CBitstring _, inps) as c, Some perm -> permute w_out perm c
21712223
| _ -> assert false
21722224
in
2173-
(decompose w_in w_out c :> circuit list)
2225+
(decompose w_in w_out c :> circuit list * (int * int))
21742226

21752227
type circuit = ExampleInterface.circuit
21762228
type pstate = ExampleInterface.PState.pstate

src/ecCircuits.mli

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ val circuit_aggregate : circuit list -> circuit
3939
val circuit_aggregate_inps : circuit -> circuit
4040
val circuit_flatten : circuit -> circuit
4141
val circuit_permute : int -> (int -> int) -> circuit -> circuit
42-
val circuit_mapreduce : ?perm:(int -> int) -> circuit -> int -> int -> circuit list
42+
val circuit_mapreduce : ?perm:(int -> int) -> circuit -> int -> int -> circuit list * (int * int)
4343

4444
(* Use circuits *)
4545
val compute : sign:bool -> circuit -> BI.zint list -> BI.zint

src/phl/ecPhlBDep.ml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ let mapreduce
118118
raise (BDepError "Failed to concatenate outputs")
119119
in
120120

121-
let cs = try
121+
let cs, mr_range = try
122122
circuit_mapreduce ?perm c n m
123123
with CircError err ->
124124
raise (BDepError err)
@@ -206,13 +206,13 @@ let prog_equiv_prod
206206

207207

208208
let tm = time tm "Preprocessing for mapreduce done" in
209-
let lanes_l = try
209+
let lanes_l, mr_range_l = try
210210
circuit_mapreduce c_l n m
211211
with CircError err ->
212212
raise (BDepError ("Left program split step failed with error:\n" ^ err))
213213
in
214214
let tm = time tm "Left program deps + split done" in
215-
let lanes_r = try
215+
let lanes_r, mr_range_r = try
216216
circuit_mapreduce c_r n m
217217
with CircError err ->
218218
raise (BDepError ("Right program split step failed with error:\n" ^ err))
@@ -372,7 +372,7 @@ let mapreduce_eval
372372
in
373373

374374

375-
let cs = try
375+
let cs, mr_range = try
376376
circuit_mapreduce c n m
377377
with CircError err ->
378378
raise (BDepError ("Split step failed with error:\n" ^ err))

0 commit comments

Comments
 (0)