(* Copyright (c) 1991 by Carnegie Mellon University *)

(* integrate.sml
 *   Decide which known functions to integrate into the bodies of other
 *   functions for sml2c.
 *
 * Author: David Tarditi
 *         Carnegie Mellon University
 *         tarditi@cs.cmu.edu
 *
 *
 * There is one function in this structure:
 *     info : ((lvar * lvar list * cexp) * bool) list ->
 *                {integrate CPS.lvar -> bool,
 *                 includefuns: CPS.lvar -> CPS.lvar list}
 *
 * Assumption: There are no unused functions in the function list.  This
 * code works only if this is true.
 *
 * info takes a list of functions and returns a record of two functions.
 * The first function, integrate, returns true for an lvar iff the
 * function named by the lvar should be integrated into the body of
 * another function.  The second function, inludefuns, returns a list
 * of functions to integrate into function f when the lvar for f.
 *
 *
 *  We use dominators to decide whether we can integrate a known function
 *  into the body of a function.  See Aho, Sethi, and Ullman: "Compilers:
 *  Principles, Tools, and Techniques", Chapter 10.4, 10.9, pp. 602, pp.
 *  670 for more information on dominators.
 *
 *  First, here are some definitions.  Given an initial node s in a directed
 *  graph, we say that node m dominates node n iff  all paths from s to n must
 *  go through m.  Define a node m to be a maximal dominator if no node
 *  dominates it except itself.
 *
 *  We use the set of escaping functions as our initial nodes and the call
 *  graph as our directed graph.  No escaping function can dominate another
 *  escaping function, but escaping and known functions can dominate known
 *  functions.
 *
 *  A known function f can be integrated into the body of its maximal
 *  dominator d.  This is true since the only functions which can directly
 *  call f can also be integrated into the body of the maximal dominator.
 *  Suppose not.  Then there must exist some function e which can call f but
 *  which is not dominated by d.  Thus, we have a contradiction, since d
 *  cannot be a dominator for f as we assumed.
 *
 *  Computing dominators:
 *
 *  The following algorithm computes the set of nodes which dominate each
 *  node in a graph, given an initial node n0:
 *
 *   D(n0) = {n0}
 *   for n in N-{n0} do D(n) := N
 *   while any changes to D(n) occur do:
 *	for n in N-{n0} do
 *	   D(n) := {n} U  intersection(D(p) | p is a predecessor of n)
 *
 *   We implement this algorithm using bit sets.*)

signature CALL =
  sig
   structure CPS : CPS
   val info : ((CPS.lvar * CPS.lvar list * CPS.cexp) * bool) list ->
       {integrate : CPS.lvar -> bool,
        includefuns : CPS.lvar -> CPS.lvar list}
  end

	   
functor CallFun(structure CPS : CPS
	        sharing type CPS.lvar = int
	        structure Intmap : INTMAP
		structure Bitset : BITSET
		val debug : bool ref) : CALL =
  struct
     structure CPS = CPS
     open CPS

     fun split pred nil = (nil,nil)
       | split pred (a::r) = let val (x,y) = split pred r
			 in if pred a then (a::x, y) else (x, a::y)
		        end

     fun sieve pred nil = nil
       | sieve pred (a::r) = if pred a then a::sieve pred r else sieve pred r

     val info = fn (funcs : ((lvar * lvar list * cexp) * bool) list) =>
	let val say = outputc std_out
	    val total = length funcs
            val (known,unknown) = split #2 funcs

	    (* list of lvars of known functions *)

            val known = map (fn ((a,_,_),_) => a) known

	    val _ = if !debug then
		      (say ("# known = "^(makestring (length known))^"\n");
		       say ("# unknown = "^(makestring (length unknown))^"\n"))
		    else ()

	    (* list of lvars of unknown functions *)

	    val unknown = map (fn ((a,_,_),_) => a) unknown

	    (* hash lvars of functions to numbers between 0 and num-1 *)

	    local
	       exception Func
               val functable : int Intmap.intmap = Intmap.new(32, Func)
	       val unhashtable : lvar array = array(total,0)
	       val add = Intmap.add functable
	       val insert : CPS.lvar * int -> int =
		   fn (a,num) => (add(a,num); update(unhashtable,num,a);
				  num+1)
	       val n = fold insert known 0
	       val _ = fold insert unknown n
            in 
	       val count_known = n
	       val hash = Intmap.map functable
	       val unhash = fn i => unhashtable sub i
	    end

            val for = fn g =>
	        let fun f i = if i<count_known then (g i; f (i+1)) else ()
	        in f 0
	        end

            (* pred: give the (hashed) immediate predecessors of an lvar.
	       An immediate predecessor of a function f is a function that
	       contains a call to f *)

	    val pred : int -> int list =
	       let val callers = array(count_known, nil : int list)
	           fun mash (func,cexp) =
	             let val func' = hash func
			 fun f cexp =
	                 case cexp
	                 of RECORD (_,_,c) => f c
	                  | SELECT (_,_,_,c) => f c
	                  | OFFSET (_,_,_,c) => f c
	                  | APP (LABEL v,_) => 

			    (* add f to the predecessor list for v *)

	                    let val v' = hash v
	                    in if v'<count_known 
			        then update(callers,v',func'::(callers sub v'))
			        else ()
	                    end
	                  | APP _ => ()
	                  | FIX (l,c) => (app (fn (_,_,a) => f a) l; f c)
	                  | SWITCH (_,cl) => app f cl
	                  | PRIMOP (_,_,_,cl) => app f cl
	             in f cexp
	             end

	           val dumpcallers = fn () => for (fn i=>
		      (say (makestring i ^ ":");
		       app (fn a => say ((makestring a)^ " ")) (callers sub i);
		       say "\n"))

	       in app (fn ((func,_,cexp),_) => mash(func,cexp)) funcs;

                  (* remove duplicate entries in predecessor lists *)

		  for (fn i =>
	                  update(callers,i,
		             SortedList.foldmerge (map (fn i => [i])
						   (callers sub i))));
		  if !debug then dumpcallers() else ();
		  fn i => callers sub i
	       end

		        
               (* compute dominator information.*)

               val dom : int -> Bitset.set =
		  let val dom = array(total,Bitset.all total)
		      val _ = for (fn i => update(dom,i,Bitset.all total))

                 (* set up dominator information for unknown functions.  
		    Unknown functions are only dominated by themselves *)
			  
		      fun g i = if i<total then
			   let val a = Bitset.empty total
			   in update(dom,i,a); Bitset.add(i,a); g(i+1)
			   end
		                else ()
		      val _ = g count_known
		      val repeat = ref false

                  (* iteratively compute dominator information for
		      known functions *)

	              fun loop _ =
	                (for (fn i =>
			      let val s = dom sub i
				  val initial = Bitset.copy s
			      in app (fn i=>Bitset.intersect(dom sub i,s))
				     (pred i);
				 Bitset.add(i,s);
				 repeat := (!repeat orelse
				            not (Bitset.equal(s,initial)))
			      end);
	                       if (!repeat)
				   then (repeat := false; loop ())
			           else ())
		  in loop ();
		     fn i => dom sub i
	          end

              (* compute the maximal dominators of known functions *)

	       val maxdom : int -> int =
		 let val maxdom = array(count_known,0)
		     val is_maxdom = array(count_known,false)

                     (* decide which functions are maximal domiinators.
		        A function is a maximal dominator iff it is
			dominated by only itself *)

		     val _ = for (fn i => case Bitset.elems (dom i)
				          of [h] =>
					      if h=i then
						  update(is_maxdom,i,true)
					      else ()
					   | _ => ())

		     (* decide the maximal dominators for the known
		        functions.  Take the list of dominators of
			a known function and remove all non-maximal
			dominators. *)

		     val _ = for (fn i =>
				  case sieve (fn j =>
					  if j<count_known then
					      is_maxdom sub j
					  else true)
				       (Bitset.elems (dom i))
				  of [h] => update(maxdom,i,h)
				| l => (ErrorMsg.impossible "integrate: 181"))
		 in fn i => if i<count_known
				then maxdom sub i
				else i
		 end
				      
              (* integrate all functions which are not maximal dominators *)

              val integrate : CPS.lvar -> bool = 
		fn l =>
		  let val l' = hash l
                  in maxdom l' <> l'
		  end

              (* integratefuns: a list of functions to be integrated *)

	      val integratefuns : CPS.lvar list = sieve integrate known

              (* compute includefuns: find the maximal function g for
	         each function f to be included, add f to the list
		 of functions to be integrate into g *)

	      val includefuns : CPS.lvar -> (CPS.lvar list) = 
	        let val include_map = array(total,nil : CPS.lvar list)
                    val _ = app  (fn l : lvar =>
				    let val parent = maxdom (hash l)
				    in update(include_map,parent,
					      l :: (include_map sub parent))
				    end) integratefuns
	        in fn l : lvar => include_map sub (hash l)
                end

	       val _ = if !debug then
		       (say "functions to be integrated:";
		        app (fn a=> say (makestring a^ " ")) integratefuns;
		        say "\ninclude map:";
		        app (fn a =>
			   (say (makestring a); say ":";
			    app (fn a=> say(makestring a^ " "))
			         (includefuns a);
			    say "\n"))
			   (known @ unknown))
		       else ()
       in {integrate=integrate,includefuns=includefuns}
       end
end
