// 3way.cpp - modifed by Wei Dai from a file of unknown source,
//            presumed to be in the public domain

#include "misc.h"
#include "3way.h"

static const word32 START_E = 0x0b0b; // round constant of first encryption round
static const word32 START_D = 0xb1b1; // round constant of first decryption round

static inline word32 reverseBits(word32 a)
{
    a = ((a & 0xAAAAAAAAL) >> 1) | ((a & 0x55555555L) << 1);
    a = ((a & 0xCCCCCCCCL) >> 2) | ((a & 0x33333333L) << 2);
    a = ((a & 0xF0F0F0F0L) >> 4) | ((a & 0x0F0F0F0FL) << 4);
	return Invert(a);
}

static inline void mu(word32 *a)       // inverts the order of the bits of a
{
    a[1] = reverseBits(a[1]);
    word32 t = reverseBits(a[0]);
    a[0] = reverseBits(a[2]);
    a[2] = t;
}


static inline void gamma(word32 *a)   // the nonlinear step
{
    word32 b0, b1;

    b0 = a[0] ^ (a[1]|(~a[2]));
    b1 = a[1] ^ (a[2]|(~a[0]));
    a[2] = a[2] ^ (a[0]|(~a[1]));
    a[0] = b0;
    a[1] = b1;
}

static inline void theta(word32 *a)    // the linear step
{
    word32 t0 = rotr(a[0]^a[1]^a[2], 16U);
    word32 t1 = a[0] ^ a[1];
    word32 t2 = a[0] ^ a[2];
    word32 t3 = a[1] ^ a[2];
    word32 t4 = a[1] >> 8;

    a[1] ^= t0 ^ (t2>>24) ^ (t1<<8) ^ (a[0]>>8) ^ (a[1]<<24);
    a[0] ^= t0 ^ (t3>>24) ^ (t2<<8) ^ (a[2]>>8) ^ (a[0]<<24);
    a[2] ^= t0 ^ (t1>>24) ^ (t3<<8) ^ t4 ^ (a[2]<<24);
}

static inline void pi_1(word32 *a)
{
	a[0] = rotl(a[0], 22U);
	a[2] = rotl(a[2], 1U);
}

static inline void pi_2(word32 *a)
{
	a[0] = rotl(a[0], 1U);
	a[2] = rotl(a[2], 22U);
}

static inline void rho(word32 *a)    // the round function
{
    theta(a) ;
    pi_1(a) ;
    gamma(a) ;
    pi_2(a) ;
}

static void GenerateRoundConstants(word32 strt, word32 *rtab, unsigned int rounds)
{
    for(unsigned i=0; i<=rounds; i++)
    {
        rtab[i] = strt;
        strt <<= 1;
        if (strt&0x10000) strt ^= 0x11011;
    }
}

ThreeWayEncryption::ThreeWayEncryption(const byte *uk, unsigned rounds)
    : rounds(rounds), rc(new word32[rounds+1])
{
    GenerateRoundConstants(START_E, rc, rounds);
    for (int i=0; i<3; i++)
        k[i] = (word32)uk[4*i+3] | ((word32)uk[4*i+2]<<8) | ((word32)uk[4*i+1]<<16) | ((word32)uk[4*i]<<24);
}

ThreeWayEncryption::~ThreeWayEncryption()
{
    k[0]=k[1]=k[2]=0;
    delete [] rc;
}

void ThreeWayEncryption::ProcessBlock(const byte *in, byte * out)
{
    word32 a[3];

#ifdef LITTLE_ENDIAN
    a[0] = Invert(*(word32 *)in);
    a[1] = Invert(*(word32 *)(in+4));
    a[2] = Invert(*(word32 *)(in+8));
#else
    a[0] = *(word32 *)in;
    a[1] = *(word32 *)(in+4);
    a[2] = *(word32 *)(in+8);
#endif

    for(unsigned i=0; i<rounds; i++)
    {
        a[0] ^= k[0] ^ (rc[i]<<16);
        a[1] ^= k[1];
        a[2] ^= k[2] ^ rc[i];
        rho(a);
    }
    a[0] ^= k[0] ^ (rc[rounds]<<16);
    a[1] ^= k[1];
    a[2] ^= k[2] ^ rc[rounds];
    theta(a);

#ifdef LITTLE_ENDIAN
    *(word32 *)out = Invert(a[0]);
    *(word32 *)(out+4) = Invert(a[1]);
    *(word32 *)(out+8) = Invert(a[2]);
#else
    *(word32 *)out = a[0];
    *(word32 *)(out+4) = a[1];
    *(word32 *)(out+8) = a[2];
#endif
}

ThreeWayDecryption::ThreeWayDecryption(const byte *uk, unsigned rounds)
    : rounds(rounds), rc(new word32[rounds+1])
{
    GenerateRoundConstants(START_D, rc, rounds);
    for (int i=0; i<3; i++)
        k[i] = (word32)uk[4*i+3] | ((word32)uk[4*i+2]<<8) | ((word32)uk[4*i+1]<<16) | ((word32)uk[4*i]<<24);
    theta(k);
    mu(k);
}

ThreeWayDecryption::~ThreeWayDecryption()
{
    k[0]=k[1]=k[2]=0;
    delete [] rc;
}

void ThreeWayDecryption::ProcessBlock(const byte *in, byte * out)
{
    word32 a[3];

#ifdef LITTLE_ENDIAN
    a[0] = Invert(*(word32 *)in);
    a[1] = Invert(*(word32 *)(in+4));
    a[2] = Invert(*(word32 *)(in+8));
#else
    a[0] = *(word32 *)in;
    a[1] = *(word32 *)(in+4);
    a[2] = *(word32 *)(in+8);
#endif

    mu(a);
    for(unsigned i=0; i<rounds; i++)
    {
        a[0] ^= k[0] ^ (rc[i]<<16);
        a[1] ^= k[1];
        a[2] ^= k[2] ^ rc[i];
        rho(a);
    }
    a[0] ^= k[0] ^ (rc[rounds]<<16);
    a[1] ^= k[1];
    a[2] ^= k[2] ^ rc[rounds];
    theta(a);
    mu(a);

#ifdef LITTLE_ENDIAN
    *(word32 *)out = Invert(a[0]);
    *(word32 *)(out+4) = Invert(a[1]);
    *(word32 *)(out+8) = Invert(a[2]);
#else
    *(word32 *)out = a[0];
    *(word32 *)(out+4) = a[1];
    *(word32 *)(out+8) = a[2];
#endif
}
