(* Copyright (c) 1991 by Carnegie Mellon University *)

(* cbignum.sml
 *
 * Arbitrary precision integers (bignums) and associated operations.
 *
 * Author: David Tarditi
 *         Carnegie Mellon University
 *         tarditi@cs.cmu.edu
 *
 * The operations on bignums are:
 *
 * val + : bignum * bignum -> bignum
 * val - : bignum * bignum -> bignum
 * val * : bignum * bignum -> bignum
 * val div : bignum * bignum -> bignum
 *     This rounds towards zero, unlike the version of div in the
 *     Definition.  The remainder always has the same sign as the
 *     dividend.  div(d,i) satsifies the following equations
 *             div(d,i)*i+r = i and
 *	         if d<0 then -i < r <= 0
 *		 if d>0 then 0 <= r < i
 * val mod = a - (a div b) * b
 * val ~ : bignum -> bignum
 * val > : bignum * bignum -> bool
 * val >= : bignum * bignum -> bool
 * val < : bignum * bignum -> bool
 * val <= : bignum * bignum -> bool
 * 
 * The polymorphic ML equality function can be used to test for equality.
 * Bignums are equal iff their concrete representations are identical
 *
 * val makestring : bignum -> string
 * val inttobignum : int -> bignum
 *    This converts an integer to a bignum.
 *
 * Bignums are implemented using boolean values to represent signs and
 * list of integers to represent the magnitude.   Each integer in the list
 * is between  0 and b, where b*b does not cause an overflow and b is a power
 * of 10.  The magnitudes of bignums can be viewed as digit strings in base b.
 * The digits for a magnitude are stored in the list from least significant
 * digit to most significant digit and are normalized
 * (the most signicant digit is not zero).
 *
 * Division is slow.  The number representation was chosen to make
 * it easy to implement and easy to print numbers.
 *)

signature CBIGNUM =
  sig
    infix 7 * div mod
    infix 6 + -
    infix 4 > < >= <=
    exception Div

    eqtype bignum
    val +   : bignum * bignum -> bignum
    val -   : bignum * bignum -> bignum
    val *   : bignum * bignum -> bignum
    val div : bignum * bignum -> bignum
    val mod : bignum * bignum -> bignum
    val ~   : bignum -> bignum

    val >   : bignum * bignum -> bool
    val >=  : bignum * bignum -> bool
    val <   : bignum * bignum -> bool
    val <=  : bignum * bignum -> bool

    val makestring : bignum -> string
    val inttobignum : int -> bignum
  end

(* abstraction on eqtypes is broken *)

structure CBignum : CBIGNUM =
  struct

      infix 7 * div mod
      infix 6 + -
      infix 4 > < >= <=

      (* bignum_wordsize: The largest power of 10 such that 
           10^(2*bignum_wordsize)
         fits in an integer *)

      val bignum_wordsize = 
        let val wordsize =
              let fun f (100,n) = 100
                    | f (i,n) = f(i+1,n*10) handle Overflow => i
	      in f(0,1)
              end
        in wordsize div 2
        end

      (* base = 10 ^ bignum_wordsize *)

      val base =
        let exception TooSmall
	    fun exp 0 = 1
              | exp n = 10 * exp(n-1)
            exception TooSmall
        in exp (if bignum_wordsize < 1 then
	            raise TooSmall
                else bignum_wordsize)
        end
     
      (* Bignums are represented in a sign and magnitude format.
         The boolean value is true if the number is postive or zero. It is
         false if it is negative.  For each integer i in the list of integers,
         0 <= i < base must be true. The list of integers [a0...an] denotes
	 the bignum b whose magnitude = 

	          a0 * base^0 + a1 * base^1 .. + an * base^n
      *)

      datatype bignum = BIG of bool * int list

     (* makestring: we convert each digit to a string and add padding
        zeros between each string so that each string consists of
        bignum_wordsize digits, except for the most significant
        digit.  That digit drops leading zeros
     *)

      val makestring =
	fn (BIG (_,nil)) => "0"
         | (BIG (sign,l)) =>
	  let exception Pad
	      fun pad (arg as (x :: nil)) = arg
	        | pad (x :: y) = 
	           let val diff = bignum_wordsize- String.length x
	           in x :: (if Integer.>(diff,0)
		            then let fun gen 0 = pad y
		                   | gen n = "0" :: gen (n-1)
	                         in gen diff
	                         end
	                    else pad y)
	           end
                | pad nil = raise Pad
	   in (if sign then "" else "~") ^
	       (implode (rev (pad (map Integer.makestring l))))
	   end

      (* inttobignum: remove least signficant digit d for an integer i,
         find the magnitude m of i div base, and cons d onto the front
         of magnitude m.*)

      val inttobignum = fn i =>
           (* correctness of f is based on induction on the integer i.
              Assume that f returns a magnitude in base b for i.  Clearly,
	      for i=0, the normalized magnitude is the nil list.  For
	      i>0, (i rem base) :: f(i/base) is a list representing the
	      magnitude for i.  (Magnitudes are stored from lsd to msd
	      (most signifcant digit)
	   *)

	let fun f 0 = nil
              | f i = Integer.rem(i,base) :: f(Integer.quot(i,base))
        in if Integer.>=(i,0) then BIG(true, f i)
           else
            (* unroll f a little just in case we were passed MININT.
	       The code was originally BIG(false,~i), but this overflows
               when i=MININT *)

	       BIG(false,~(Integer.rem(i,base)) :: f(~(Integer.quot(i,base))))
        end

     (* >=(a,b): if signs don't decide whether a>=b, scan the two magnitudes at
        the same time, carrying a value which is true iff for the portions
        of a and b seen so far (call them a',b'), a' >= b' *)

      val (op >=) = fn (BIG(sx,x),BIG(sy,y)) =>
	let fun f (a::ar,b::br,c) =
	          f(ar,br,if a=b then c else Integer.>(a,b))
              | f (nil,nil,c) = c
              | f (_ :: _, nil, _) = true
              | f (nil, _ :: _,_) = false
        in case (sx,sy)
           of (false,false) => f(y,x,true)
            | (false,true) => false
            | (true,false) => true
	    | (true,true) => f(x,y,true)
        end
 
        (* mag_gt: decides whether magnitude a > magnitude b.   Scan
           the magnitudes from lsd to msd, carrying a value which
	   is true iff a'>b', where a' and b' are the portions of a and b
           seen so far *)

        val mag_gt = fn (x,y) =>             
          let fun f (a::ar,b::br,c) =
	            f(ar,br,if a=b then c else Integer.>(a,b))
                | f (_ :: _, nil, _) = true
                | f (nil, _ :: _,_) = false
                | f (nil, nil,c) = c
	  in f(x,y,false)
	  end

      val (op >) = fn (BIG(sx,x),BIG(sy,y)) =>
        case (sx,sy)
           of (false,false) => mag_gt(y,x)
	    | (false,true) => false
            | (true,false) => true
            | (true,true) => mag_gt(x,y)

        val (op <) = fn args =>  not ((op >=) args)
        val (op <=) = fn args => not ((op >) args)

        (* normalize: normalize a bignum.  A bignum is normalized if the 
           mostsignificant digit of the magnitude is nonzero.*)

        val normalize : bignum -> bignum =
             let val zero = BIG(true,nil)
             in fn (BIG(sx,x)) =>
	         let fun strip (0 :: r) = strip r
                       | strip (arg as (n :: r)) = BIG(sx,rev arg)
	               | strip nil = zero
	         in strip (rev x)
	         end
             end

        (* addwc: add with carry.  Maintains invariant that magnitude
           be normalized *)

        fun addwc(a : int list,b : int list) : int list =
           let fun f (a :: ar, b :: br, c) = cont_f(a+b+c,ar,br)
                 | f (nil,r,0) = r
                 | f (r, nil,0) = r
	         | f (nil,nil,c) = [c]
                 | f (a as nil, b::br,c) = cont_f(b+c,a,br)
                 | f (a::ar, b as nil,c) = cont_f(a+c,ar,b)
	       and cont_f(sum,a,b) =
	             if Integer.>=(sum,base) then 
	                   Integer.-(sum,base) :: f(a,b,1)
	             else sum :: f(a,b,0)
           in f(a,b,0)
           end

        (* subwb (subtract with borrow): compute a-b if a >= b.
           Maintains invariant that magnitude be normalized. *)

        fun subwb(a : int list,b: int list) : int list =
            let exception Subwb
                fun f (a :: ar, b :: br, c, result) =
		             cont_f(a-b-c,ar,br,result)
		  | f (a :: ar,nil,c,result) =
			     cont_f(a-c,ar,nil,result)
	          | f (nil,nil,0,result) = result
		  | f _ = raise Subwb
	        and cont_f(diff,a,b,result) =
	             if Integer.<(diff,0) then
			     f(a,b,1,Integer.+(diff,base)::result)
	             else f(a,b,0,diff::result)
	        and normalize (0 :: result) = normalize result	
                  | normalize result = rev result
            in normalize (f(a,b,0,nil))
            end
	    
         (* multwc (multiply with carry): multiply a list of integer
            coefficients for a bignum by an integer x, where 0 <= x < base.
	    Maintains invariant that the magnitude be normalized.
          *)

        fun multwc(_,0)  : int list = nil
	  | multwc(l : int list, x : int) : int list = 

            let (* assuming that a,x,c < base, computing prod never
                   overflows and c < base the next pass through.

		   Proof: a <= base-1, x <= base-1, c <= base-1.
		   So     a*x+c <= base^2-2*base+1 + (base-1)
		                <= base^2-base
                   This shows that we have no overflow.  That c<base
		   follows from
		       c <= (base^2-base) quot base
		         <  base, since quot (SML/NJ version of div
			          which rounds down) always rounds down.
                *)

               fun f(a::ar,c) =
	         let val prod = Integer.+(Integer.*(a,x),c)
	         in Integer.rem(prod,base)::f(ar,Integer.quot(prod,base))
	         end
                  | f (nil,0) = nil
		  | f (nil,c) = [c]
             in f(l,0)
             end

        (* base_mult(mag,pow) = mag * base^pow, where pow >= 0.  Maintains
	   invariant that magnitude be normalized.*)

        fun base_mult(nil,_) = nil
	  | base_mult (num:int list, pow) =
          let fun f 0 = num
	        | f n = 0 :: f(Integer.-(n,1))
	  in f pow
	  end

        (* equiv: boolean equivalence *)

        fun equiv (true,true) = true
	  | equiv (true,false) = false
	  | equiv (false,true) = false
	  | equiv (false,false) = true
   
	fun ~ (arg as (BIG(_,nil))) = arg
          | ~ (BIG(s,r)) = BIG(not s,r)

        fun op + (BIG(true,x),BIG(true,y)) = BIG(true,addwc(x,y))
          | op + (BIG(false,x),BIG(false,y)) = BIG(false,addwc(x,y))
          | op + (arg1 as (BIG(false,x)),arg2 as (BIG (true,y))) =
	          if (~arg1 > arg2) then BIG(false,subwb(x,y))
	          else BIG(true,subwb(y,x))
          | op + (arg1 as (BIG(true,x)),arg2 as (BIG (false,y))) =
	          if (~arg2 > arg1) then BIG(false,subwb(y,x))
	          else BIG (true,subwb(x,y))

        fun op - (x,y) = x + ~y

        (* multiplication: compute using partial sums.  We don't
	   need to normalize our magnitudes since all the operations
	   that we're using preserve this invariant *)

        local
	    val zero = BIG(true,nil)
        in fun op * (BIG(sx,x),BIG(sy,y)) =
	   let fun f (x::xr,pow,partial_sum) =
		     f(xr,Integer.+(pow,1),
			  addwc(base_mult(multwc(y,x),pow),partial_sum))
	         | f (nil,_,partial_sum) = partial_sum
	       val magnitude = f (x,0,[])
	       val sign = equiv(sx,sy)
           in case magnitude
              of nil => zero
	       | _ => BIG(sign,magnitude)
	   end
        end

        (* findgtdiv: given the normalized magnitudes of two bignums j,k,
	    where j <> 0 find the greatest integer i such that 
	    0 <= i < base, i * j <= k *)

        val findgtdiv = fn (j,k) =>
          let fun binsearch(lo,hi) =
	      if hi=Integer.+(lo,1) then lo
	      else let val mid = Integer.quot(Integer.+(hi,lo),2)
                       val result = multwc(j,mid)
	           in if mag_gt(result,k)
		      then binsearch(lo,mid)
	              else binsearch(mid,hi)
	           end
	  in binsearch(0,base)
	  end

        exception Div
     
	fun op div (_,BIG(sy,nil)) = raise Div
	  | op div (BIG(sx,x),BIG(sy,y)) =

            (* We use long division.  x can be regarded as a sum of values of
	       y multiplied by base^p * d, where d is a single digit.  We
	       start with the maximum p such that x >= y*base^p is possible
	       and find the maximum digit d such that d*y*base^p <= x.
	       We then compute the remainder r =x - d*y*base^p.  If r<y
	       we stop and pad out the magnitude with p zeros.  If r>=y,
	       we repeat the entire process *)
	       
	     let fun f(x,pow) =
	          let val ny = base_mult(y,pow)
	              val d = findgtdiv(ny,x)
	              val remainder = subwb(x,multwc(ny,d))
	           in d :: (if mag_gt(y,remainder) then
		              let fun g 0 = nil
				    | g i = 0 :: g(Integer.-(i,1))
				  in g pow
			       end
			    else f(remainder,Integer.-(pow,1)))
	           end
	          val result = rev(f(x,max(Integer.-(length x,length y),0)))
	          val sign = equiv(sx,sy)
	     in normalize(BIG(sign,result))
	     end
      val op mod = fn (a,b) => a - (a div b)*b
end
