(* Basic simplifications *)
(* open Ast *)
open Syntax
open Cout
(* Simplifications of conditions *)
let sctrue = And []
let scfalse = Or []
let mkor = function [c] -> c | cl -> Or cl
let mkand = function [c] -> c | cl -> And cl
(* Logical implication and equivalence *)
let rec implies c1 c2 =
match (c1, c2) with
(Equals(v1, n1), Equals(v2, n2)) -> v1 = v2 && n1 = n2
| (_, Or cl) -> List.exists (fun c -> implies c1 c) cl
| (Or cl, _) -> List.for_all (fun c -> implies c c2) cl
| (_, And cl) -> List.for_all (fun c -> implies c1 c) cl
| (And cl, _) -> List.exists (fun c -> implies c c2) cl
let equiv c1 c2 = implies c1 c2 && implies c2 c1
(* Logical contradiction *)
let rec contradicts c1 c2 =
match (c1, c2) with
(Equals(v1, n1), Equals(v2, n2)) -> v1 = v2 && n1 <> n2
| (_, And cl) -> List.exists (fun c -> contradicts c1 c) cl
| (And cl, _) -> List.exists (fun c -> contradicts c c2) cl
| (_, Or cl) -> List.for_all (fun c -> contradicts c1 c) cl
| (Or cl, _) -> List.for_all (fun c -> contradicts c c2) cl
(* In an "and" list, remove all elements that are implied by other elements *)
let rec remove_implied_and = function
[] -> []
| c1 :: rem ->
let rem' = remove_implied_and rem in
if List.exists (fun c2 -> implies c2 c1) rem'
then rem'
else c1 :: List.filter (fun c2 -> not(implies c1 c2)) rem'
(* In an "or" list, remove all elements that imply other elements *)
let rec remove_implied_or = function
[] -> []
| c1 :: rem ->
let rem' = remove_implied_or rem in
if List.exists (fun c2 -> implies c1 c2) rem'
then rem'
else c1 :: List.filter (fun c2 -> not(implies c2 c1)) rem'
(* In an "and" list, detect contradictions *)
let rec contradiction_and = function
[] -> false
| c1 :: rem ->
List.exists (fun c -> contradicts c1 c) rem || contradiction_and rem
let check_contradiction = function
And cl when contradiction_and cl -> scfalse
| c -> c
(* In an "or" list consisting entirely of "and"s, try to factor common
clauses:
(A & B) | (A & C) | ... = A & (B | C | ...) *)
let rec intersect c1 c2 =
match c1 with
[] -> []
| c :: cl ->
if List.exists (equiv c) c2
then c :: intersect cl c2
else intersect cl c2
let rec subtract c1 c2 =
match c1 with
[] -> []
| c :: cl ->
if List.exists (equiv c) c2
then subtract cl c2
else c :: subtract cl c2
let factor_or = function
(And cl1) :: (_ :: _ as crem) as cl ->
begin try
let common = ref cl1 in
List.iter
(function And l -> common := intersect l !common
| _ -> raise Exit)
crem;
if !common = [] then raise Exit;
let newcl =
List.map (function And l -> mkand (subtract l !common)
| _ -> raise Exit)
cl in
mkand (mkor newcl :: !common)
with Exit ->
mkor cl
end
| cl -> mkor cl
(* In an "and" list consisting entirely of "or"s, try to factor common
clauses:
(A | B) & (A | C) & ... = A | (B & C & ...) *)
let factor_and = function
(Or cl1) :: (_ :: _ as crem) as cl ->
begin try
let common = ref cl1 in
List.iter
(function Or l -> common := intersect l !common
| _ -> raise Exit)
crem;
if !common = [] then raise Exit;
let newcl =
List.map (function Or l -> mkor(subtract l !common)
| _ -> raise Exit)
cl in
mkor (mkand newcl :: !common)
with Exit ->
mkand cl
end
| cl -> mkand cl
(* Flatten ands within ands / ors within ors *)
let rec flatten_ands = function
[] -> []
| And cl' :: cl -> cl' @ flatten_ands cl
| c :: cl -> c :: flatten_ands cl
let rec flatten_ors = function
[] -> []
| Or cl' :: cl -> cl' @ flatten_ors cl
| c :: cl -> c :: flatten_ors cl
let rec simpl_cnd = function
Equals(_, _) as c -> c
| And cl ->
check_contradiction(factor_and(remove_implied_and(flatten_ands
(List.map simpl_cnd cl))))
| Or cl ->
check_contradiction(factor_or(remove_implied_or(flatten_ors
(List.map simpl_cnd cl))))
(* Simplification of conditions may cause impossible "and"s to appear,
so iterate simpl_cnd a few times *)
let simpl_cond c =
let c1 = simpl_cnd c in
if c1 = c then c else
let c2 = simpl_cnd c1 in
if c2 = c1 then c1 else
simpl_cnd c2
(* The list of all possible states *)
let possible_current_states = ref ([] : int list);;
(* Check if a matching on "state" always succeeds because the
list of possible states is included in the list of label cases *)
let case_state_succeeds lbls =
List.for_all (fun l -> List.mem l lbls) !possible_current_states
let rec simplify_case_state = function
[] -> raise Not_found
| (lbls, stmt) :: rem ->
if case_state_succeeds lbls then stmt else simplify_case_state rem
(* Determine the total size of statements that are "duplicated" in the
two (label, stmt) lists, i.e. that occur in both. *)
let rec size_of_duplicates l1 l2 =
match l1 with
[] -> 0 (* no duplicates, obviously *)
| (lbl, stmt) :: rem ->
if List.exists (fun (lbl', stmt') -> stmt = stmt') l2
then cout_mem_st stmt + size_of_duplicates rem l2
else size_of_duplicates rem l2
let arms_from_x = List.map (function vl,st -> Arm(vl,st))
let arms_to_x = List.map (function (Arm(vl,st)) -> vl,st)
(* Build a case or if from a list of (label, stmt),
re-sharing identical stmts *)
let build_case v lbl_stmt_list default =
let rec add_to_arms (lbl, stmt as lbl_stmt) = function
[] -> [[lbl], stmt]
| (lbls, stmt') :: rem as arms ->
if stmt = stmt' then (lbl :: lbls, stmt') :: rem
else (lbls, stmt') :: add_to_arms lbl_stmt rem in
let add_to_arms_maybe (lbl, stmt as lbl_stmt) arms =
if stmt = default then arms else add_to_arms lbl_stmt arms in
let arms = List.fold_right add_to_arms_maybe lbl_stmt_list [] in
let stmt_case = Case(v, arms_from_x arms, default) in
(* Build equivalent cascade of IF and see if it's cheaper *)
let stmt_ifs =
List.fold_right
(fun (lbls, stmt) rem ->
let cond =
match lbls with
[lbl] -> Equals(v, lbl)
| _ -> Or(List.map (fun l -> Equals(v, l)) lbls) in
If(cond, stmt, [], rem))
arms default in
if cout_mem_st stmt_ifs < cout_mem_st stmt_case
then stmt_ifs else stmt_case
(* Simplify a case statement.
Factor out identical statements.
Break the case into smaller cases or into ifs if beneficial. *)
let optimize_case v arms default =
(* Build a list (label, stmt) *)
let lbl_stmts = ref [] in
List.iter
(fun (Arm (lbls, stmt)) ->
List.iter (fun lbl -> lbl_stmts := (lbl, stmt) :: !lbl_stmts) lbls)
arms;
(* Sort it by increasing label *)
let arms = Sort.list (fun (l1, _) (l2, _) -> l1 <= l2) !lbl_stmts in
(* Split the cases into dense enough segments *)
let rec split cases_before prev_lbl curr_case = function
[] -> List.rev (curr_case :: cases_before)
| (lbl, stmt) :: rem as arms ->
(* Two possibilities:
- add this arm to the current case; this increases its
span from prev_lbl - start_lbl to lbl - start_lbl,
i.e. the cost increases by lbl - prev_lbl.
- terminate the current case and start a new one:
this adds the base cost of a case (10).
Moreover, if this causes duplication of some statements,
this adds the costs of those statements. A statement is
duplicated if it occurs both in curr_case and in arms.
We evaluate the cost of both possibilities and choose the
cheapest *)
if lbl - prev_lbl <= 10 + size_of_duplicates curr_case arms
then
(* Extend the current case *)
split cases_before lbl ((lbl, stmt) :: curr_case) rem
else
(* Start a new case *)
split (curr_case :: cases_before) lbl [lbl, stmt] rem in
let new_cases =
match arms with
[] -> []
| (lbl1, stmt1) :: rem -> split [] lbl1 [lbl1, stmt1] rem in
(* Rebuild a cascade of cases *)
let rec build_cases = function
[] -> default
| lbl_stmt_list :: rem ->
build_case v lbl_stmt_list (build_cases rem) in
build_cases new_cases
(* Optimize cascades of "if" statements, grouping conditions with
"or" and "and" if possible. Also take care of the case where both
arms of the "if" are identical *)
let optimize_if cond ifso ifnot =
if ifso = ifnot then ifso else
begin match (ifso, ifnot) with
If(cond1, ifso1,[], ifnot1), _ when ifnot1 = ifnot ->
If(simpl_cond(And [cond; cond1]), ifso1,[], ifnot)
| _, If(cond1, ifso1,[], ifnot1) when ifso1 = ifso ->
If(simpl_cond(Or [cond; cond1]), ifso,[], ifnot1)
| _, _ ->
If(cond, ifso,[], ifnot)
end
(* Optimization of nested cases *)
let rec specialize_cond v n = function
Equals(v', n') as c ->
if v = v' then (if n = n' then sctrue else scfalse) else c
| And [] -> sctrue
| And (c1 :: cl) ->
begin match (specialize_cond v n c1, specialize_cond v n (And cl)) with
And [], sc -> sc
| sc, And [] -> sc
| Or [], _ -> scfalse
| _, Or [] -> scfalse
| sc1, And scl -> And(sc1::scl)
| sc1, sc2 -> And [sc1;sc2]
end
| Or [] -> scfalse
| Or (c1 :: cl) ->
begin match (specialize_cond v n c1, specialize_cond v n (Or cl)) with
Or [], sc -> sc
| sc, Or [] -> sc
| And [], _ -> sctrue
| _, And [] -> sctrue
| sc1, Or scl -> Or(sc1::scl)
| sc1, sc2 -> Or [sc1;sc2]
end
(*
let rec specialize_stmt v n = function
If(cond, ifso,[], ifnot) ->
begin match specialize_cond v n cond with
And [] -> specialize_stmt v n ifso
| Or [] -> specialize_stmt v n ifnot
| c ->
let sifso = specialize_stmt v n ifso
and sifnot = specialize_stmt v n ifnot in
if sifso = sifnot then sifso else If(c, sifso,[], sifnot)
end
| Decision(_, _) as s -> s
| Case(v', arms, default) ->
if v = v' then begin
let rec find_arm = function
[] -> specialize_stmt v n default
| (lbls, stmt) :: rem ->
if List.mem n lbls
then specialize_stmt v n stmt
else find_arm rem in
find_arm (arms_to_x arms)
end else
optimize_case v'
(List.map (fun (lbls, stmt) -> (lbls, specialize_stmt v n stmt))
(arms_to_x arms))
(specialize_stmt v n default)
let case_of_case v nl s =
optimize_case v
(List.map (fun n -> ([n], specialize_stmt v n s)) nl)
(specialize_stmt v (-1) s)
module IntSet = Set.Make(struct type t = int let compare = compare end)
let optimize_case_or_nested_cases v arms default =
let normal_case =
optimize_case v arms default in
(* Test if all arms are either non-cases or cases on the same variable *)
let splitvar = ref "" in
let splitvals = ref IntSet.empty in
let check_arm (lbls, stmt) =
match stmt with
Case(v', arms', default') ->
if !splitvar = "" then splitvar := v';
if !splitvar <> v' then raise Exit;
List.iter
(fun (lbls, stmt) ->
List.iter (fun lbl -> splitvals := IntSet.add lbl !splitvals) lbls)
(arms_to_x arms')
| _ -> () in
try
List.iter check_arm arms;
if !splitvar = "" then raise Exit;
(***
print_string "*** Considering splitting for:"; print_newline();
pretty_stmt normal_case; print_newline();
print_string "*** Splitting on "; print_string !splitvar;
print_newline();
print_string "Interesting values are : ";
List.iter (fun n -> print_int n; print_string " ")
(IntSet.elements !splitvals); print_newline();
***)
let split_case =
case_of_case !splitvar (IntSet.elements !splitvals) normal_case in
(***
print_string "*** Splitting result:"; print_newline();
pretty_stmt split_case; print_newline();
***) if cout_mem_st split_case < cout_mem_st normal_case
then split_case else normal_case
with Exit ->
normal_case
(* Flatten a case whose default case is a case on the same variable *)
let rec flatten_case lbls_seen v arms default =
match default with
Case(v', arms', default') when v = v' ->
(* Subtract the labels already seen from the arms *)
let rec subtract_arms = function
[] -> []
| (lbls, stmt) :: rem ->
let new_lbls =
List.fold_right
(fun lbl lst ->
if IntSet.mem lbl lbls_seen then lst else lbl :: lst)
lbls [] in
if new_lbls = [] then subtract_arms rem
else (new_lbls, stmt) :: subtract_arms rem in
(* Add the labels of the arms to lbls_seen *)
let rec add_labels seen = function
[] -> seen
| (lbls, stmt) :: rem ->
add_labels (List.fold_right IntSet.add lbls seen) rem in
(* Recursively simplify the default *)
let (arms'', default'') =
flatten_case (add_labels lbls_seen arms) v (arms_to_x arms') default' in
(subtract_arms arms @ arms'', default'')
| _ ->
(arms, default)
(* Simplification of a statement *)
let rec simpl_stmt = function
If(cond, ifso,[], ifnot) ->
begin match simpl_cond cond with
Or [] (* false *) -> simpl_stmt ifnot
| And [] (* true *) -> simpl_stmt ifso
| cond' -> optimize_if cond' (simpl_stmt ifso) (simpl_stmt ifnot)
end
(**
| Decision(Some n, utt) when !possible_current_states = [n] ->
Decision(None, utt)
**)
| Decision(_, _) as stmt -> stmt
| Case(v, arms, default) ->
let (arms, default) = flatten_case IntSet.empty v (arms_to_x arms) default in
optimize_case_or_nested_cases v
(List.map (fun (lbls, stmt) -> (lbls, simpl_stmt stmt))
arms)
(simpl_stmt default)
(* Conversions between characters and statements *)
let stmt_of_character c = Case("state", c, (Decision(None, "")))
let all_states c =
let allstates = ref IntSet.empty in
List.iter
(fun (lbls, stmt) ->
List.iter (fun lbl -> allstates := IntSet.add lbl !allstates) lbls)
c;
IntSet.elements !allstates
let character_of_stmt allstates s =
[allstates, s]
(* Simplification of a character *)
open Format
let char_from_x = List.map (function il,st -> Rule(il,st))
let simpl_character c =
let c' =
List.map
(fun (states, stmt) ->
possible_current_states := states;
(states, simpl_stmt stmt))
c in
(***
print_string "*** First simplification:"; print_newline();
pretty_character c'; print_newline();
***)
let rec add_rule (lbls, stmt as lbls_stmt) = function
[] -> [lbls_stmt]
| (lbls', stmt' as lbls_stmt') :: rem ->
if stmt' = stmt then (lbls @ lbls', stmt) :: rem
else lbls_stmt' :: add_rule lbls_stmt rem in
let normal_character = List.fold_right add_rule (List.rev c') [] in
(***
print_string "*** Normal character:"; print_newline();
pretty_character normal_character; print_newline();
***)
(* Now try to simplify the character as a whole statement *)
possible_current_states := all_states normal_character;
let stmt_character =
character_of_stmt !possible_current_states
(simpl_stmt (stmt_of_character (arms_from_x normal_character))) in
(***
print_string "*** One-statement character:"; print_newline();
pretty_character stmt_character; print_newline();
***)
(* Now choose the smallest of the three characters *)
let c' =
if cout_mem_character (char_from_x stmt_character) < cout_mem_character (char_from_x normal_character)
then stmt_character
else normal_character in
if cout_mem_character (char_from_x c') < cout_mem_character (char_from_x c) then c' else c
*)