@@ -206,7 +206,7 @@ module type CircuitInterface = sig
206
206
207
207
(* Mapreduce/Dependecy analysis related functions *)
208
208
val is_decomposable : int -> int -> cbitstring cfun -> bool
209
- val decompose : int -> int -> cbitstring cfun -> (cbitstring cfun ) list
209
+ val decompose : int -> int -> cbitstring cfun -> (cbitstring cfun ) list * ( int * int )
210
210
val permute : int -> (int -> int ) -> cbitstring cfun -> cbitstring cfun
211
211
212
212
(* Wraps the backend call to deal with args/inputs *)
@@ -320,6 +320,10 @@ module type CBackend = sig
320
320
val is_splittable : int -> int -> deps -> bool
321
321
322
322
val are_independent : block_deps -> bool
323
+
324
+ val single_dep : deps -> bool
325
+ (* Assumes single_dep *)
326
+ val dep_range : deps -> int * int
323
327
end
324
328
end
325
329
@@ -425,11 +429,14 @@ module TestBack : CBackend = struct
425
429
let get (r : reg ) (idx : int ) = r.(idx)
426
430
427
431
let permute (w : int ) (perm : int -> int ) (r : reg ) : reg =
432
+ Format. eprintf " Applying permutation to reg of size %d with block size of %d@." (size_of_reg r) w;
428
433
Array. init (size_of_reg r) (fun i ->
429
- let block_idx, bit_idx = (i / w), (i mod w) in
430
- let idx = (perm block_idx)* w + bit_idx in
431
- r.(idx)
432
- )
434
+ let block_idx, bit_idx = perm (i / w), (i mod w) in
435
+ if block_idx < 0 then None
436
+ else
437
+ let idx = block_idx* w + bit_idx in
438
+ Some r.(idx)
439
+ ) |> Array. filter_map (fun x -> x)
433
440
434
441
435
442
(* Node operations *)
@@ -536,17 +543,17 @@ module TestBack : CBackend = struct
536
543
| 0 -> true
537
544
| 1 ->
538
545
let blocks = block_deps_of_deps w_out d in
539
- (* Format.eprintf "Checking block width...@."; *)
546
+ Format. eprintf " Checking block width...@." ;
540
547
Array. for_all (fun (_ , d ) ->
541
548
if Map. is_empty d then true
542
549
else
543
550
let _, bits = Map. any d in
544
551
Set. is_empty bits ||
545
552
let base = Set. at_rank_exn 0 bits in
546
- (* Format.eprintf "Base for current block: %d@." base; *)
553
+ Format. eprintf " Base for current block: %d@." base;
547
554
Set. for_all (fun bit ->
548
555
let dist = bit - base in
549
- (* Format.eprintf "Current bit: %d | Current dist: %d | Limit: %d@." bit dist w_in; *)
556
+ Format. eprintf " Current bit: %d | Current dist: %d | Limit: %d@." bit dist w_in;
550
557
0 < = dist && dist < w_in
551
558
) bits
552
559
) blocks
@@ -576,6 +583,28 @@ module TestBack : CBackend = struct
576
583
true
577
584
with BreakOut ->
578
585
false
586
+
587
+
588
+ let single_dep (d : deps ) : bool =
589
+ match Set. cardinal
590
+ (Array. fold_left (Set. union) Set. empty
591
+ (Array. map (fun dep -> Map. keys dep |> Set. of_enum) d))
592
+ with
593
+ | 0 | 1 -> true
594
+ | _ -> false
595
+
596
+ (* Assumes single_dep, returns range (bot, top) such that valid idxs are bot <= i < top *)
597
+ let dep_range (d : deps ) : int * int =
598
+ assert (single_dep d);
599
+ let idxs =
600
+ Array. fold_left (fun acc d ->
601
+ Set. union (Map. fold Set. union d Set. empty) acc) Set. empty d
602
+ in
603
+ Format. eprintf " %a@." pp_deps d;
604
+ Format. eprintf " Dep range for dependencies:@." ;
605
+ Set. iter (fun i -> Format. eprintf " %d " i) idxs;
606
+ Format. eprintf " @.Min: %d | Max: %d@." (Set. min_elt idxs) (Set. max_elt idxs);
607
+ (Set. min_elt idxs, Set. max_elt idxs + 1 )
579
608
end
580
609
581
610
end
@@ -1272,7 +1301,7 @@ module MakeCircuitInterfaceFromCBackend(Backend: CBackend) : CircuitInterface =
1272
1301
let array_oflist (circs : circuit list ) (dfl : circuit ) (len : int ) : circuit =
1273
1302
let circs, inps = List. split circs in
1274
1303
let dif = len - List. length circs in
1275
- Format. eprintf " Len, Dif in array_oflist: %d, %d@." len dif;
1304
+ (* Format.eprintf "Len, Dif in array_oflist: %d, %d@." len dif; *)
1276
1305
let circs = circs @ (List. init dif (fun _ -> fst dfl)) in
1277
1306
let inps = if dif > 0 then inps @ [snd dfl] else inps in
1278
1307
let circs = List. map
@@ -1518,14 +1547,32 @@ module MakeCircuitInterfaceFromCBackend(Backend: CBackend) : CircuitInterface =
1518
1547
(* For more complex circuits, we might be able to simulate this with a int -> (int, int) map *)
1519
1548
let is_decomposable (in_w : width ) (out_w : width ) ((`CBitstring r , inps ) as c : cbitstring cfun ) : bool =
1520
1549
match inps with
1521
- | {type_ =`CIBitstring w } :: [] when w mod in_w = 0 && Backend. size_of_reg r mod out_w = 0 ->
1550
+ | {type_ =`CIBitstring w } :: [] when ( Backend. size_of_reg r mod out_w = 0 ) ->
1522
1551
let deps = Backend.Deps. deps_of_reg r in
1523
- Backend.Deps. is_splittable in_w out_w deps
1552
+ Backend.Deps. is_splittable in_w out_w deps &&
1553
+ let base, top = Backend.Deps. dep_range deps in
1554
+ let () = Format. eprintf " Passed backend check, checking width of deps (top - base = %d | in_w = %d)@." (top - base) in_w in
1555
+ (top - base) mod in_w = 0
1524
1556
| _ ->
1525
1557
Format. eprintf " Failed decomposition type check@\n " ;
1526
1558
Format. eprintf " In_w: %d | Out_w : %d | Circ: %a" in_w out_w pp_circuit c;
1527
1559
false
1528
1560
1561
+ (* TODO: Extend this for multiple inputs? *)
1562
+ let align_renamer ((`CBitstring r , inps ) : cbitstring cfun ) : (int * int) * cinp * (Backend.inp -> Backend.inp option) =
1563
+ match inps with
1564
+ | [{type_ = `CIBitstring w; id}] ->
1565
+ let d = Backend.Deps. deps_of_reg r in
1566
+ assert (Backend.Deps. single_dep d);
1567
+ let (start_idx, end_idx) as range = Backend.Deps. dep_range d in
1568
+ range,
1569
+ {type_ = `CIBitstring (end_idx - start_idx); id},
1570
+ (fun (id_ , w ) ->
1571
+ if id <> id_ then None else
1572
+ if w < start_idx || w > = end_idx then None
1573
+ else Some (id_, w - start_idx))
1574
+ | _ -> assert false
1575
+
1529
1576
let split_renamer (n : count ) (in_w : width ) (inp : cinp ) : (cinp array) * (Backend.inp -> cbool_type option) =
1530
1577
match inp with
1531
1578
| {type_ = `CIBitstring w ; id} when w mod in_w = 0 ->
@@ -1535,9 +1582,12 @@ module MakeCircuitInterfaceFromCBackend(Backend: CBackend) : CircuitInterface =
1535
1582
if id <> id_ then None else
1536
1583
let id_idx, bit_idx = (w / in_w), (w mod in_w) in
1537
1584
Some (Backend. input_node ~id: ids.(id_idx) bit_idx))
1585
+ | {type_ = `CIBitstring w ; id} ->
1586
+ Format. eprintf " Failed to build split renamer for n=%d in_w=%d w=%d@." n in_w w;
1587
+ assert false
1538
1588
| _ -> assert false
1539
1589
1540
- let decompose (in_w : width ) (out_w : width ) ((`CBitstring r , inps ) as c : cbitstring cfun ) : cbitstring cfun list =
1590
+ let decompose (in_w : width ) (out_w : width ) ((`CBitstring r , inps ) as c : cbitstring cfun ) : cbitstring cfun list * (int * int) =
1541
1591
if not (is_decomposable in_w out_w c) then
1542
1592
let deps = Backend.Deps. block_deps_of_reg out_w r in
1543
1593
Format. eprintf " Failed to decompose. in_w=%d out_w=%d Deps:@.%a" in_w out_w (Backend.Deps. pp_block_deps) deps;
@@ -1546,11 +1596,13 @@ module MakeCircuitInterfaceFromCBackend(Backend: CBackend) : CircuitInterface =
1546
1596
let n = (Backend. size_of_reg r) / out_w in
1547
1597
let blocks = Array. init n (fun i ->
1548
1598
Backend. slice r (i* out_w) out_w) in
1549
- let cinps, renamer = split_renamer n in_w (List. hd inps) in
1599
+ let range, cinp, aligner = align_renamer c in
1600
+ let cinps, renamer = split_renamer n in_w cinp in
1601
+ let renamer = fun i -> Option. bind (aligner i) renamer in
1550
1602
Array. map2 (fun r inp ->
1551
1603
let r = Backend. applys renamer r in
1552
1604
(`CBitstring r, [inp])
1553
- ) blocks cinps |> Array. to_list
1605
+ ) blocks cinps |> Array. to_list, range
1554
1606
1555
1607
let permute (w : width ) (perm : (int -> int) ) ((`CBitstring r , inps ): cbitstring cfun ) : cbitstring cfun =
1556
1608
`CBitstring (Backend. permute w perm r), inps
@@ -2164,13 +2216,13 @@ let circuit_permute (bsz: int) (perm: int -> int) (c: circuit) : circuit =
2164
2216
in
2165
2217
(permute bsz perm c :> circuit )
2166
2218
2167
- let circuit_mapreduce ?(perm : (int -> int) option ) (c : circuit ) (w_in : width ) (w_out : width ) : circuit list =
2219
+ let circuit_mapreduce ?(perm : (int -> int) option ) (c : circuit ) (w_in : width ) (w_out : width ) : circuit list * (int * int) =
2168
2220
let c = match c, perm with
2169
2221
| (`CBitstring _ , inps ) as c , None -> c
2170
2222
| (`CBitstring _ , inps ) as c , Some perm -> permute w_out perm c
2171
2223
| _ -> assert false
2172
2224
in
2173
- (decompose w_in w_out c :> circuit list )
2225
+ (decompose w_in w_out c :> circuit list * (int * int ) )
2174
2226
2175
2227
type circuit = ExampleInterface .circuit
2176
2228
type pstate = ExampleInterface.PState .pstate
0 commit comments