More complete Map library for OCaml

From: Christophe Raffalli (raffalli@logique.jussieu.fr)
Date: Wed Aug 27 1997 - 11:55:22 MET DST


Date: Wed, 27 Aug 1997 11:55:22 +0200 (MET DST)
Message-Id: <199708270955.LAA05515@boole.logique.jussieu.fr>
From: Christophe Raffalli <raffalli@logique.jussieu.fr>
To: caml-list@margaux.inria.fr
Subject: More complete Map library for OCaml

Version Anglaise:

I needed an implementation of multisets, and I found that they were
missing functions in the Map module to implement multisets of elements of
X as a map from X to integers. The needed functions can be programmed
using Map.fold but you loose the performance of the well balanced tree
representation ! The modified map.ml and map.mli follows and the
multiset example too.

The extra functions are Map.union, Map.inter, Map.diff, Map.compare
and Map.equal. But to give a meaning to these functions, you
need an extra argument to decide what to do when a binding is present
in both maps. You can look at the file map.mli for more details.
 
I also added the useful Map.map function.

- ---

French version:

J'ai eu besoin d'une implantation des multi-ensembles et je me suis
apercu que certaines fonctions manquaient dans la librairie Map pour
implanter un multi-ensemble d'elements de X comme une "map" de X dans
les entiers. Les fonctions necessaires peuvent etre ecrites a l'aide de
Map.fold mais on perd la performance de la represeantion en arbre bien
balancee. Les fichiers modifies map.ml et map.mli suivent ainsi que
l'exemple des multi-ensembles.
 
Les fonctions suplementaires sont Map.union, Map.inter, Map.diff, Map.compare
et Map.equal. Mais pour donner un sens a ces fonctions, il faut ajouter un
argument suplementaire pour decider quoi faire lorsque la meme liaison est
presente dans les deux Maps. Pour plus de details, voir le fichier map.mli.

J'ai aussi ajouter une fonction Map.map.

- ----
Christophe Raffalli
Universite Paris XII

URL: http://www.logique.jussieu.fr/www.raffalli

Here are the files:
Voici les fichiers:

- ----8<---- map.mli ----8<----
(***********************************************************************)
(* *)
(* Objective Caml *)
(* *)
(* Xavier Leroy, projet Cristal, INRIA Rocquencourt *)
(* *)
(* Copyright 1996 Institut National de Recherche en Informatique et *)
(* Automatique. Distributed only by permission. *)
(* *)
(***********************************************************************)

(* $Id: map.mli,v 1.7 1996/04/30 14:50:17 xleroy Exp $ *)

(* Module [Map]: association tables over ordered types *)

(* extended by C. Raffalli *)

(* This module implements applicative association tables, also known as
   finite maps or dictionaries, given a total ordering function
   over the keys.
   All operations over maps are purely applicative (no side-effects).
   The implementation uses balanced binary trees, and therefore searching
   and insertion take time logarithmic in the size of the map. *)

module type OrderedType =
  sig
    type t
    val compare: t -> t -> int
  end
          (* The input signature of the functor [Map.Make].
             [t] is the type of the map keys.
             [compare] is a total ordering function over the keys.
             This is a two-argument function [f] such that
             [f e1 e2] is zero if the keys [e1] and [e2] are equal,
             [f e1 e2] is strictly negative if [e1] is smaller than [e2],
             and [f e1 e2] is strictly positive if [e1] is greater than [e2].
             Examples: a suitable ordering function for type [int]
             is [(-)]. You can also use the generic structural comparison
             function [compare]. *)

module type S =
  sig
    type key
          (* The type of the map keys. *)
    type 'a t
          (* The type of maps from type [key] to type ['a]. *)
    val empty: 'a t
          (* The empty map. *)
    val is_empty: 'a t -> bool
        (* Test whether a map is empty or not. *)
    val add: key -> 'a -> 'a t -> 'a t
        (* [add x y m] returns a map containing the same bindings as
           [m], plus a binding of [x] to [y]. If [x] was already bound
           in [m], its previous binding disappears. *)
    val find: key -> 'a t -> 'a
        (* [find x m] returns the current binding of [x] in [m],
           or raises [Not_found] if no such binding exists. *)
    val remove: key -> 'a t -> 'a t
        (* [remove x m] returns a map containing the same bindings as
           [m], except for [x] which is unbound in the returned map. *)
    val update: key -> ('a -> 'a) -> (unit -> 'a) -> 'a t -> 'a t
        (* [update x fn gn m] returns a map containing the same bindings as
           [m], except for [x]. If [x] was bound to [d] then it is bound to
           [fn d] in the returned map. If [x] was unbound, it is bound to
           [gn ()] in the returned map. *)
    val union: ('a -> 'a -> 'a) -> 'a t -> 'a t -> 'a t
        (* [union m m']. If [x] is bound to [d] in [m] and [d'] in [m'] then
           it is bound to [fn d d'] in the returned map. Otherwise, if [x]
           is bound to [d] in [m] or [m'] then it is bound to [d] in
           the returned map. No other bindings are created *)
    val inter: ('a -> 'a -> 'a) -> 'a t -> 'a t -> 'a t
        (* [inter m m']. If [x] is bound to [d] in [m] and [d'] in [m'] then
           it is bound to [fn d d'] in the returned map. No other bindings
           are created *)
    val diff: ('a -> 'a -> 'a option) -> 'a t -> 'a t -> 'a t
        (* [inter m m']. If [x] is bound to [d] in [m] and [d'] in [m'] then
           if [fn d d' = Some d''] then [x] is bound to [d''] in the returned
           map. Otherwise, if [x] is bound to [d] in [m] and unbound in [m']
           then [x] is bound to [d] in the returned map. No other bindings
           are created *)
    val compare: ('a -> 'a -> int) -> 'a t -> 'a t -> int
        (* Total ordering between sets. Can be used as the ordering function
           for doing sets of sets. The first argument must be a total ordering
           on 'a *)
    val equal: ('a -> 'a -> bool) -> 'a t -> 'a t -> bool
        (* [equal fn m1 m2] tests whether the maps [m1] and [m2] are
           equal, that is, contain the same bindings. [fn] is used to
           compare the value in the maps. *)
    val map: (key -> 'a -> 'b) -> 'a t -> 'b t
        (* [map f m] applies [f] to all bindings in map [m],
           and return the corresponding map: if [x] is bound to [d] in [m]
           then [x] is bound to [f x d] in the resulting map. *)
    val iter: (key -> 'a -> 'b) -> 'a t -> unit
        (* [iter f m] applies [f] to all bindings in map [m],
           discarding the results.
           [f] receives the key as first argument, and the associated value
           as second argument. The order in which the bindings are passed to
           [f] is unspecified. Only current bindings are presented to [f]:
           bindings hidden by more recent bindings are not passed to [f]. *)
    val fold: (key -> 'a -> 'b -> 'b) -> 'a t -> 'b -> 'b
        (* [fold f m a] computes [(f kN dN ... (f k1 d1 a)...)],
           where [k1 ... kN] are the keys of all bindings in [m],
           and [d1 ... dN] are the associated data.
           The order in which the bindings are presented to [f] is
           not specified. *)
  end

module Make(Ord: OrderedType): (S with type key = Ord.t)
        (* Functor building an implementation of the map structure
           given a totally ordered type. *)
- ----8<-----------------8<----

- ----8<---- map.ml ----8<----
(***********************************************************************)
(* *)
(* Objective Caml *)
(* *)
(* Xavier Leroy, projet Cristal, INRIA Rocquencourt *)
(* *)
(* Copyright 1996 Institut National de Recherche en Informatique et *)
(* Automatique. Distributed only by permission. *)
(* *)
(***********************************************************************)

(* extended by C. Raffalli *)

(* $Id: map.ml,v 1.5 1996/04/30 14:50:17 xleroy Exp $ *)

module type OrderedType =
  sig
    type t
    val compare: t -> t -> int
  end

module type S =
  sig
    type key
    type 'a t
    val empty: 'a t
    val is_empty: 'a t -> bool
    val add: key -> 'a -> 'a t -> 'a t
    val find: key -> 'a t -> 'a
    val remove: key -> 'a t -> 'a t
    val update: key -> ('a -> 'a) -> (unit -> 'a) -> 'a t -> 'a t
    val union: ('a -> 'a -> 'a) -> 'a t -> 'a t -> 'a t
    val inter: ('a -> 'a -> 'a) -> 'a t -> 'a t -> 'a t
    val diff: ('a -> 'a -> 'a option) -> 'a t -> 'a t -> 'a t
    val compare: ('a -> 'a -> int) -> 'a t -> 'a t -> int
    val equal: ('a -> 'a -> bool) -> 'a t -> 'a t -> bool
    val map: (key -> 'a -> 'b) -> 'a t -> 'b t
    val iter: (key -> 'a -> 'b) -> 'a t -> unit
    val fold: (key -> 'a -> 'b -> 'b) -> 'a t -> 'b -> 'b
  end

module Make(Ord: OrderedType) = struct

    type key = Ord.t

    type 'a t =
        Empty
      | Node of 'a t * key * 'a * 'a t * int

    let height = function
        Empty -> 0
      | Node(_,_,_,_,h) -> h

    let create l x d r =
      let hl = height l and hr = height r in
      Node(l, x, d, r, (if hl >= hr then hl + 1 else hr + 1))

    let bal l x d r =
      let hl = match l with Empty -> 0 | Node(_,_,_,_,h) -> h in
      let hr = match r with Empty -> 0 | Node(_,_,_,_,h) -> h in
      if hl > hr + 2 then begin
        match l with
          Empty -> invalid_arg "Set.bal"
        | Node(ll, lv, ld, lr, _) ->
            if height ll >= height lr then
              create ll lv ld (create lr x d r)
            else begin
              match lr with
                Empty -> invalid_arg "Set.bal"
              | Node(lrl, lrv, lrd, lrr, _)->
                  create (create ll lv ld lrl) lrv lrd (create lrr x d r)
            end
      end else if hr > hl + 2 then begin
        match r with
          Empty -> invalid_arg "Set.bal"
        | Node(rl, rv, rd, rr, _) ->
            if height rr >= height rl then
              create (create l x d rl) rv rd rr
            else begin
              match rl with
                Empty -> invalid_arg "Set.bal"
              | Node(rll, rlv, rld, rlr, _) ->
                  create (create l x d rll) rlv rld (create rlr rv rd rr)
            end
      end else
        Node(l, x, d, r, (if hl >= hr then hl + 1 else hr + 1))

    (* Same as bal, but repeat rebalancing until the final result
       is balanced. *)

    let rec join l x d r =
      match bal l x d r with
        Empty -> invalid_arg "Set.join"
      | Node(l', x', d', r', _) as t' ->
          let diff = height l' - height r' in
          if diff < -2 or diff > 2 then join l' x' d' r' else t'

    (* Merge two trees l and r into one.
       All elements of l must precede the elements of r.
       Assumes | height l - height r | <= 2. *)

    let rec merge t1 t2 =
      match (t1, t2) with
        (Empty, t) -> t
      | (t, Empty) -> t
      | (Node(l1, v1, d1, r1, h1), Node(l2, v2, d2, r2, h2)) ->
          bal l1 v1 d1 (bal (merge r1 l2) v2 d2 r2)

    (* Same as merge, but does not assume anything about l and r. *)

    let rec concat t1 t2 =
      match (t1, t2) with
        (Empty, t) -> t
      | (t, Empty) -> t
      | (Node(l1, v1, d1, r1, h1), Node(l2, v2, d2, r2, h2)) ->
          join l1 v1 d1 (join (concat r1 l2) v2 d2 r2)

    (* Splitting *)

    let rec split x = function
        Empty ->
          (Empty, None, Empty)
      | Node(l, v, d, r, _) ->
          let c = Ord.compare x v in
          if c = 0 then (l, Some (v, d), r)
          else if c < 0 then
            let (ll, vl, rl) = split x l in (ll, vl, join rl v d r)
          else
            let (lr, vr, rr) = split x r in (join l v d lr, vr, rr)

    (* Implementation of the set operations *)

    let empty = Empty

    let is_empty = function Empty -> true | _ -> false

    let rec add x data = function
        Empty ->
          Node(Empty, x, data, Empty, 1)
      | Node(l, v, d, r, h) as t ->
          let c = Ord.compare x v in
          if c = 0 then
            Node(l, x, data, r, h)
          else if c < 0 then
            bal (add x data l) v d r
          else
            bal l v d (add x data r)

    let rec find x = function
        Empty ->
          raise Not_found
      | Node(l, v, d, r, _) ->
          let c = Ord.compare x v in
          if c = 0 then d
          else find x (if c < 0 then l else r)

    let rec remove x = function
        Empty ->
          Empty
      | Node(l, v, d, r, h) as t ->
          let c = Ord.compare x v in
          if c = 0 then
            merge l r
          else if c < 0 then
            bal (remove x l) v d r
          else
            bal l v d (remove x r)

    let rec update x fn gn = function
        Empty ->
          Node(Empty, x, gn (), Empty, 1)
      | Node(l, v, d, r, h) as t ->
          let c = Ord.compare x v in
          if c = 0 then
            Node(l, x, fn d, r, h)
          else if c < 0 then
            bal (update x fn gn l) v d r
          else
            bal l v d (update x fn gn r)

    let rec union fn s1 s2 =
      match (s1, s2) with
        (Empty, t2) -> t2
      | (t1, Empty) -> t1
      | (Node(l1, v1, d1, r1, _), t2) ->
          let (l2, q, r2) = split v1 t2 in
          let d1'' = match q with
            None -> d1
          | Some(_,d1') -> fn d1 d1' in
          join (union fn l1 l2) v1 d1'' (union fn r1 r2)

    let rec inter fn s1 s2 =
      match (s1, s2) with
        (Empty, t2) -> Empty
      | (t1, Empty) -> Empty
      | (Node(l1, v1, d1, r1, _), t2) ->
          match split v1 t2 with
            (l2, None, r2) ->
              concat (inter fn l1 l2) (inter fn r1 r2)
          | (l2, Some (_, d1'), r2) ->
              join (inter fn l1 l2) v1 (fn d1 d1') (inter fn r1 r2)

    let rec diff fn s1 s2 =
      match (s1, s2) with
        (Empty, t2) -> Empty
      | (t1, Empty) -> t1
      | (Node(l1, v1, d1, r1, _), t2) ->
          match split v1 t2 with
            (l2, None, r2) ->
              join (diff fn l1 l2) v1 d1 (diff fn r1 r2)
          | (l2, Some (_,d1'), r2) ->
              match fn d1 d1' with
                Some d1'' ->
                  join (diff fn l1 l2) v1 d1'' (diff fn r1 r2)
              | None ->
                  concat (diff fn l1 l2) (diff fn r1 r2)

    let rec compare_aux fn l1 l2 =
        match (l1, l2) with
        ([], []) -> 0
      | ([], _) -> -1
      | (_, []) -> 1
      | (Empty :: t1, Empty :: t2) ->
          compare_aux fn t1 t2
      | (Node(Empty, v1, d1, r1, _) :: t1, Node(Empty, v2, d2, r2, _) :: t2) ->
          let c = Ord.compare v1 v2 in
          if c <> 0 then c else
          let c = fn d1 d2 in
          if c <> 0 then c else compare_aux fn (r1::t1) (r2::t2)
      | (Node(l1, v1, d1, r1, _) :: t1, t2) ->
          compare_aux fn (l1 :: Node(Empty, v1, d1, r1, 0) :: t1) t2
      | (t1, Node(l2, v2, d2, r2, _) :: t2) ->
          compare_aux fn t1 (l2 :: Node(Empty, v2, d2, r2, 0) :: t2)

    let compare fn s1 s2 =
      compare_aux fn [s1] [s2]

    let equal fn s1 s2 =
      let fn' x y = if fn x y then 0 else 1 in
      compare fn' s1 s2 = 0

    let rec iter f = function
        Empty -> ()
      | Node(l, v, d, r, _) ->
          iter f l; f v d; iter f r

    let rec map f = function
        Empty -> Empty
      | Node(l, v, d, r, h) ->
          Node (map f l, v, f v d, map f r, h)

    let rec fold f m accu =
      match m with
        Empty -> accu
      | Node(l, v, d, r, _) ->
          fold f l (f v d (fold f r accu))

end
- ----8<-----------------8<----

- ----8<---- multiset.ml ----8<----
(* Multisets over ordered types *)

module type OrderedType =
  sig
    type t
    val compare: t -> t -> int
  end

module type S =
  sig
    type elt
    type t
    val empty: t
    val is_empty: t -> bool
    val mem: elt -> t -> bool
    val count: elt -> t -> int
    val add: elt -> t -> t
    val multiple_add: elt -> int -> t -> t
    val remove: elt -> t -> t
    val union: t -> t -> t
    val inter: t -> t -> t
    val diff: t -> t -> t
    val compare: t -> t -> int
    val equal: t -> t -> bool
    val iter: (elt -> int -> 'a) -> t -> unit
    val fold: (elt -> int -> 'a -> 'a) -> t -> 'a -> 'a
    val cardinal: t -> int
  end

module Make(Ord: OrderedType) =
  struct
    type elt = Ord.t

    module M = Map.Make (Ord)

    type t = int M.t

    let empty = M.empty

    let is_empty = M.is_empty

    let count x s = try M.find x s with Not_found -> 0

    let mem x s = count x s > 0

    let add x s = M.update x (fun n -> n+1) (fun () -> 1) s

    let multiple_add x n s = M.update x (fun n' -> n+n') (fun () -> n) s

    let remove x s =
      try
        M.update x (fun n -> if n = 1 then raise Exit else n-1)
                   (fun () -> raise Not_found) s
      with
        Exit -> M.remove x s

    let union = M.union (+)

    let inter = M.inter (min : int -> int -> int)

    let diff = M.diff (fun n n' -> let r = n - n' in
                        if r <= 0 then None else Some r)

    let compare = M.compare (-)

    let equal = M.equal ((=) : int -> int -> bool)

    let iter = M.iter

    let fold = M.fold

    let cardinal s = fold (fun _ n n' -> n + n') s 0

  end
- ----8<-----------------8<----

- --PAA11081.872600455/concorde.inria.fr--
------- End of forwarded message -------

- ----
Christophe Raffalli
Universite Paris XII

URL: http://www.logique.jussieu.fr/www.raffalli



This archive was generated by hypermail 2b29 : Sun Jan 02 2000 - 11:58:12 MET