(* Copyright (c) 1991 by Carnegie Mellon University *)

(* bitset.sml
 *
 * An implementation of sets of integers between 0..n-1 and operations on
 * those sets.
 * 
 * Author: David Tarditi
 *         Carnegie Mellon University
 *         tarditi@cs.cmu.edu
 *
 * The operations on sets are:
 *
 *  val empty : int -> set          Create an empty set of size s
 *  val all: int -> set             Create a full set of size s
 *  intersect : set * set -> unit   intersect(a,b): b <- intersection(a,b)
 *  union : set * set -> unit       union(a,b): b <- union(a,b)
 *  add : int * set -> unit         add(a,s): s <- union(s,{a})
 *  elems : set -> int list         Creates a list of the integers in the set
 *  copy : set -> set               Creates a copy of a set
 *  equal : set -> set              Set equality
 *  exception Size                  Raised when an operation is passed two
 *                                  sets of different sizes
 * 
 * Some operations are side-effecting operations.
 *
 * Sets are implemented as bitsets using bitwise operations on integers 
 * stored in arrays.  This representation is extremely compact -- it
 * takes up 1/100 of the memory that a list-based representation would
 * for non-trivial sized sets (sets with more than 30 elements).
 * 
 * This code is specific to SML/NJ and is not Standard ML code. It uses arrays
 * and bitwise operations, neither of which are in the Definition.  It also
 * assumes that integers have a minimum of 30 bits.
 *)

signature BITSET =
    sig
      type set
      val empty : int -> set
      val all : int -> set
      val intersect : set * set -> unit
      val union : set * set -> unit
      val add : int * set -> unit
      val elems : set -> int list
      val copy : set -> set
      val equal : set * set -> bool
      exception Size
    end

abstraction Bitset : BITSET =
  struct

     (* the number of bits in an integer *)

     val n = 30

     (* type set: the set represents integers between 0 and size-1 *)

     type set  = {size : int, elems : int array}

     exception Size

     (* words: the size of the integer array that we need to hold integers
        between 0 and s-1 *)

     fun words s = (s quot n)+1

     fun empty s = 
	   let val c = words s
	   in {size=s,elems = array(c,0)}
           end

     local

       (* allones: an integer with bits 0..n-1 set to 1 *)

        val allones =
           let fun f(i,r) = if i<n then f(i+1,Bits.orb(Bits.lshift(r,1),1))
	                    else r
           in f(0,0)
           end
     in

       (* all: create a set of size s containing all integers between 0..s-1 *)

       fun all s =
	   let

       (* lastones: the last integer may be only partially filled with ones *)

	       fun lastones offset =
	         let fun f (i,r)=
	                if i<offset then f(i+1,Bits.orb(Bits.lshift(r,1),1))
	                else r
	         in f (0,0)
	         end
               val c = words s
	       val elems = array(c,allones)
	       val _ = update(elems,c-1,lastones (s mod n))
	   in {size=s,elems=elems}
	   end
     end

     (* intersect(a,b): b<-intersection(a,b) *)

     fun intersect ({size=s1,elems=e1} : set,{size=s2,elems=e2} : set) =
	   if s1<>s2 then raise Size
	   else let val j = Array.length e1
	            fun f i =
	              if i<j then (update(e2,i,Bits.andb(e1 sub i,e2 sub i));
				      f (i+1))
	              else ()
	        in f 0
	        end

     (* union (a,b): b<-union(a,b) *)

     fun union ({size=s1,elems=e1} : set,{size=s2,elems=e2} : set) =
	   if s1<>s2 then raise Size
	   else let val j = Array.length e1
	            fun f i =
	              if i<j then (update(e2,i,Bits.orb(e1 sub i,e2 sub i));
				      f (i+1))
	              else ()
	        in f 0
	        end

     (* add (n,s): s<-union(s,{n}) *)

     fun add (new,{size,elems} : set) =
	 if new>=0 andalso new<size then
	    let val i = new quot n
	        val j = new - i * n
	    in update(elems,i,Bits.orb(elems sub i,Bits.lshift(1,j)))
	    end
        else raise Size

     (* copy s: create a copy of s *)

     fun copy (arg as {size=s,...} : set) =
	   let val newset = empty s
	   in (union(arg,newset); newset)
	   end

     (* equal: set equality *)

     fun equal ({size=s1,elems=e1},{size=s2,elems=e2}) =
	 if s1<>s2 then raise Size
	 else let val j = Array.length e1
	          fun f i = 
	               if i>=j then true
	               else if (e1 sub i) <> (e2 sub i)
	                       then false
	                       else f(i+1)
              in f 0
	      end
	          
     (* elems s: create a list of the elements of s *)

     fun elems ({size,elems}) =
	  let val min = fn (i,j) => if i<j then i else j

            (* f loops over the elements of the integer array.  pos is
	        the current element of the array, count is the integer
	        corresponding to the next bit position that we're looking
                at, r is a list of integers which are in the set and between
                0..count-1 *)

	          fun f(pos,count,r) =
		   let val elem = elems sub pos

                    (* g loops over an integer fetched from the array.
		       The maximum integer represented by a bit position
		       in elem is bound-1.  Pattern masks out bit positions
		       other htan those for count.  Count is the integer
		       corresponding to the next bit position that we're
		       looking at.  r is the result list of integers. *)
		       
	               fun g (bound,pattern,count,r) =
	                 if count<bound then
	                     g(bound, Bits.lshift(pattern,1),count+1,
				if Bits.andb(elem,pattern)<>0 
				   then count :: r
				   else r)
	                 else if count>=size then rev r else f(pos+1,count,r)
	           in g(min(size,count+n),1,count,r)
	           end
          in if size<>0 then f(0,0,nil) else nil
          end
end



