#include "defs.h"
#include "mp.e"
#include "mp.h"


mp_float
mp_root		WITH_3_ARGS(
	mp_float,	x,
	mp_int,		n,
	mp_float,	y
)
/*
Returns y = x^(1/n) for integer n, mp x and y using Newton's method without
divisions.  The time taken is O(M(t)) unless int_abs(n) is large (in which
case mp_power() is used instead).  Accumulator operations are performed.
*/
{
    mp_ptr_type		xp = mp_ptr(x), yp = mp_ptr(y);
    mp_round_type	save_round = round;
    mp_sign_type	x_sign;
    mp_base_type	b;
    mp_length		t, new_t, new_t2, extra;
    mp_acc_float	x_copy, result, value_of_f, upper, mid;
    mp_ptr_type		x_copy_ptr, result_ptr, value_of_f_ptr;
    mp_int		abs_n;

    mp_check_2("mp_root", xp, yp);

    DEBUG_BEGIN(DEBUG_ROOT);
    DEBUG_PRINTF_1("+root {\n");
    DEBUG_1("x = ", xp);
    DEBUG_PRINTF_2("n = %d\n\n", n);

    if (n == 1)
    {
	/*
	Simply copy x if n is 1.
	*/

	mp_copy_ptr(xp, yp);

	DEBUG_1("-} y = ", yp);
	DEBUG_END();

	return y;
    }

    /*
    Check for various illegal argument combinations.
    */

    abs_n = int_abs(n);
    x_sign = mp_sign(xp);

    if (n == 0)
	mp_error("mp_root: n is zero");

    if (x_sign == 0)
    {
	if (n < 0)
	    mp_error("mp_root: x is zero and n is less than zero");

	mp_set_sign(yp, 0);

	DEBUG_PRINTF_1("-} y = 0");
	DEBUG_END();

	return y;
    }

    if (x_sign < 0 && abs_n % 2 == 0)
	mp_error("mp_root: x is less than zero and n is even");


    b = mp_b(xp);
    t = mp_t(xp);

    /*
    Calculate sufficient guard digits and allocate 2 temporary floats.
    */

    extra = 1 + mp_extra_guard_digits(abs_n, b);
    new_t = t + extra;

    mp_change_up();
    mp_acc_float_alloc_2(b, new_t, x_copy, result);


    /*
    Work with int_abs(x), fix up sign later.
    */

    mp_move(x, x_copy);
    mp_set_sign(mp_acc_float_ptr(x_copy), int_abs(x_sign));


    /*
    Check for large int_abs(n).
    */

    if (abs_n >= MAX_EXPT / 4)
    {
	/*
	It's more efficient to use mp_power() for very large n.
	*/

	mp_q_to_mp(1, n, result);
	mp_power(x_copy, result, result);

	result_ptr = mp_acc_float_ptr(result);
    }
    else
    {
	mp_int		ex, j, k, limit, e, ep, last_k; 
	mp_bool		first;


	/*
	Main case: We calculate y = x^(-1/abs_n) by solving the
	equation f(y) = 0 where f(y) = y^n - x (x constant, i.e.
	input).  First the exponent of x_copy is reduced as much as
	possible.  This ensures that the (scaled) result is between 1/b
	and 1.  Next, if abs_n is 1 or 2, an approximation is made to y
	(by guessing 1/x_copy or 1/sqrt(x_copy)).  Otherwise the
	bisection method is applied to g(y) = x * y^n - 1 = 0 (which
	has the same root as f(y)) with initial lower and upper values
	of x being 1/b and 1 until int_abs(g(mid point)) <= 1/2.

	Having now a good approximation for y, Newton's method is
	applied on the same equation: f(y) = 0.  Finally, if n > 0, y
	(== x^(-1/n)) is converted to y^(1/n) by raising it to the
	(n - 1)th power and multiplying by x_copy.

	If abs_n <= 2 we only need x_copy in the following floats, but
	since it is efficient to allocate many together, we allocate 3
	temporary floats.  upper and mid are only used in the bisection
	phase.
	*/

	mp_acc_float_alloc_3(b, new_t, value_of_f, upper, mid);

	/*
	Set pointers.  The macro fix_pointers() makes the later code a bit
	simpler to read.
	*/

	x_copy_ptr = mp_acc_float_ptr(x_copy);
	result_ptr = mp_acc_float_ptr(result);
	value_of_f_ptr = mp_acc_float_ptr(value_of_f);

	mp_set_digits_zero(mp_digit_ptr(result_ptr, 0), new_t);


#define fix_pointers()	if (mp_has_changed()) { \
			    x_copy_ptr = mp_acc_float_ptr(x_copy);	   \
			    result_ptr = mp_acc_float_ptr(result);	   \
			    value_of_f_ptr = mp_acc_float_ptr(value_of_f); \
			}


	e = -mp_expt(x_copy_ptr);

	/*
	Compute ep = floor(e/abs_n) + 1.
	*/

	ep = e / abs_n;

	if (abs_n * ep <= e)
	    ep++;


	/*
	Scale to avoid under/overflow: scaled int_abs(x) is between 1 and b^abs_n.
	*/

	mp_expt(x_copy_ptr) += ep * abs_n;


	/*
	Lower and upper bounds on int_abs(scaled result) are now 1/b and 1.
	Set limit to the maximum number of iterations allowed.
	*/

	limit = t * mp_times_log2_b(1, b);


	/*
	Reduce t at first.
	*/

	new_t2 = 2 + extra;


	fix_pointers();

	mp_t(x_copy_ptr) = mp_t(result_ptr) = mp_t(value_of_f_ptr) = new_t2;



	/*
	To speed up reciprocals and square roots, treat abs_n = 1 or 2
	as special cases and get better starting approximations for them.
	*/

	if (abs_n == 1)
	    mp_q_to_mp(b, b * mp_digit(x_copy_ptr, 0) +
			    mp_digit(x_copy_ptr, 1), result);

	else if (abs_n == 2)
	{
	    /*
	    Compute j = upper bound on 4 * int_abs(scaled x).
	    */

	    j = 4 * (mp_to_int(x_copy) + 1);

	    /*
	    Compute k = upper bound on sqrt(j) using integer Newton iteration.
	    */

	    k = j;

	    do
	    {
		last_k = k;
		k = (k + j / k) / 2;
	    } while (k < last_k);

	    if (k * k < j)
		k++;


	    /*
	    Now get a lower bound on scaled int_abs(result).
	    */

	    mp_q_to_mp(2, k, result);
	}
	else
	{
	    /*
	    Here abs_n > 2.  Apply bisection method to g(y) = x * y^n - 1 = 0.
	    result is used as the lower value of y.  Lower and upper bounds
	    on int_abs(scaled result) are 1/b and 1.
	    */

	    mp_ptr_type		mid_ptr, upper_ptr;


	    mp_q_to_mp(1, b, result);
	    mp_int_to_mp(1, upper);

	    upper_ptr = mp_acc_float_ptr(upper);
	    mid_ptr = mp_acc_float_ptr(mid);

	    mp_t(upper_ptr) = mp_t(mid_ptr) = new_t2;


	    do
	    {
		/*
		Check for infinite loop.
		*/

		if (--limit < 0)
		    mp_bug("mp_root: iteration not converging (bisection)");

		/*
		Compute mid point of result (lower) and upper.
		*/

		mp_sub(upper, result, mid);
		mp_div_int_eq(mid, 2);
		mp_add_eq(mid, result);


		/*
		Compute f(mid).
		*/

		mp_abs_int_power(mid, abs_n, value_of_f);
		mp_mul_eq(value_of_f, x_copy);
		mp_add_int_eq(value_of_f, -1);


		/*
		Fix pointers if necessary.
		*/

		if (mp_has_changed())
		{
		    x_copy_ptr = mp_acc_float_ptr(x_copy);
		    result_ptr = mp_acc_float_ptr(result);
		    value_of_f_ptr = mp_acc_float_ptr(value_of_f);
		    upper_ptr = mp_acc_float_ptr(upper);
		    mid_ptr = mp_acc_float_ptr(mid);
		}

		if (!mp_is_pos(value_of_f_ptr))
		    mp_copy_ptr(mid_ptr, result_ptr);

		if (!mp_is_neg(value_of_f_ptr))
		    mp_copy_ptr(mid_ptr, upper_ptr);

		/*
		Repeat bisection if int_abs(residual) >= 1/2.
		*/

	    } while (!mp_is_zero(value_of_f_ptr) &&
			(mp_expt(value_of_f_ptr) > 0 ||
			    mp_expt(value_of_f_ptr) == 0 &&
			    2 * mp_digit(value_of_f_ptr, 0) >= b
			)
		    );

	    DEBUG_PRINTF_1("bisection finished\n");
	    DEBUG_1("result = ", result_ptr);
	    DEBUG_1("value_of_f = ", value_of_f_ptr);

	    /*
	    Now Newton's method should converge.
	    */

	    if (!mp_is_zero(value_of_f_ptr))
		mp_copy_ptr(mid_ptr, result_ptr);
	}

	first = TRUE;
	ex = -new_t2;

	fix_pointers();

	/*
	Newton loop - the iteration is:

		new y = y + y / n * (x * y^n - 1).
		      = y + y / n * f(y).
	*/

	do
	{
	    /*
	    If abs_n > 2, and first is true, we just calculated f(result)
	    in the bisection loop, so we can skip this step.
	    */

	    if (abs_n > 2 && first)
		first = FALSE;
	    
	    else
	    {
		/*
		Choose good t.
		*/
	    
		mp_length	t = new_t2 + 4 * int_abs(ex);

		if (t > new_t)
		    t = new_t;

		fix_pointers();

		mp_t(x_copy_ptr) = mp_t(result_ptr) = mp_t(value_of_f_ptr) = t;


		/*
		Calculate f(result).
		*/

		mp_abs_int_power(result, abs_n, value_of_f);
		mp_mul_eq(value_of_f, x_copy);
		mp_add_int_eq(value_of_f, -1);

		ex = -t;
	    }

	    fix_pointers();

	    if (!mp_is_zero(value_of_f_ptr))
		ex = mp_expt(value_of_f_ptr);

	    /*
	    Check for infinite loop.
	    */

	    if (--limit < 0 || ex > 0)
		mp_bug("mp_root: iteration not converging (newton)");


	    /*
	    Set y += y/n * f(y).
	    */

	    mp_mul_eq(value_of_f, result);
	    mp_div_int_eq(value_of_f, abs_n);
	    mp_sub_eq(result, value_of_f);

	} while (2 * ex + new_t > 0);


	/*
	Correct exponent of y.
	*/

	fix_pointers();
	mp_expt(result_ptr) += ep;


	/*
	Correct result if n is positive.
	*/

	if (n > 0)
	{
	    /*
	    Fix exponent of x_copy and set result = x_copy * result^(n - 1).
	    */

	    mp_expt(x_copy_ptr) -= ep * abs_n;

	    mp_abs_int_power(result, n - 1, result);
	    mp_mul_eq(result, x_copy);

	    fix_pointers();
	}


	mp_acc_float_delete(value_of_f);
	mp_acc_float_delete(upper);
	mp_acc_float_delete(mid);
    }

    mp_sign(result_ptr) *= x_sign;

    round = save_round;
    mp_move_round(result, y);

    mp_acc_float_delete(result);
    mp_acc_float_delete(x_copy);

    mp_change_down();

    DEBUG_1("-} y = ", yp);
    DEBUG_END();

    return y;
}


mp_float
mp_sqrt		WITH_2_ARGS(
	mp_float,	x,
	mp_float,	y
)
/*
Returns y = sqrt(x), using mp_root().
*/
{
    mp_sign_type	sign = mp_sign(mp_ptr(x));


    if (sign < 0)
	mp_error("mp_sqrt: x is negative");
    

    if (sign == 0)
	mp_set_sign(mp_ptr(y), 0);
    else
	mp_root(x, 2, y);
    
    return y;
}
