Skip to content

Commit 5913127

Browse files
author
Gustavo Delerue
committed
WIP: Better abstraction for ecCircuits
1 parent 39310b5 commit 5913127

File tree

3 files changed

+106
-106
lines changed

3 files changed

+106
-106
lines changed

src/ecCircuits.ml

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,10 +1086,10 @@ type cache = (ident, (cinput * circuit)) Map.t
10861086
if not: remove env argument from recursive calls *)
10871087
let circuit_of_form
10881088
?(pstate : pstate = Map.empty) (* Program variable values *)
1089-
?(cache : cache = Map.empty) (* Let-bindings and such *)
10901089
(hyps : hyps)
10911090
(f_ : EcAst.form)
10921091
: circuit =
1092+
let cache = Map.empty in
10931093

10941094
let rec doit (cache: (ident, (cinput * circuit)) Map.t) (hyps: hyps) (f_: form) : hyps * circuit =
10951095
let env = toenv hyps in
@@ -1497,18 +1497,18 @@ let pstate_of_memtype ?pstate (env: env) (mt : memtype) =
14971497
) (Option.get lmt).lmt_decl in
14981498
pstate_of_variables ?pstate env vars
14991499

1500-
let process_instr (hyps: hyps) (mem: memory) ?(cache: cache = Map.empty) (pstate: _) (inst: instr) =
1500+
let process_instr (hyps: hyps) (mem: memory) (pstate: _) (inst: instr) =
15011501
let env = toenv hyps in
15021502
(* Format.eprintf "[W]Processing : %a@." (EcPrinting.pp_instr (EcPrinting.PPEnv.ofenv env)) inst; *)
15031503
(* let start = Unix.gettimeofday () in *)
15041504
try
15051505
match inst.i_node with
15061506
| Sasgn (LvVar (PVloc v, _ty), e) ->
1507-
let pstate = Map.add v (form_of_expr mem e |> circuit_of_form ~pstate ~cache hyps) pstate in
1507+
let pstate = Map.add v (form_of_expr mem e |> circuit_of_form ~pstate hyps) pstate in
15081508
(* Format.eprintf "[W] Took %f seconds@." (Unix.gettimeofday() -. start); *)
15091509
pstate
15101510
| Sasgn (LvTuple (vs), e) ->
1511-
let tp = (form_of_expr mem e |> circuit_of_form ~pstate ~cache hyps) in
1511+
let tp = (form_of_expr mem e |> circuit_of_form ~pstate hyps) in
15121512
assert (is_bwtuple tp.circ);
15131513
let comps = circuits_of_circuit tp in
15141514
let pstate = List.fold_left2 (fun pstate (pv, _ty) c ->
@@ -1590,3 +1590,55 @@ let instrs_equiv
15901590
let circ2 = { circ2 with inps = inputs @ circ2.inps } in
15911591
circ_equiv circ1 circ2 None
15921592
)
1593+
1594+
let initial_pstate_of_vars (env: env) (invs: variable list) : cinput list * (symbol, circuit) Map.t =
1595+
let pstate : (symbol, circuit) Map.t = Map.empty in
1596+
1597+
let inps = List.map (input_of_variable env) invs in
1598+
let inpcs, inps = List.split inps in
1599+
(* List.iter (fun c -> Format.eprintf "Inp: %s @." (cinput_to_string c)) inps; *)
1600+
let inpcs = List.combine inpcs @@ List.map (fun v -> v.v_name) invs in
1601+
1602+
inps, List.fold_left
1603+
(fun pstate (inp, v) -> Map.add v inp pstate)
1604+
pstate inpcs
1605+
1606+
(* Generates pstate : (symbol, circuit) Map from program
1607+
and inputs associated to the program
1608+
Throws: CircError on failure
1609+
*)
1610+
let pstate_of_prog (hyps: hyps) (mem: memory) (proc: instr list) (invs: variable list) : (symbol, circuit) Map.t =
1611+
let inps, pstate = initial_pstate_of_vars (toenv hyps) (invs) in
1612+
1613+
let pstate =
1614+
List.fold_left (process_instr hyps mem) pstate proc
1615+
in
1616+
Map.map (fun c -> assert (c.inps = []); {c with inps=inps}) pstate
1617+
1618+
(* FIXME: refactor this function *)
1619+
let rec circ_simplify_form_bitstring_equality
1620+
?(mem = mhr)
1621+
?(pstate: (symbol, circuit) Map.t = Map.empty)
1622+
?(pcond: circuit option)
1623+
(hyps: hyps)
1624+
(f: form)
1625+
: form =
1626+
let env = toenv hyps in
1627+
1628+
let rec check (f : form) =
1629+
match EcFol.sform_of_form f with
1630+
| SFeq (f1, f2)
1631+
when (Option.is_some @@ EcEnv.Circuit.lookup_bitstring env f1.f_ty)
1632+
|| (Option.is_some @@ EcEnv.Circuit.lookup_array env f1.f_ty)
1633+
->
1634+
let c1 = circuit_of_form ~pstate hyps f1 in
1635+
let c2 = circuit_of_form ~pstate hyps f2 in
1636+
Format.eprintf "[W]Testing circuit equivalence for forms:
1637+
%a@.%[email protected] circuits: %s | %s@."
1638+
(EcPrinting.pp_form (EcPrinting.PPEnv.ofenv env)) f1
1639+
(EcPrinting.pp_form (EcPrinting.PPEnv.ofenv env)) f2
1640+
(circuit_to_string c1)
1641+
(circuit_to_string c2);
1642+
f_bool (circ_equiv c1 c2 pcond)
1643+
| _ -> f_map (fun ty -> ty) check f
1644+
in check f

src/ecCircuits.mli

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,34 +9,38 @@ open LDecl
99
module Map = Batteries.Map
1010

1111
(* -------------------------------------------------------------------- *)
12-
type circ
13-
type cinput
14-
type circuit = { circ: circ; inps: cinput list; }
12+
type circuit
1513
type pstate = (symbol, circuit) Map.t
16-
type cache = (EcIdent.t, (cinput * circuit)) Map.t
14+
(*type cache = (EcIdent.t, (cinput * circuit)) Map.t*)
1715

1816
(* -------------------------------------------------------------------- *)
1917
exception CircError of string
2018

2119
(* -------------------------------------------------------------------- *)
2220
val get_specification_by_name : string -> Lospecs.Ast.adef option
2321
val circ_red : hyps -> EcReduction.reduction_info
24-
val cinput_to_string : cinput -> string
25-
val cinput_of_type : ?idn:ident -> env -> ty -> cinput
22+
(*val cinput_to_string : cinput -> string*)
23+
(*val cinput_of_type : ?idn:ident -> env -> ty -> cinput*)
2624
val width_of_type : env -> ty -> int
27-
val size_of_circ : circ -> int
25+
(*val size_of_circ : circ -> int *)
2826
val compute : sign:bool -> circuit -> BI.zint list -> BI.zint
2927
val circuit_to_string : circuit -> string
30-
val circ_ident : cinput -> circuit
28+
(*val circ_ident : cinput -> circuit*)
3129
val circuit_ueq : circuit -> circuit -> circuit
3230
val circuit_aggregate : circuit list -> circuit
3331
val circuit_aggregate_inps : circuit -> circuit
3432
val circuit_flatten : circuit -> circuit
3533
val circuit_permutation : int -> int -> (int -> int) -> circuit
3634
val circuit_mapreduce : ?perm:(int -> int) -> circuit -> int -> int -> circuit list
3735
val circ_equiv : ?strict:bool -> circuit -> circuit -> circuit option -> bool
38-
val circuit_of_form : ?pstate:pstate -> ?cache:cache -> hyps -> form -> circuit
39-
val pstate_of_memtype : ?pstate:pstate -> env -> memtype -> pstate * cinput list
40-
val input_of_variable : env -> variable -> circuit * cinput
36+
val circuit_of_form : ?pstate:pstate -> hyps -> form -> circuit
37+
(*val pstate_of_memtype : ?pstate:pstate -> env -> memtype -> pstate * cinput list*)
38+
val pstate_of_prog : hyps -> memory -> instr list -> variable list -> (symbol, circuit) Map.t
39+
(*val input_of_variable : env -> variable -> circuit * cinput*)
4140
val instrs_equiv : hyps -> memenv -> ?keep:EcPV.PV.t -> ?pstate:pstate -> instr list -> instr list -> bool
42-
val process_instr : hyps -> memory -> ?cache:cache -> pstate -> instr -> (symbol, circuit) Map.t
41+
val process_instr : hyps -> memory -> pstate -> instr -> (symbol, circuit) Map.t
42+
val circ_simplify_form_bitstring_equality :
43+
?mem:EcMemory.memory ->
44+
?pstate:(string, circuit) Map.t ->
45+
?pcond:circuit -> hyps -> form -> form
46+

src/phl/ecPhlBDep.ml

Lines changed: 34 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -57,32 +57,6 @@ let circ_of_qsymbol (hyps: hyps) (qs: qsymbol) : circuit =
5757
fc
5858
with CircError err ->
5959
raise (BDepError err)
60-
61-
62-
let initial_pstate_of_vars (env: env) (invs: variable list) : cinput list * (symbol, circuit) Map.t =
63-
let pstate : (symbol, circuit) Map.t = Map.empty in
64-
65-
let inps = List.map (EcCircuits.input_of_variable env) invs in
66-
let inpcs, inps = List.split inps in
67-
(* List.iter (fun c -> Format.eprintf "Inp: %s @." (cinput_to_string c)) inps; *)
68-
let inpcs = List.combine inpcs @@ List.map (fun v -> v.v_name) invs in
69-
70-
inps, List.fold_left
71-
(fun pstate (inp, v) -> Map.add v inp pstate)
72-
pstate inpcs
73-
74-
(* Generates pstate : (symbol, circuit) Map from program
75-
Throws: BDepError on failure
76-
*)
77-
let pstate_of_prog (hyps: hyps) (mem: memory) (proc: stmt) (invs: variable list) : (symbol, circuit) Map.t =
78-
let inps, pstate = initial_pstate_of_vars (toenv hyps) (invs) in
79-
80-
let pstate = try
81-
List.fold_left (EcCircuits.process_instr hyps mem) pstate proc.s_node
82-
with CircError err ->
83-
raise (BDepError err)
84-
in
85-
Map.map (fun c -> assert (c.inps = []); {c with inps=inps}) pstate
8660

8761

8862
(* -------------------------------------------------------------------- *)
@@ -117,7 +91,11 @@ let mapreduce
11791

11892
let tm = time tm "Precondition circuit generation done" in
11993

120-
let pstate = pstate_of_prog hyps mem proc invs in
94+
let pstate = try
95+
EcCircuits.pstate_of_prog hyps mem proc.s_node invs
96+
with CircError err ->
97+
raise (BDepError err)
98+
in
12199

122100
let tm = time tm "Program circuit generation done" in
123101

@@ -126,7 +104,7 @@ let mapreduce
126104
(List.map (fun v -> v.v_name) outvs) in
127105

128106
(* This is required for now as we do not allow mapreduce with multiple arguments *)
129-
assert (Set.cardinal @@ Set.of_list @@ List.map (fun c -> c.inps) circs = 1);
107+
(* assert (Set.cardinal @@ Set.of_list @@ List.map (fun c -> c.inps) circs = 1); *)
130108

131109
let c = try
132110
(circuit_aggregate circs)
@@ -178,9 +156,17 @@ let prog_equiv_prod
178156
in
179157
let tm = Unix.gettimeofday () in
180158

181-
let pstate_l : (symbol, circuit) Map.t = pstate_of_prog hyps meml proc_l invs_l in
159+
let pstate_l : (symbol, circuit) Map.t = try
160+
EcCircuits.pstate_of_prog hyps meml proc_l.s_node invs_l
161+
with CircError err ->
162+
raise (BDepError err)
163+
in
182164
let tm = time tm "Left program generation done" in
183-
let pstate_r : (symbol, circuit) Map.t = pstate_of_prog hyps memr proc_r invs_l in
165+
let pstate_r : (symbol, circuit) Map.t = try
166+
EcCircuits.pstate_of_prog hyps memr proc_r.s_node invs_l
167+
with CircError err ->
168+
raise (BDepError err)
169+
in
184170
let tm = time tm "Right program generation done" in
185171

186172
begin
@@ -189,14 +175,8 @@ let prog_equiv_prod
189175
let circs_r = List.map (fun v -> Option.get (Map.find_opt v pstate_r))
190176
(List.map (fun v -> v.v_name) outvs_r) in
191177

192-
(* let () = List.iter2 (fun c v -> Format.eprintf "%s inputs: " v.v_name; *)
193-
(* List.iter (Format.eprintf "%s ") (List.map cinput_to_string c.inps); *)
194-
(* Format.eprintf "@."; ) circs outvs in *)
195-
196-
(* let () = List.iter (fun c -> Format.eprintf "%s@." (circuit_to_string c)) circs in *)
197-
(* Only one input supported for now *)
198-
assert (Set.cardinal @@ Set.of_list @@ List.map (fun c -> c.inps) circs_l = 1);
199-
assert (Set.cardinal @@ Set.of_list @@ List.map (fun c -> c.inps) circs_r = 1);
178+
(*assert (Set.cardinal @@ Set.of_list @@ List.map (fun c -> c.inps) circs_l = 1); *)
179+
(*assert (Set.cardinal @@ Set.of_list @@ List.map (fun c -> c.inps) circs_r = 1);*)
200180
let c_l = try
201181
(circuit_aggregate circs_l)
202182
with CircError _err ->
@@ -263,37 +243,6 @@ let prog_equiv_prod
263243
if both sides are equivalent as circuits
264244
or false otherwise
265245
*)
266-
let rec circ_simplify_form_bitstring_equality
267-
?(mem = mhr)
268-
?(pstate: (symbol, circuit) Map.t = Map.empty)
269-
?(pcond: circuit option)
270-
?(inps: cinput list option)
271-
(hyps: hyps)
272-
(f: form)
273-
: form =
274-
let env = toenv hyps in
275-
276-
let rec check (f : form) =
277-
match sform_of_form f with
278-
| SFeq (f1, f2)
279-
when (Option.is_some @@ EcEnv.Circuit.lookup_bitstring env f1.f_ty)
280-
|| (Option.is_some @@ EcEnv.Circuit.lookup_array env f1.f_ty)
281-
->
282-
let c1 = circuit_of_form ~pstate hyps f1 in
283-
let c2 = circuit_of_form ~pstate hyps f2 in
284-
let c1, c2 = match inps with
285-
| Some inps -> {c1 with inps = inps}, {c2 with inps = inps}
286-
| None -> c1, c2
287-
in
288-
Format.eprintf "[W]Testing circuit equivalence for forms:
289-
%a@.%[email protected] circuits: %s | %s@."
290-
(EcPrinting.pp_form (EcPrinting.PPEnv.ofenv env)) f1
291-
(EcPrinting.pp_form (EcPrinting.PPEnv.ofenv env)) f2
292-
(circuit_to_string c1)
293-
(circuit_to_string c2);
294-
f_bool (circ_equiv c1 c2 pcond)
295-
| _ -> f_map (fun ty -> ty) check f
296-
in check f
297246

298247
let circ_form_eval_plus_equiv
299248
?(mem = mhr)
@@ -307,8 +256,8 @@ let circ_form_eval_plus_equiv
307256
let env = toenv hyps in
308257
let redmode = circ_red hyps in
309258
let (@@!) = EcTypesafeFol.f_app_safe env in
310-
let inps = List.map (EcCircuits.input_of_variable env) invs in
311-
let inpcs, inps = List.split inps in
259+
(*let inps = List.map (EcCircuits.input_of_variable env) invs in*)
260+
(*let inpcs, inps = List.split inps in*)
312261
let size, of_int = match EcEnv.Circuit.lookup_bitstring env v.v_type with
313262
| Some {size; ofint} -> size, ofint
314263
| None ->
@@ -322,11 +271,6 @@ let circ_form_eval_plus_equiv
322271
true
323272
else
324273
let cur_val = of_int @@! [f_int cur] in
325-
let pstate : (symbol, circuit) Map.t = Map.empty in
326-
let pstate = List.fold_left2
327-
(fun pstate inp v -> Map.add v inp pstate)
328-
pstate inpcs (invs |> List.map (fun v -> v.v_name))
329-
in
330274
let insts = List.map (fun i ->
331275
match i.i_node with
332276
| Sasgn (lv, e) ->
@@ -338,12 +282,12 @@ let circ_form_eval_plus_equiv
338282
| _ -> i
339283
) proc.s_node
340284
in
341-
let pstate = try
342-
List.fold_left (EcCircuits.process_instr hyps mem) pstate insts
343-
with CircError err ->
344-
raise (BDepError ("Program circuit generation failed with error:\n" ^ err))
285+
let pstate = try
286+
EcCircuits.pstate_of_prog hyps mem insts invs
287+
with CircError err ->
288+
raise (BDepError err)
345289
in
346-
let pstate = Map.map (fun c -> assert (c.inps = []); {c with inps=inps}) pstate in
290+
347291
let f = EcPV.PVM.subst1 env (PVloc v.v_name) mem cur_val f in
348292
let pcond = match Map.find_opt v.v_name pstate with
349293
| Some circ -> begin try
@@ -353,10 +297,10 @@ let circ_form_eval_plus_equiv
353297
end
354298
| None -> None
355299
in
356-
let () = Format.eprintf "Form before circuit simplify %a@." (EcPrinting.pp_form (EcPrinting.PPEnv.ofenv env)) f in
300+
(*let () = Format.eprintf "Form before circuit simplify %a@." (EcPrinting.pp_form (EcPrinting.PPEnv.ofenv env)) f in*)
357301
let f = EcCallbyValue.norm_cbv redmode hyps f in
358-
let () = Format.eprintf "Form after circuit simplify %a@." (EcPrinting.pp_form (EcPrinting.PPEnv.ofenv env)) f in
359-
let f = circ_simplify_form_bitstring_equality ~mem ~pstate ~inps ?pcond hyps f in
302+
(*let () = Format.eprintf "Form after circuit simplify %a@." (EcPrinting.pp_form (EcPrinting.PPEnv.ofenv env)) f in*)
303+
let f = EcCircuits.circ_simplify_form_bitstring_equality ~mem ~pstate ?pcond hyps f in
360304
let f = EcCallbyValue.norm_cbv (EcReduction.full_red) hyps f in
361305
if f <> f_true then
362306
(Format.eprintf "Got %a after reduction@." (EcPrinting.pp_form (EcPrinting.PPEnv.ofenv env)) f;
@@ -387,17 +331,19 @@ let mapreduce_eval
387331

388332
let tm = time tm "Lane function circuit generation done" in
389333

390-
let pstate = pstate_of_prog hyps mem proc invs in
334+
let pstate = try
335+
EcCircuits.pstate_of_prog hyps mem proc.s_node invs
336+
with CircError err ->
337+
raise (BDepError err)
338+
in
391339

392340
let tm = time tm "Program circuit generation done" in
393341

394342
begin
395343
let circs = List.map (fun v -> Option.get (Map.find_opt v pstate)) (List.map (fun v -> v.v_name) outvs) in
396344

397-
assert (Set.cardinal @@ Set.of_list @@ List.map (fun c -> c.inps) circs = 1);
398-
let cinp = (List.hd circs).inps in
399345
let c = try
400-
{(circuit_aggregate circs) with inps=cinp}
346+
(circuit_aggregate circs)
401347
with CircError _err ->
402348
raise (BDepError "Failed to concatenate program outputs")
403349
in
@@ -410,8 +356,6 @@ let mapreduce_eval
410356

411357
let tm = time tm "circuit dependecy analysis + splitting done" in
412358

413-
List.iter (fun c -> Format.eprintf "%s@." (circuit_to_string c)) cs;
414-
415359
List.iteri (fun i c ->
416360
if circ_equiv ~strict:true (List.hd cs) c None
417361
then ()

0 commit comments

Comments
 (0)