
/*****************************************************************************\

MODULE: ZZ

SUMMARY:

The class ZZ is used to represent signed, arbitrary length integers.
Here is a simple program that reads two numbers and prints their product:

#include "ZZ.h"

main()
{
   ZZ a, b, prod; 

   cin >> a; 
   cin >> b; 
   mul(prod, a, b);  
   cout << prod << "\n";
}

Routines are provided for all of the basic arithmetic operations,
as well as for some more advanced operations such as primality testing.
Space is automatically managed by the constructors and destructors
(but routines are provided for explicit management).

This module also provides routines for
generating small primes, and fast routines for performing
modular arithmetic on single-precision numbers.

A ZZ is represented as sequence of digits (plus a sign-bit)
using a radix ZZ_RADIX, which is defined as 1L << ZZ_NBITS.
On 32-bit machines, ZZ_NBITS is 30, or 26 if using the SINGLE_MUL option;
on 64-bit machines ZZ_NBITS is 50.


\*****************************************************************************/

#include <iostream.h>
#include "tools.h"


class ZZ {
public:

   ZZ(); // initial value 0
   ZZ(const ZZ& a);
   void operator=(const ZZ& a); 
   ~ZZ();


   ZZ(INIT_SIZE_TYPE, long k);
   // initial value is 0, but space is pre-allocated so that numbers
   // x with x.size() <= k can be stored without re-allocation.
   // Invoke as ZZ(INIT_SIZE, k).
   // The purpose for the INIT_SIZE argument is to prevent automatic
   // type conversion from long to ZZ.


   ZZ(INIT_VAL_TYPE, long a); // initial value a. Invoke as ZZ(INIT_VAL, a)

   void kill();
   // Space is freed and value becomes 0.

   void SetSize(long k);
   // pre-allocates space for numbers x with x.size() <= k;  
   // does not change the value.

   long size() const; 
   // returns the number of (ZZ_NBITS-bit) digits of |a|; the size of 0 is 0.

   static const ZZ& zero();
   // a read-only reference to 0
};


void clear(ZZ& x); // x = 0

void set(ZZ& x);   // x = 1

void swap(ZZ& x, ZZ& y); 
// swap x and y bi swapping pointers, if possible;
// otherwise by copy.



/*****************************************************************************\

                                  Conversion

\*****************************************************************************/


void operator<<(ZZ& x, long a); // x = a

void operator<<(ZZ& x, int a);  // x = a 

void operator<<(ZZ& x, double a);
// x = floor(a);

long Long(const ZZ& a); // return a, no overflow check

void operator<<(long& x, const ZZ& a); // x = a, no overflow check

void operator<<(int& x, const ZZ& a); // x = a, no overflow check

double Double(const ZZ& a); // return a, no overflow check

void operator<<(double& x, const ZZ& a); // x = a, no overflow check

void operator<<(float& x, const ZZ& a); //  x = a, no overflow check



/*****************************************************************************\

                                 Comparison

\*****************************************************************************/


long sign(const ZZ& a); // returns sign of a (-1, 0, +1)

long compare(const ZZ& a, const ZZ& b); // returns sign of a-b (-1, 0, or 1).

long IsZero(const ZZ& a); // test for 0

long IsOne(const ZZ& a); // test for 1

/* The usual comparison operators */
   
long operator==(const ZZ& a, const ZZ& b);
long operator!=(const ZZ& a, const ZZ& b);
long operator<(const ZZ& a, const ZZ& b);
long operator>(const ZZ& a, const ZZ& b);
long operator<=(const ZZ& a, const ZZ& b);
long operator>=(const ZZ& a, const ZZ& b);

/* single-precision versions of the above */

long compare(const ZZ& a, long b);
long operator==(const ZZ& a, long b);
long operator!=(const ZZ& a, long b);
long operator<(const ZZ& a, long b);
long operator>(const ZZ& a, long b);
long operator<=(const ZZ& a, long b);
long operator>=(const ZZ& a, long b);



/*****************************************************************************\

                                 Addition

\*****************************************************************************/


void add(ZZ& x, const ZZ& a, const ZZ& b); // x = a + b
void add(ZZ& x, const ZZ& a, long b); // x = a + b
void add(ZZ& x, long a, const ZZ& b); // x = a + b

void sub(ZZ& x, const ZZ& a, const ZZ& b); // x = a - b
void sub(ZZ& x, const ZZ& a, long b); // x = a - b
void sub(ZZ& x, long a, const ZZ& b); // x = a - b

void SubPos(ZZ& x, const ZZ& a, const ZZ& b); // z = a-b; assumes a >= b >= 0.

void negate(ZZ& x, const ZZ& a); // x = -a

void abs(ZZ& x, const ZZ& a); // x = |a|




/*****************************************************************************\

                             Multiplication

\*****************************************************************************/

void mul(ZZ& x, const ZZ& a, const ZZ& b); // x = a * b

void mul(ZZ& x, const ZZ& a, long b); // x = a * b
void mul(ZZ& x, long a, const ZZ& b); // x = a * b

void sqr(ZZ& x, const ZZ& a); // x = a*a



/*****************************************************************************\

                                 Division

\*****************************************************************************/


void DivRem(ZZ& q, ZZ& r, const ZZ& a, const ZZ& b);
// q = floor(a/b), r = a - b*q.
// This implies that:
//    |r| < |b|, and if r != 0, sign(r) = sign(b)

void div(ZZ& q, const ZZ& a, const ZZ& b);
// q = floor(a/b)

void rem(ZZ& r, const ZZ& a, const ZZ& b);
// q = floor(a/b), r = a - b*q

void QuickRem(ZZ& r, const ZZ& b);
// q = floor(a/b), r = a - b*q.
// Assumes b > 0 and r >=0.
// Division is performed in place and may cause r to be re-allocated.

long divide(ZZ& q, const ZZ& a, const ZZ& b);
// if b | a, sets q = a/b and returns 1; otherwise returns 0.

long divide(const ZZ& a, const ZZ& b);
// if b | a, returns 1; otherwise returns 0.


/* single-precision versions */

long DivRem(ZZ& q, const ZZ& a, long b);
// q = floor(a/b), r = a - b*q, return value is r.

void div(ZZ& q, const ZZ& a, long b);
// q = floor(a/b).

long rem(const ZZ& a, long b);
// q = floor(a/b), r = a - b*q, return value is r.

long divide(ZZ& q, const ZZ& a, long b);
// if b | a, sets q = a/b and returns 1; otherwise returns 0.

long divide(const ZZ& a, long b);
// if b | a, returns 1; otherwise returns 0.



/*****************************************************************************\

                                    GCD's

\*****************************************************************************/


void GCD(ZZ& d, const ZZ& a, const ZZ& b);
// d = gcd(a, b) (which is always non-negative)

void XGCD(ZZ& d, ZZ& s, ZZ& t, const ZZ& a, const ZZ& b);
//  d = gcd(a, b) = a*s + b*t;
//  The coefficients s and t are defined according to
//  the standard Euclidean algorithm applied to |a| and |b|,
//  with the signs then adjusted according to the signs 
//  of a and b.

/* single-precision versions */

long GCD(long a, long b);
// return value is gcd(a, b) (which is always non-negative)

void XGCD(long& d, long& s, long& t, long a, long b);
//  d = gcd(a, b) = a*s + b*t;
//  The coefficients s and t are defined according to
//  the standard Euclidean algorithm applied to |a| and |b|,
//  with the signs then adjusted according to the signs 
//  of a and b.



/*****************************************************************************\

                              Bit Operations

\*****************************************************************************/


void LeftShift(ZZ& x, const ZZ& a, long k);
// x = a left-shifted k bits (k < 0 implies right shift)

void RightShift(ZZ& x, const ZZ& a, long k);
// x = a right-shifted k bits (k < 0 implies left shift)

long MakeOdd(ZZ& x);
// removes factors of 2 from x, returns the number of 2's removed
// returns 0 if x == 0

long IsOdd(const ZZ& a); // test if a is odd

long NumBits(const ZZ& a);
// returns the number of bits in binary represenation of |a|; 
// NumBits(0) = 0

long bit(const ZZ& a, long k);
// returns bit k of a, position 0 being the low-order bit

long digit(const ZZ& a, long k);
// returns k-th ZZ_NBITS-digit of |a|, position 0 being the low-order digit.

void LowBits(ZZ& x, const ZZ& a, long k);
// x = low order k bits of |a| 

long LowBits(const ZZ& a, long k);
// returns low order k bits of |a|

long SetBit(ZZ& x, long p);
// returns original value of p-th bit of |a|, and replaces
// p-th bit of a by 1 if it was zero;
// error if p < 0 

long SwitchBit(ZZ& x, long p);
// returns original value of p-th bit of |a|, and switches
// the value of p-th bit of a;
// p starts counting at 0;
// error if p < 0 

long weight(long a); // returns Hamming weight of |a|

long weight(const ZZ& a); // returns Hamming weight of |a|

void and(ZZ& x, const ZZ& a, const ZZ& b); // x = |a| AND |b|

void or(ZZ& x, const ZZ& a, const ZZ& b); // x = |a| OR |b|

void xor(ZZ& x, const ZZ& a, const ZZ& b); // x = |a| XOR |b|

/* singe-precision versions */

long NumBits(long a);
// returns the number of bits in binary represenation of |a|; 
// NumBits(0) = 0

long bit(long a, long k);
// returns bit k of a, position 0 being the low-order bit

long NextPowerOfTwo(long m);
// returns least nonnegative k such that 2^k >= m



/*****************************************************************************\

                            Psuedo-Random Numbers

\*****************************************************************************/


void SetSeed(const ZZ& s); // initialize generator with a "seed"

void RandomBnd(ZZ& x, const ZZ& n);
// x = pseudo-random number in the range 0..n-1, or 0 if n <= 0

void RandomBnd(ZZ& x, long n);
// single-precision version 

void RandomLen(ZZ& x, long NumBits);
// x = psuedo-random number with precisely NumBits bits

/* single-precision versions */

long RandomBnd(long n);
// returns pseudo-random number in the range 0..n-1, or 0 if n <= 0

long RandomLen(long l);
// returns psuedo-random number with precisely NumBits bits


/*****************************************************************************\

             Incremental Chinese Remaindering

\*****************************************************************************/

long CRT(ZZ& a, ZZ& p, const ZZ& A, const ZZ& P);
long CRT(ZZ& a, ZZ& p, long A, long P);
// 0 <= A < P, (p, P) = 1;
// computes b such that b = a mod p, b = A mod p,
//   and -p*P/2 < b <= p*P/2;
// sets a = b, p = p*P, and returns 1 if a's value
//   has changed, otherwise 0




/*****************************************************************************\

                                Primality Testing 

\*****************************************************************************/


long ProbPrime(const ZZ& n, long NumTrials = 10);
// tests if n is prime;  performs a little trial division,
// followed by a single-precision MillerWitness test, followed by
// up to NumTrials general MillerWitness tests.

long MillerWitness(const ZZ& n, const ZZ& w);
long MillerWitness(const ZZ& n, long w);
// Tests if w is a witness to primality a la Miller.
// Assumption: n is odd and positive, 0 <= w < n.

void RandomPrime(ZZ& n, long l, long NumTrials=10);
// n =  random l-bit prime
// Uses ProbPrime with NumTrials.

void NextPrime(ZZ& n, const ZZ& m, long NumTrials=10);
// n = smallest prime >= m.
// Uses ProbPrime with NumTrials.

/* single-precision versions */

long ProbPrime(long n, long NumTrials = 10);
// tests if n is prime;  performs a little trial division,
// followed by a single-precision MillerWitness test, followed by
// up to NumTrials general MillerWitness tests.

long RandomPrime(long l, long NumTrials=10);
// returns random l-bit prime
// Uses ProbPrime with NumTrials.

long NextPrime(long l, long NumTrials=10);
// returns smallest prime >= m.
// Uses ProbPrime with NumTrials.


/*****************************************************************************\

                               Exponentiation

\*****************************************************************************/


void power(ZZ& x, const ZZ& a, long e); // x = a^e (e >= 0)

void power(ZZ& x, long a, long e); // x = a^e (e >= 0) 



/*****************************************************************************\

                               Square Roots

\*****************************************************************************/


void SqrRoot(ZZ& x, const ZZ& a); // x = floor(a^{1/2}) (a >= 0)

long SqrRoot(long a); // x = floor(a^{1/2}) (a >= 0)



/*****************************************************************************\

                    Small Prime Generation

\*****************************************************************************/


// primes are generated in sequence, starting at 2, 
// and up to a maximum that is near min(ZZ_RADIX, 2^30).

class PrimeSeq {
public:
   PrimeSeq();
   ~PrimeSeq();

   long next();
   // returns next prime in the sequence.
   // returns 0 if list of small primes is exhausted.

   void reset(long b);
   // resets generator so that the next prime in the sequence
   // is the smallest prime >= b.

private:
   PrimeSeq(const PrimeSeq&);        // disabled
   void operator=(const PrimeSeq&);  // disabled

};



/*****************************************************************************\

                             Modular Arithmetic

The following routines perform arithmetic mod n, where n > 1.
All arguments (other than exponents) are assumed to be in the range 0..n-1.
ALIAS RESTRICTION: in all of these routines, it is 
                   assumed that n is not aliased by 
                   any of the outputs.

\*****************************************************************************/


void AddMod(ZZ& x, const ZZ& a, const ZZ& b, const ZZ& n); 
// x = (a+b)%n

void SubMod(ZZ& x, const ZZ& a, const ZZ& b, const ZZ& n);
// x = (a-b)%n

void NegateMod(ZZ& x, const ZZ& a, const ZZ& n);
// x = -a % n


void AddMod(ZZ& x, const ZZ& a, long b, const ZZ& n);
// x = (a+b)%n
void AddMod(ZZ& x, long a, const ZZ& b, const ZZ& n);
// x = (a-b)%n
void SubMod(ZZ& x, const ZZ& a, long b, const ZZ& n);
// x = (a-b)%n
void SubMod(ZZ& x, long a, const ZZ& b, const ZZ& n);
// x = (a-b)%n

void MulMod(ZZ& x, const ZZ& a, const ZZ& b, const ZZ& n);
// x = (a*b)%n

void MulMod(ZZ& x, const ZZ& a, long b, const ZZ& n);
// x = (a*b)%n
void MulMod(ZZ& x, long a, const ZZ& b, const ZZ& n);
// x = (a*b)%n

void SqrMod(ZZ& x, const ZZ& a, const ZZ& n);
// x = a^2 % n

void InvMod(ZZ& x, const ZZ& a, const ZZ& n);
// x = a^{-1} mod n (0 <= x < n)
// error is raised occurs if inverse not defined

long InvModStatus(ZZ& x, const ZZ& a, const ZZ& n);
// if gcd(a,b) = 1, then return-value = 0, x = a^{-1} mod n
// otherwise, return-value = 1, x = gcd(a, n)

void PowerMod(ZZ& x, const ZZ& a, long e, const ZZ& n);
// x = a^e % n (e >= 0)

void PowerMod(ZZ& x, const ZZ& a, const ZZ& e, const ZZ& n);
// x = a^e % n (e >= 0)

void PowerMod(ZZ& x, long a, long e, const ZZ& n);
// x = a^e % n (e >= 0)

inline void PowerMod(ZZ& x, long a, const ZZ& e, const ZZ& n);
// x = a^e % n (e >= 0)




/*****************************************************************************\

                    Jacobi symbol and modular square roots

\*****************************************************************************/


long Jacobi(const ZZ& a, const ZZ& n);
//  compute Jacobi symbol of a and n;
//  assumes 0 <= a < n, n odd

void SqrRootMod(ZZ& x, const ZZ& a, const ZZ& n);
//  computes square root of a mod n;
//  assumes n is an odd prime, and that a is a square mod n



/*****************************************************************************\

                                     Input/Output

I/O Format:

Numbers are written in base 10, with an optional minus sign.

\*****************************************************************************/

istream& operator>>(istream& s, ZZ& x);  
ostream& operator<<(ostream& s, const ZZ& a); 





/*****************************************************************************\

                        Single-precision modular arithmetic

These routines implement single-precision modular arithmetic.
If n is the modulus, all inputs should be in the range 0..n-1.
The number n itself should be in the range 1..2^{ZZ_NBITS}-1.

\*****************************************************************************/




long AddMod(long a, long b, long n); // return (a+b)%n

long SubMod(long a, long b, long n); // return (a-b)%n

long MulMod(long a, long b, long n); // return (a*b)%n

long MulMod(long a, long b, long n, double ninv);
// return (a*b)%n.
// ninv = 1/((double) n).
// This is faster if n is fixed for many multiplications.


long MulMod2(long a, long b, long n, double bninv);
// return (a*b)%n.
// bninv = ((double) b)/((double) n).
// This is faster if both n and b are fixed for many multiplications.


long MulDivRem(long& q, long a, long b, long n, double bninv);
// return (a*b)%n, set q = (a*b)/n.
// bninv = ((double) b)/((double) n)

long InvMod(long a, long n);
// computes a^{-1} mod n.  Error is raised if undefined.

long PowerMod(long a, long e, long n);
// computes a^e mod n, e >= 0

