(* Basic simplifications *) (* open Ast *) open Syntax open Cout (* Simplifications of conditions *) let sctrue = And [] let scfalse = Or [] let mkor = function [c] -> c | cl -> Or cl let mkand = function [c] -> c | cl -> And cl (* Logical implication and equivalence *) let rec implies c1 c2 = match (c1, c2) with (Equals(v1, n1), Equals(v2, n2)) -> v1 = v2 && n1 = n2 | (_, Or cl) -> List.exists (fun c -> implies c1 c) cl | (Or cl, _) -> List.for_all (fun c -> implies c c2) cl | (_, And cl) -> List.for_all (fun c -> implies c1 c) cl | (And cl, _) -> List.exists (fun c -> implies c c2) cl let equiv c1 c2 = implies c1 c2 && implies c2 c1 (* Logical contradiction *) let rec contradicts c1 c2 = match (c1, c2) with (Equals(v1, n1), Equals(v2, n2)) -> v1 = v2 && n1 <> n2 | (_, And cl) -> List.exists (fun c -> contradicts c1 c) cl | (And cl, _) -> List.exists (fun c -> contradicts c c2) cl | (_, Or cl) -> List.for_all (fun c -> contradicts c1 c) cl | (Or cl, _) -> List.for_all (fun c -> contradicts c c2) cl (* In an "and" list, remove all elements that are implied by other elements *) let rec remove_implied_and = function [] -> [] | c1 :: rem -> let rem' = remove_implied_and rem in if List.exists (fun c2 -> implies c2 c1) rem' then rem' else c1 :: List.filter (fun c2 -> not(implies c1 c2)) rem' (* In an "or" list, remove all elements that imply other elements *) let rec remove_implied_or = function [] -> [] | c1 :: rem -> let rem' = remove_implied_or rem in if List.exists (fun c2 -> implies c1 c2) rem' then rem' else c1 :: List.filter (fun c2 -> not(implies c2 c1)) rem' (* In an "and" list, detect contradictions *) let rec contradiction_and = function [] -> false | c1 :: rem -> List.exists (fun c -> contradicts c1 c) rem || contradiction_and rem let check_contradiction = function And cl when contradiction_and cl -> scfalse | c -> c (* In an "or" list consisting entirely of "and"s, try to factor common clauses: (A & B) | (A & C) | ... = A & (B | C | ...) *) let rec intersect c1 c2 = match c1 with [] -> [] | c :: cl -> if List.exists (equiv c) c2 then c :: intersect cl c2 else intersect cl c2 let rec subtract c1 c2 = match c1 with [] -> [] | c :: cl -> if List.exists (equiv c) c2 then subtract cl c2 else c :: subtract cl c2 let factor_or = function (And cl1) :: (_ :: _ as crem) as cl -> begin try let common = ref cl1 in List.iter (function And l -> common := intersect l !common | _ -> raise Exit) crem; if !common = [] then raise Exit; let newcl = List.map (function And l -> mkand (subtract l !common) | _ -> raise Exit) cl in mkand (mkor newcl :: !common) with Exit -> mkor cl end | cl -> mkor cl (* In an "and" list consisting entirely of "or"s, try to factor common clauses: (A | B) & (A | C) & ... = A | (B & C & ...) *) let factor_and = function (Or cl1) :: (_ :: _ as crem) as cl -> begin try let common = ref cl1 in List.iter (function Or l -> common := intersect l !common | _ -> raise Exit) crem; if !common = [] then raise Exit; let newcl = List.map (function Or l -> mkor(subtract l !common) | _ -> raise Exit) cl in mkor (mkand newcl :: !common) with Exit -> mkand cl end | cl -> mkand cl (* Flatten ands within ands / ors within ors *) let rec flatten_ands = function [] -> [] | And cl' :: cl -> cl' @ flatten_ands cl | c :: cl -> c :: flatten_ands cl let rec flatten_ors = function [] -> [] | Or cl' :: cl -> cl' @ flatten_ors cl | c :: cl -> c :: flatten_ors cl let rec simpl_cnd = function Equals(_, _) as c -> c | And cl -> check_contradiction(factor_and(remove_implied_and(flatten_ands (List.map simpl_cnd cl)))) | Or cl -> check_contradiction(factor_or(remove_implied_or(flatten_ors (List.map simpl_cnd cl)))) (* Simplification of conditions may cause impossible "and"s to appear, so iterate simpl_cnd a few times *) let simpl_cond c = let c1 = simpl_cnd c in if c1 = c then c else let c2 = simpl_cnd c1 in if c2 = c1 then c1 else simpl_cnd c2 (* The list of all possible states *) let possible_current_states = ref ([] : int list);; (* Check if a matching on "state" always succeeds because the list of possible states is included in the list of label cases *) let case_state_succeeds lbls = List.for_all (fun l -> List.mem l lbls) !possible_current_states let rec simplify_case_state = function [] -> raise Not_found | (lbls, stmt) :: rem -> if case_state_succeeds lbls then stmt else simplify_case_state rem (* Determine the total size of statements that are "duplicated" in the two (label, stmt) lists, i.e. that occur in both. *) let rec size_of_duplicates l1 l2 = match l1 with [] -> 0 (* no duplicates, obviously *) | (lbl, stmt) :: rem -> if List.exists (fun (lbl', stmt') -> stmt = stmt') l2 then cout_mem_st stmt + size_of_duplicates rem l2 else size_of_duplicates rem l2 let arms_from_x = List.map (function vl,st -> Arm(vl,st)) let arms_to_x = List.map (function (Arm(vl,st)) -> vl,st) (* Build a case or if from a list of (label, stmt), re-sharing identical stmts *) let build_case v lbl_stmt_list default = let rec add_to_arms (lbl, stmt as lbl_stmt) = function [] -> [[lbl], stmt] | (lbls, stmt') :: rem as arms -> if stmt = stmt' then (lbl :: lbls, stmt') :: rem else (lbls, stmt') :: add_to_arms lbl_stmt rem in let add_to_arms_maybe (lbl, stmt as lbl_stmt) arms = if stmt = default then arms else add_to_arms lbl_stmt arms in let arms = List.fold_right add_to_arms_maybe lbl_stmt_list [] in let stmt_case = Case(v, arms_from_x arms, default) in (* Build equivalent cascade of IF and see if it's cheaper *) let stmt_ifs = List.fold_right (fun (lbls, stmt) rem -> let cond = match lbls with [lbl] -> Equals(v, lbl) | _ -> Or(List.map (fun l -> Equals(v, l)) lbls) in If(cond, stmt, [], rem)) arms default in if cout_mem_st stmt_ifs < cout_mem_st stmt_case then stmt_ifs else stmt_case (* Simplify a case statement. Factor out identical statements. Break the case into smaller cases or into ifs if beneficial. *) let optimize_case v arms default = (* Build a list (label, stmt) *) let lbl_stmts = ref [] in List.iter (fun (Arm (lbls, stmt)) -> List.iter (fun lbl -> lbl_stmts := (lbl, stmt) :: !lbl_stmts) lbls) arms; (* Sort it by increasing label *) let arms = Sort.list (fun (l1, _) (l2, _) -> l1 <= l2) !lbl_stmts in (* Split the cases into dense enough segments *) let rec split cases_before prev_lbl curr_case = function [] -> List.rev (curr_case :: cases_before) | (lbl, stmt) :: rem as arms -> (* Two possibilities: - add this arm to the current case; this increases its span from prev_lbl - start_lbl to lbl - start_lbl, i.e. the cost increases by lbl - prev_lbl. - terminate the current case and start a new one: this adds the base cost of a case (10). Moreover, if this causes duplication of some statements, this adds the costs of those statements. A statement is duplicated if it occurs both in curr_case and in arms. We evaluate the cost of both possibilities and choose the cheapest *) if lbl - prev_lbl <= 10 + size_of_duplicates curr_case arms then (* Extend the current case *) split cases_before lbl ((lbl, stmt) :: curr_case) rem else (* Start a new case *) split (curr_case :: cases_before) lbl [lbl, stmt] rem in let new_cases = match arms with [] -> [] | (lbl1, stmt1) :: rem -> split [] lbl1 [lbl1, stmt1] rem in (* Rebuild a cascade of cases *) let rec build_cases = function [] -> default | lbl_stmt_list :: rem -> build_case v lbl_stmt_list (build_cases rem) in build_cases new_cases (* Optimize cascades of "if" statements, grouping conditions with "or" and "and" if possible. Also take care of the case where both arms of the "if" are identical *) let optimize_if cond ifso ifnot = if ifso = ifnot then ifso else begin match (ifso, ifnot) with If(cond1, ifso1,[], ifnot1), _ when ifnot1 = ifnot -> If(simpl_cond(And [cond; cond1]), ifso1,[], ifnot) | _, If(cond1, ifso1,[], ifnot1) when ifso1 = ifso -> If(simpl_cond(Or [cond; cond1]), ifso,[], ifnot1) | _, _ -> If(cond, ifso,[], ifnot) end (* Optimization of nested cases *) let rec specialize_cond v n = function Equals(v', n') as c -> if v = v' then (if n = n' then sctrue else scfalse) else c | And [] -> sctrue | And (c1 :: cl) -> begin match (specialize_cond v n c1, specialize_cond v n (And cl)) with And [], sc -> sc | sc, And [] -> sc | Or [], _ -> scfalse | _, Or [] -> scfalse | sc1, And scl -> And(sc1::scl) | sc1, sc2 -> And [sc1;sc2] end | Or [] -> scfalse | Or (c1 :: cl) -> begin match (specialize_cond v n c1, specialize_cond v n (Or cl)) with Or [], sc -> sc | sc, Or [] -> sc | And [], _ -> sctrue | _, And [] -> sctrue | sc1, Or scl -> Or(sc1::scl) | sc1, sc2 -> Or [sc1;sc2] end (* let rec specialize_stmt v n = function If(cond, ifso,[], ifnot) -> begin match specialize_cond v n cond with And [] -> specialize_stmt v n ifso | Or [] -> specialize_stmt v n ifnot | c -> let sifso = specialize_stmt v n ifso and sifnot = specialize_stmt v n ifnot in if sifso = sifnot then sifso else If(c, sifso,[], sifnot) end | Decision(_, _) as s -> s | Case(v', arms, default) -> if v = v' then begin let rec find_arm = function [] -> specialize_stmt v n default | (lbls, stmt) :: rem -> if List.mem n lbls then specialize_stmt v n stmt else find_arm rem in find_arm (arms_to_x arms) end else optimize_case v' (List.map (fun (lbls, stmt) -> (lbls, specialize_stmt v n stmt)) (arms_to_x arms)) (specialize_stmt v n default) let case_of_case v nl s = optimize_case v (List.map (fun n -> ([n], specialize_stmt v n s)) nl) (specialize_stmt v (-1) s) module IntSet = Set.Make(struct type t = int let compare = compare end) let optimize_case_or_nested_cases v arms default = let normal_case = optimize_case v arms default in (* Test if all arms are either non-cases or cases on the same variable *) let splitvar = ref "" in let splitvals = ref IntSet.empty in let check_arm (lbls, stmt) = match stmt with Case(v', arms', default') -> if !splitvar = "" then splitvar := v'; if !splitvar <> v' then raise Exit; List.iter (fun (lbls, stmt) -> List.iter (fun lbl -> splitvals := IntSet.add lbl !splitvals) lbls) (arms_to_x arms') | _ -> () in try List.iter check_arm arms; if !splitvar = "" then raise Exit; (*** print_string "*** Considering splitting for:"; print_newline(); pretty_stmt normal_case; print_newline(); print_string "*** Splitting on "; print_string !splitvar; print_newline(); print_string "Interesting values are : "; List.iter (fun n -> print_int n; print_string " ") (IntSet.elements !splitvals); print_newline(); ***) let split_case = case_of_case !splitvar (IntSet.elements !splitvals) normal_case in (*** print_string "*** Splitting result:"; print_newline(); pretty_stmt split_case; print_newline(); ***) if cout_mem_st split_case < cout_mem_st normal_case then split_case else normal_case with Exit -> normal_case (* Flatten a case whose default case is a case on the same variable *) let rec flatten_case lbls_seen v arms default = match default with Case(v', arms', default') when v = v' -> (* Subtract the labels already seen from the arms *) let rec subtract_arms = function [] -> [] | (lbls, stmt) :: rem -> let new_lbls = List.fold_right (fun lbl lst -> if IntSet.mem lbl lbls_seen then lst else lbl :: lst) lbls [] in if new_lbls = [] then subtract_arms rem else (new_lbls, stmt) :: subtract_arms rem in (* Add the labels of the arms to lbls_seen *) let rec add_labels seen = function [] -> seen | (lbls, stmt) :: rem -> add_labels (List.fold_right IntSet.add lbls seen) rem in (* Recursively simplify the default *) let (arms'', default'') = flatten_case (add_labels lbls_seen arms) v (arms_to_x arms') default' in (subtract_arms arms @ arms'', default'') | _ -> (arms, default) (* Simplification of a statement *) let rec simpl_stmt = function If(cond, ifso,[], ifnot) -> begin match simpl_cond cond with Or [] (* false *) -> simpl_stmt ifnot | And [] (* true *) -> simpl_stmt ifso | cond' -> optimize_if cond' (simpl_stmt ifso) (simpl_stmt ifnot) end (** | Decision(Some n, utt) when !possible_current_states = [n] -> Decision(None, utt) **) | Decision(_, _) as stmt -> stmt | Case(v, arms, default) -> let (arms, default) = flatten_case IntSet.empty v (arms_to_x arms) default in optimize_case_or_nested_cases v (List.map (fun (lbls, stmt) -> (lbls, simpl_stmt stmt)) arms) (simpl_stmt default) (* Conversions between characters and statements *) let stmt_of_character c = Case("state", c, (Decision(None, ""))) let all_states c = let allstates = ref IntSet.empty in List.iter (fun (lbls, stmt) -> List.iter (fun lbl -> allstates := IntSet.add lbl !allstates) lbls) c; IntSet.elements !allstates let character_of_stmt allstates s = [allstates, s] (* Simplification of a character *) open Format let char_from_x = List.map (function il,st -> Rule(il,st)) let simpl_character c = let c' = List.map (fun (states, stmt) -> possible_current_states := states; (states, simpl_stmt stmt)) c in (*** print_string "*** First simplification:"; print_newline(); pretty_character c'; print_newline(); ***) let rec add_rule (lbls, stmt as lbls_stmt) = function [] -> [lbls_stmt] | (lbls', stmt' as lbls_stmt') :: rem -> if stmt' = stmt then (lbls @ lbls', stmt) :: rem else lbls_stmt' :: add_rule lbls_stmt rem in let normal_character = List.fold_right add_rule (List.rev c') [] in (*** print_string "*** Normal character:"; print_newline(); pretty_character normal_character; print_newline(); ***) (* Now try to simplify the character as a whole statement *) possible_current_states := all_states normal_character; let stmt_character = character_of_stmt !possible_current_states (simpl_stmt (stmt_of_character (arms_from_x normal_character))) in (*** print_string "*** One-statement character:"; print_newline(); pretty_character stmt_character; print_newline(); ***) (* Now choose the smallest of the three characters *) let c' = if cout_mem_character (char_from_x stmt_character) < cout_mem_character (char_from_x normal_character) then stmt_character else normal_character in if cout_mem_character (char_from_x c') < cout_mem_character (char_from_x c) then c' else c *)