(* Simplifications on statements *) open Format open Ast open Cond module IntSet = Set.Make(struct type t = int let compare = compare end) (* Test equality of two statements, modulo the unification of NSchoose nodes in decisions *) let same_stmt s1 s2 = let trail = ref [] in let update_decision d new_next = trail := (d, new_next) :: !trail; d.next <- new_next in let unify_decision d1 d2 = match (d1.next, d2.next) with NSdefault, NSdefault -> true | NSfixed i, NSfixed j -> i = j | NSchoose i, NSdefault -> update_decision d1 NSdefault; true | NSchoose i, (NSfixed j as n2) -> if i = j then (update_decision d1 n2; true) else false | NSchoose i, NSchoose j -> if i <> j then begin update_decision d1 NSdefault; update_decision d2 NSdefault end; true | NSdefault, NSchoose i -> update_decision d2 NSdefault; true | (NSfixed j as n1), NSchoose i -> if i = j then (update_decision d2 n1; true) else false | (_, _) -> false in let rec unify_stmt s1 s2 = match (s1, s2) with Edecision d1, Edecision d2 -> d1.utterance = d2.utterance && unify_decision d1 d2 | Eif(c1, so1, not1), Eif(c2, so2, not2) -> c1 = c2 && unify_stmt so1 so2 && unify_stmt not1 not2 | Ecase(v1, arms1, default1), Ecase(v2, arms2, default2) -> v1 = v2 && unify_stmt default1 default2 && List.length arms1 = List.length arms2 && List.for_all2 (fun (lbl1, stmt1) (lbl2, stmt2) -> lbl1 = lbl2 && unify_stmt stmt1 stmt2) arms1 arms2 | (_, _) -> false in if unify_stmt s1 s2 then true else (List.iter (fun (d, oldnext) -> d.next <- oldnext) !trail; false) (* Check if a matching on a variable can be resolved at compile-time because the set of possible values for the variable is included in one of the arms *) let rec resolve_case env v = function [] -> raise Not_found | (lbls, stmt) :: rem -> if Env.matches_arm env v lbls then stmt else resolve_case env v rem (* Determine the total size of statements that are "duplicated" in the two (label, stmt) lists, i.e. that occur in both. *) let rec size_of_duplicates l1 l2 = match l1 with [] -> 0 (* no duplicates, obviously *) | (lbl, stmt) :: rem -> if List.exists (fun (lbl', stmt') -> stmt = stmt') l2 then size_stmt stmt + size_of_duplicates rem l2 else size_of_duplicates rem l2 (* Build a case or if from a list of (label, stmt), re-sharing identical stmts *) let build_case env v lbl_stmt_list default = (* Remove impossible values for v *) let lbl_stmt_list = List.filter (fun (lbl, stmt) -> Env.equals env v lbl <> Env.No) lbl_stmt_list in (* Factor out arms with identical actions *) let rec add_to_arms (lbl, stmt as lbl_stmt) = function [] -> [[lbl], stmt] | (lbls, stmt') :: rem as arms -> if same_stmt stmt stmt' then (lbl :: lbls, stmt') :: rem else (lbls, stmt') :: add_to_arms lbl_stmt rem in let add_to_arms_maybe (lbl, stmt as lbl_stmt) arms = if same_stmt stmt default then arms else add_to_arms lbl_stmt arms in let arms = List.fold_right add_to_arms_maybe lbl_stmt_list [] in try resolve_case env v arms with Not_found -> let stmt_case = Ecase(v, arms, default) in (* Build equivalent cascade of IF and see if it's cheaper *) let stmt_ifs = List.fold_right (fun (lbls, stmt) rem -> let cond = match lbls with [lbl] -> Cequals(v, lbl) | _ -> Cor(List.map (fun l -> Cequals(v, l)) lbls) in Eif(cond, stmt, rem)) arms default in if size_stmt stmt_ifs < size_stmt stmt_case then stmt_ifs else stmt_case (* Simplify a case statement. Factor out identical statements. Break the case into smaller cases or into ifs if beneficial. *) let optimize_case env v arms default = (* Build a list (label, stmt) *) let lbl_stmts = ref [] in List.iter (fun (lbls, stmt) -> List.iter (fun lbl -> lbl_stmts := (lbl, stmt) :: !lbl_stmts) lbls) arms; (* Sort it by increasing label *) let arms = Sort.list (fun (l1, _) (l2, _) -> l1 <= l2) !lbl_stmts in (* Split the cases into dense enough segments *) let rec split cases_before prev_lbl curr_case = function [] -> List.rev (curr_case :: cases_before) | (lbl, stmt) :: rem as arms -> (* Two possibilities: - add this arm to the current case; this increases its span from prev_lbl - start_lbl to lbl - start_lbl, i.e. the cost increases by lbl - prev_lbl. - terminate the current case and start a new one: this adds the base cost of a case (10). Moreover, if this causes duplication of some statements, this adds the costs of those statements. A statement is duplicated if it occurs both in curr_case and in arms. We evaluate the cost of both possibilities and choose the cheapest *) if lbl - prev_lbl <= 10 + size_of_duplicates curr_case arms then (* Extend the current case *) split cases_before lbl ((lbl, stmt) :: curr_case) rem else (* Start a new case *) split (curr_case :: cases_before) lbl [lbl, stmt] rem in let new_cases = match arms with [] -> [] | (lbl1, stmt1) :: rem -> split [] lbl1 [lbl1, stmt1] rem in (* Rebuild a cascade of cases *) let rec build_cases = function [] -> default | lbl_stmt_list :: rem -> build_case env v lbl_stmt_list (build_cases rem) in build_cases new_cases (* Flatten a case whose default case is a case on the same variable *) let rec flatten_case lbls_seen v arms default = match default with Ecase(v', arms', default') when v = v' -> (* Subtract the labels already seen from the arms *) let rec subtract_arms = function [] -> [] | (lbls, stmt) :: rem -> let new_lbls = List.fold_right (fun lbl lst -> if IntSet.mem lbl lbls_seen then lst else lbl :: lst) lbls [] in if new_lbls = [] then subtract_arms rem else (new_lbls, stmt) :: subtract_arms rem in (* Add the labels of the arms to lbls_seen *) let rec add_labels seen = function [] -> seen | (lbls, stmt) :: rem -> add_labels (List.fold_right IntSet.add lbls seen) rem in (* Recursively simplify the default *) let (arms'', default'') = flatten_case (add_labels lbls_seen arms) v arms' default' in (subtract_arms arms @ arms'', default'') | _ -> (arms, default) (* Recognize cascade of "if" statements that test the same variable, and try to turn them into a "case" statement if possible *) let ifs_to_case env cond ifso ifnot = let case_var = ref "" in let lbls_seen = ref IntSet.empty in let rec cond_to_lbls = function Cequals(v, n) -> if !case_var = "" then case_var := v; if !case_var <> v then raise Exit; if IntSet.mem n !lbls_seen then IntSet.empty else (lbls_seen := IntSet.add n !lbls_seen; IntSet.singleton n) | Cor [] -> IntSet.empty | Cor (c1 :: cl) -> IntSet.union (cond_to_lbls c1) (cond_to_lbls (Cor cl)) | _ -> raise Exit in let rec extract_case arms stmt = match stmt with Eif(cond, ifso, ifnot) -> begin try let lbls = cond_to_lbls cond in if IntSet.is_empty lbls then extract_case arms ifnot else extract_case ((IntSet.elements lbls, ifso) :: arms) ifnot with Exit -> (List.rev arms, stmt) end | Ecase(v', arms', default') when v' = !case_var -> flatten_case IntSet.empty v' (List.rev arms) stmt | _ -> (List.rev arms, stmt) in let if_stmt = Eif(cond, ifso, ifnot) in let (arms, default) = extract_case [] if_stmt in if IntSet.cardinal !lbls_seen <= 3 then if_stmt else begin let case_stmt = optimize_case env !case_var arms default in if size_stmt case_stmt < size_stmt if_stmt then case_stmt else if_stmt end (* Optimize cascades of "if" statements, grouping conditions with "or" and "and" if possible. Also take care of the case where both arms of the "if" are identical *) let optimize_if env cond ifso ifnot = if same_stmt ifso ifnot then ifso else begin match (ifso, ifnot) with Eif(cond1, ifso1, ifnot1), _ when same_stmt ifnot1 ifnot -> Eif(Cond.simpl env (Cand [cond; cond1]), ifso1, ifnot) | _, Eif(cond1, ifso1, ifnot1) when same_stmt ifso1 ifso -> Eif(Cond.simpl env (Cor [cond; cond1]), ifso, ifnot1) | _, _ -> ifs_to_case env cond ifso ifnot end (* Optimization of nested cases *) let rec specialize_cond v n = function Cequals(v', n') as c -> if v = v' then (if n = n' then sctrue else scfalse) else c | Cand [] -> sctrue | Cand (c1 :: cl) -> begin match (specialize_cond v n c1, specialize_cond v n (Cand cl)) with Cand [], sc -> sc | sc, Cand [] -> sc | Cor [], _ -> scfalse | _, Cor [] -> scfalse | sc1, Cand scl -> Cand(sc1::scl) | sc1, sc2 -> Cand [sc1;sc2] end | Cor [] -> scfalse | Cor (c1 :: cl) -> begin match (specialize_cond v n c1, specialize_cond v n (Cor cl)) with Cor [], sc -> sc | sc, Cor [] -> sc | Cand [], _ -> sctrue | _, Cand [] -> sctrue | sc1, Cor scl -> Cor(sc1::scl) | sc1, sc2 -> Cor [sc1;sc2] end let rec specialize_stmt env v n = function Eif(cond, ifso, ifnot) -> begin match specialize_cond v n cond with Cand [] -> specialize_stmt env v n ifso | Cor [] -> specialize_stmt env v n ifnot | c -> let sifso = specialize_stmt env v n ifso and sifnot = specialize_stmt env v n ifnot in if same_stmt sifso sifnot then sifso else Eif(c, sifso, sifnot) end | Edecision d as s -> s | Ecase(v', arms, default) -> if v = v' then begin let rec find_arm = function [] -> specialize_stmt env v n default | (lbls, stmt) :: rem -> if List.mem n lbls then specialize_stmt env v n stmt else find_arm rem in find_arm arms end else optimize_case env v' (List.map (fun (lbls, stmt) -> (lbls, specialize_stmt env v n stmt)) arms) (specialize_stmt env v n default) let case_of_case env v nl s = optimize_case env v (List.map (fun n -> ([n], specialize_stmt env v n s)) nl) (specialize_stmt env v max_int s) (* Determine all values "of interest" of the given variable in the given statement *) let interesting_values v stmt = let intval = ref IntSet.empty in let rec int_cond = function Cequals(v', n) -> if v = v' then intval := IntSet.add n !intval | Cor cl -> List.iter int_cond cl | Cand cl -> List.iter int_cond cl in let rec int_stmt = function Eif(cond, ifso, ifnot) -> int_cond cond; int_stmt ifso; int_stmt ifnot | Edecision d -> () | Ecase(v', arms, default) -> if v = v' then List.iter (fun (lbls, stmt) -> List.iter (fun n -> intval := IntSet.add n !intval) lbls) arms; List.iter (fun (lbls, stmt) -> int_stmt stmt) arms; int_stmt default in int_stmt stmt; IntSet.elements !intval let optimize_case_or_nested_cases env v arms default = let normal_case = optimize_case env v arms default in (* Test if all arms are either non-cases or cases on the same variable *) let splitvar = ref "" in let check_arm (lbls, stmt) = match stmt with Ecase(v', arms', default') -> if !splitvar = "" then splitvar := v'; if !splitvar <> v' then raise Exit | _ -> () in try List.iter check_arm arms; if !splitvar = "" then raise Exit; (*** print_string "*** Considering splitting for:"; print_newline(); pretty_stmt normal_case; print_newline(); print_string "*** Splitting on "; print_string !splitvar; print_newline(); print_string "Interesting values are : "; List.iter (fun n -> print_int n; print_string " ") (IntSet.elements !splitvals); print_newline(); ***) let int_val = Env.possible_values env !splitvar (interesting_values !splitvar normal_case) in let split_case = case_of_case env !splitvar int_val normal_case in (*** print_string "*** Splitting result:"; print_newline(); pretty_stmt split_case; print_newline(); ***) if size_stmt split_case < size_stmt normal_case then split_case else normal_case with Exit -> normal_case (* Simplification of a statement *) let rec simpl env = function Eif(cond, ifso, ifnot) -> begin match Cond.simpl env cond with Cor [] (* false *) -> simpl env ifnot | Cand [] (* true *) -> simpl env ifso | cond' -> let (envso, envnot) = Env.refine_cond env cond' in optimize_if env cond' (simpl envso ifso) (simpl envnot ifnot) end | Edecision d -> begin try let n = Env.value_of env "state" in match d.next with NSdefault -> Edecision {d with next = NSchoose n} | NSfixed n' when n = n' -> Edecision {d with next = NSchoose n} | _ -> Edecision d with Not_found -> Edecision d end | Ecase(v, arms, default) -> let (arms, default) = flatten_case IntSet.empty v arms default in (*** print_string "***"; print_newline(); Env.pretty (Env.exclude_arms env v arms); ***) optimize_case_or_nested_cases env v (simpl_arms env v arms) (simpl (Env.exclude_arms env v arms) default) and simpl_arms env v = function [] -> [] | (lbls, stmt) :: rem -> let lbls' = Env.refine_arm env v lbls in if lbls' = [] then simpl_arms env v rem else (lbls', simpl (Env.set_possible_values env v lbls) stmt) :: simpl_arms env v rem