@@ -208,6 +208,7 @@ module type CircuitInterface = sig
208
208
val is_decomposable : int -> int -> cbitstring cfun -> bool
209
209
val decompose : int -> int -> cbitstring cfun -> (cbitstring cfun ) list * (int * int )
210
210
val permute : int -> (int -> int ) -> cbitstring cfun -> cbitstring cfun
211
+ val align_inputs : circuit -> (int * int ) list -> circuit
211
212
212
213
(* Wraps the backend call to deal with args/inputs *)
213
214
module CircuitSpec : sig
@@ -709,7 +710,9 @@ module MakeCircuitInterfaceFromCBackend(Backend: CBackend) : CircuitInterface =
709
710
710
711
(* Inputs helper functions *)
711
712
let merge_inputs (cs : cinp list ) (ds : cinp list ) : cinp list =
712
- cs @ ds
713
+ (* FIXME: hack *)
714
+ if cs = ds then cs
715
+ else cs @ ds
713
716
714
717
let merge_inputs_list (cs : cinp list list ) : cinp list =
715
718
List. fold_right (merge_inputs) cs []
@@ -1573,6 +1576,30 @@ module MakeCircuitInterfaceFromCBackend(Backend: CBackend) : CircuitInterface =
1573
1576
else Some (id_, w - start_idx))
1574
1577
| _ -> assert false
1575
1578
1579
+ let align_inputs ((c , inps ): circuit ) (slcs : (int * int) list ) : circuit =
1580
+ assert ((List. length inps = 1 ) && (List. length slcs = 1 ));
1581
+ let (sz, offset) = List. hd slcs in
1582
+ let inp = match inps with
1583
+ | {type_ = `CIBitstring w_ } as inp :: [] ->
1584
+ {inp with type_ = `CIBitstring sz}
1585
+ | {type_ = `CIArray (w , n )} as inp :: [] ->
1586
+ assert (sz mod w = 0 );
1587
+ {inp with type_ = `CIArray (w, sz / w)}
1588
+ | _ -> assert false
1589
+ in
1590
+ let aligner =
1591
+ (fun (id_ , w ) ->
1592
+ Format. eprintf " Aligning id=%d w=%d offset=%d sz=%d@." id_ w offset sz;
1593
+ if inp.id <> id_ then None else
1594
+ if w < offset || w > = offset + sz then Some Backend. bad
1595
+ else Some (Backend. input_node ~id: id_ (w - offset)))
1596
+ in
1597
+ match c with
1598
+ | `CBitstring r -> (`CBitstring (Backend. applys aligner r), [inp])
1599
+ | `CArray (r , w ) -> (`CArray (Backend. applys aligner r, w), [inp])
1600
+ | `CTuple (r , ws ) -> (`CTuple (Backend. applys aligner r, ws), [inp])
1601
+ | `CBool b -> (`CBool (Backend. apply aligner b), [inp])
1602
+
1576
1603
let split_renamer (n : count ) (in_w : width ) (inp : cinp ) : (cinp array) * (Backend.inp -> cbool_type option) =
1577
1604
match inp with
1578
1605
| {type_ = `CIBitstring w ; id} when w mod in_w = 0 ->
@@ -1596,13 +1623,14 @@ module MakeCircuitInterfaceFromCBackend(Backend: CBackend) : CircuitInterface =
1596
1623
let n = (Backend. size_of_reg r) / out_w in
1597
1624
let blocks = Array. init n (fun i ->
1598
1625
Backend. slice r (i* out_w) out_w) in
1599
- let range, cinp, aligner = align_renamer c in
1626
+ (* let range, cinp, aligner = align_renamer c in *)
1627
+ let cinp = (List. hd inps) in
1600
1628
let cinps, renamer = split_renamer n in_w cinp in
1601
- let renamer = fun i -> Option. bind (aligner i) renamer in
1629
+ (* let renamer = fun i -> Option.bind (aligner i) renamer in *)
1602
1630
Array. map2 (fun r inp ->
1603
1631
let r = Backend. applys renamer r in
1604
1632
(`CBitstring r, [inp])
1605
- ) blocks cinps |> Array. to_list, range
1633
+ ) blocks cinps |> Array. to_list, ( 0 , 0 )
1606
1634
1607
1635
let permute (w : width ) (perm : (int -> int) ) ((`CBitstring r , inps ): cbitstring cfun ) : cbitstring cfun =
1608
1636
`CBitstring (Backend. permute w perm r), inps
@@ -2241,6 +2269,12 @@ let circuit_aggregate =
2241
2269
let circuit_aggregate_inps =
2242
2270
circuit_aggregate_inputs
2243
2271
2272
+ let circuit_slice (c : circuit ) (sz : int ) (offset : int ) = assert false
2273
+
2274
+ (* FIXME: this should use ids instead of strings *)
2275
+ let circuit_align_inputs (c : circuit ) (slcs : (symbol * (int * int) ) list ) =
2276
+ align_inputs c (List. snd slcs)
2277
+
2244
2278
let circuit_flatten (c : circuit ) =
2245
2279
(cbitstring_of_circuit ~strict: false c :> circuit )
2246
2280
0 commit comments