(*------------------------------------------------------------------
   file  : co_thread.sml
   date  : August 22, 1990
   info  : goes with SML/NJ version 0.59
   author: Eric Cooper and Greg Morrisett
   desc  :

         This file contains a functor for the Standard ML Thread 
         signature.  This implementation uses continuations as
         threads and has no preemption.

         The corresponding signature is in the file:
         thread.sig.sml

         Special thanks to Andrzej Filinski for help with the
         per-thread state implementation.
 ------------------------------------------------------------------*)
import "queue.sig";
import "thread.sig";

functor Co_Thread (Queue : QUEUE) : THREAD =
    struct
        (************************************************)
	(* per-thread state                             *)
        (************************************************)
	type env = unit ref
	datatype 'a var = VAR of (env * 'a) list ref
	exception Undefined
	    
	fun new_env () = ref ()

	val current_env = ref (new_env ())

	fun var a = VAR (ref [(!current_env, a)])

	fun find _ [] = raise Undefined
	  | find env ((e, a) :: rest) =
	    if e = env then a else find env rest

	fun get (VAR v) = find (!current_env) (!v)

	fun replace env [] a = [(env, a)]
	  | replace env ((pair as (e, _)) :: rest) a =
	    if e = env then (e, a) :: rest
	    else pair :: replace env rest a

	fun set (VAR v) a = (v := replace (!current_env) (!v) a)

        (************************************************)
        (* atomicity                                    *)
        (************************************************)
	fun bracket pre post obj body =
	    let val _ = pre obj
		val result = body () handle exn =>
		    (post obj; raise exn)
	    in
		post obj;
		result
	    end

	val enabled = ref true

	exception Disable

	fun disable () =
	    if !enabled then
		enabled := false
	    else
		raise Disable

	exception Enable
	
	fun enable () =
	    if !enabled then
		raise Enable
	    else
		enabled := true

	fun atomically body = bracket disable enable () body

        (************************************************)
        (* thread creation, destruction, and scheduling *)
        (************************************************)
	datatype thread = THREAD of unit cont * env

	fun thread k = THREAD (k, !current_env)

	val run_queue : thread Queue.t = Queue.create ()

	fun reschedule thread = Queue.enq run_queue thread

	fun block thread q = Queue.enq q thread

	exception Deadlock

	fun next () =
	    Queue.deq run_queue
	    handle Queue.Deq => raise Deadlock

	fun switch () =
	    let val THREAD (k, env) = next ()
	    in
		current_env := env;
		throw k ()
	    end

	fun fork f =
	    atomically
	    (fn () =>
	     (callcc (fn k =>
		      (reschedule (thread k);
		       current_env := new_env();
		       enable ();
		       f () handle exn =>
			   (print "Unhandled exception ";
			    print (System.exn_name exn);
			    print " raised in thread.\n");
			   disable ();
			   switch ()))))

	fun exit () = switch ()

	fun yield () =
	    atomically (fn () =>
			(callcc (fn k =>
				 (reschedule (thread k);
				  switch ()))))

        (************************************************)
        (* mutex locks                                  *)
        (************************************************)
	datatype mutex =
	    MUTEX of bool ref * thread Queue.t

	fun mutex () =
	    MUTEX (ref false, Queue.create ())

	fun try_acquire (MUTEX (held, _)) =
	    if not (!held) then
		(held := true; true)
	    else
		false

	fun acquire (mutex as MUTEX (held, q)) =
	    let fun loop () =
		if try_acquire mutex then
		    ()
		else
		    (callcc (fn k =>
			     (block (thread k) q;
			      switch ()));
		     loop ())
	    in
		atomically loop
	    end

	fun nonatomic_release (MUTEX (held, q)) =
	    (held := false;
	     reschedule (Queue.deq q)
	     handle Queue.Deq => ())

	fun release mutex =
	    atomically (fn () => nonatomic_release mutex)

	fun with_mutex mutex body =
	    bracket acquire release mutex body

        (************************************************)
        (* conditions                                   *)
        (************************************************)
	datatype condition =
	    CONDITION of (thread * mutex) Queue.t

	fun condition () =
	    CONDITION (Queue.create ())

	fun awaken condition_queue =
	    let val (thread,
		     mutex as MUTEX (_, mutex_queue)) =
		Queue.deq condition_queue
	    in
		if try_acquire mutex then
		    reschedule thread
		else
		    block thread mutex_queue
	    end

	fun repeat f =
	    (f (); repeat f)

	fun signal (CONDITION q) =
	    atomically (fn () =>
			awaken q handle Queue.Deq => ())

	fun broadcast (CONDITION q) =
	    atomically (fn () =>
			repeat (fn () => awaken q)
			handle Queue.Deq => ())

	fun wait mutex (CONDITION q) =
	    (atomically (fn () =>
			 (nonatomic_release mutex;
			  callcc (fn k =>
				  (Queue.enq q ((thread k),
						mutex);
				   switch ()))));
	     acquire mutex)

	fun await mutex cond test =
	    if test () then
		()
	    else
		(wait mutex cond; await mutex cond test)
    end
