/*
    Theseus - maximum likelihood superpositioning of macromolecular structures

    Copyright (C) 2004-2013 Douglas L. Theobald

    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program; if not, write to the:

    Free Software Foundation, Inc.,
    59 Temple Place, Suite 330,
    Boston, MA  02111-1307  USA

    -/_|:|_|_\-
*/

#include <pthread.h>
#include <gsl/gsl_rng.h>
#include <gsl/gsl_randist.h>
#include "Threads.h"
#include "MultiPose_local.h"
#include "pdbStats.h"
#include "distfit.h"
#include "ProcGSLSVD.h"
#include "MultiPoseMix.h"

extern int
MultiPose(CdsArray *baseA);

extern void
PrintSuperposStats(CdsArray *cdsA);

static void
HierarchVars(CdsArray *cdsA);

static int
MultiPoseMix(CdsArray *baseA, const double *probs, double *vars);

static void
CalcWtsMix(CdsArray *cdsA, const double *probs);

static void
*MultiPoseMix_pth(void *mixdata_ptr);


static double
FrobTermAtom(CdsArray *cdsA, const int atom)
{
    int             i;
    double          trace;
    const int       cnum = cdsA->cnum;
    const int       len = 3 * cnum;
    double         *residuals = cdsA->residuals;

    trace = 0.0;
    for (i = atom * len; i < (atom + 1) * len; ++i)
        trace += residuals[i] * residuals[i];

    return(-0.5 * trace);
}


/* Calculates the likelihood for a specified Gaussian model, given a
   structural superposition.

     NOTA BENE: This function assumes that the variances, covariance matrices,
     hierarchical model parameters, average coordinates, rotations, and 
     translations have all been pre-calculated. Even when not calculating the
     optimal ML rotations and translation transformations, the other parameters
     in general must be estimated iteratively, as described below.

   This is not nearly as trivial as it may first appear. For the dimensionally
   weighted case, this involves an iterative ML estimate of the covariance
   matrices, even when the atomic row-wise matrix is assumed to be diagonal or
   proportional to the identity matrix. The way I do it, the superposition as a
   whole is rotated to bring it into alignment with the principal axes of the
   dimensional covariance matrix. Furthermore, the first term of the likelihood
   equation (the Mahalonobius Frobenius matrix norm term) is normally equal to
   NKD/2 at the maximum. However, when using shrinkage or hierarchical estimates
   of the covariance matrices, this convenient simplification no longer holds,
   and the double matrix-weighted Frobenius norm must be calculated explicitly.
*/
static double
CalcLogLAtom(CdsArray *cdsA, const int atom)
{
    const double    cnum = cdsA->cnum;
    const double    nk = cnum;
    const double    nd = 3.0 * cnum;
    const double    ndk = nk * 3.0;
    const double    ndk2 = 0.5 * ndk;
    const double   *var = (const double *) cdsA->var;
    double          lndetrow , frobterm, logL;
    Algorithm      *algo = cdsA->algo;

    lndetrow = frobterm = 0.0;

    if (algo->leastsquares == 1)
    {
        frobterm = FrobTermAtom(cdsA, atom);
    }
    else if (algo->varweight == 1)
    {
        lndetrow = log(var[atom]);

        if (algo->hierarch != 0)
            frobterm = FrobTermAtom(cdsA, atom);
        else
            frobterm = -ndk2;
    }

    logL = frobterm - ndk2 * log(2.0*MY_PI) - 0.5 * nd * lndetrow;

    return(logL);
}


int
VecTestBinary(const double *vec, const int len, const double precision)
{
    int          i;

    for (i = 0; i < len; ++i)
        if(vec[i] > precision && vec[i] < 1.0 - precision)
            return(0);

    return(1);
}


void
AveProb(double *aveprobs, const int mixn, double **probs, const int vlen)
{
    int             i, j;

    for (i = 0; i < mixn; ++i)
    {
        aveprobs[i] = 0.0;
        for (j = 0; j < vlen; ++j)
            aveprobs[i] += probs[i][j];

        aveprobs[i] /= vlen;
    }
}


static int
CheckConvergenceMix(const double *vec1, const double *vec2, const int len, const double precision)
{
    return(VecEq(vec1, vec2, len, precision));
    /* return(VecTestBinary(vec1, len, precision)); */
}


void
NewCalcMixDens(CdsArray **mixA, const int mixn, double **probs)
{
    const int       vlen = mixA[0]->vlen, cnum = mixA[0]->cnum;
    double          dist, sump, pij, summix;
    int             i, j, k, m;

    for (j = 0; j < mixn; ++j)
        AveCds(mixA[j]);

    for (m = 0; m < vlen; ++m)
    {
        for (j = 0; j < mixn; ++j)
        {
            sump = 0.0;
            for (i = 0; i < cnum; ++i)
            {
                summix = 0.0;
                for (k = 0; k < mixn; ++k)
                {
                    dist = SqrCdsDist(mixA[k]->cds[i], m, mixA[k]->avecds, m);
                    summix += probs[k][m] * normal_pdf(sqrt(dist), 0.0, 3.0 * mixA[k]->var[m]);
                }

                dist = SqrCdsDist(mixA[j]->cds[i], m, mixA[j]->avecds, m);
                pij = probs[j][m] * normal_pdf(sqrt(dist), 0.0, 3.0 * mixA[j]->var[m]);
                sump += pij / summix; 
            }

            probs[j][m] = sump / cnum;
        }

        sump = 0.0;
        for (j = 0; j < mixn; ++j)
            sump += probs[j][m];

        for (j = 0; j < mixn; ++j)
            probs[j][m] /= sump;
    }

/*     for (i = 0; i < vlen; ++i) */
/*     { */
/*         printf("\n%3d:", mixA[0]->cds[0]->resSeq[i]); */
/*         for (j = 0; j < mixn; ++j) */
/*             printf(" % 6.4f", probs[j][i]); */
/*     } */
}


void
OldCalcMixDens(CdsArray **mixA, const int mixn, double **probs)
{
    const int       vlen = mixA[0]->vlen, cnum = mixA[0]->cnum;
    double          dist, sump;
    int             i, j, k, m;

    for (j = 0; j < mixn; ++j)
        AveCds(mixA[j]);

    for (k = 0; k < mixn; ++k)
    {
		for (i = 0; i < vlen; ++i)
		{
			dist = 0.0;
			for (j = 0; j < cnum; ++j)
				dist += SqrCdsDist(mixA[k]->cds[j], i, mixA[k]->avecds, i);
	
			probs[k][i] *= normal_pdf(sqrt(dist), 0.0, 3.0 * cnum * mixA[k]->var[i]);
		}
    }

    for (m = 0; m < vlen; ++m)
    {
        sump = 0.0;
        for (j = 0; j < mixn; ++j)
            sump += probs[j][m];

        for (j = 0; j < mixn; ++j)
            probs[j][m] /= sump;
    }

/*     for (i = 0; i < vlen; ++i) */
/*     { */
/*         printf("\n%3d:", mixA[0]->cds[0]->resSeq[i]); */
/*         for (j = 0; j < mixn; ++j) */
/*             printf(" % 6.4f", probs[j][i]); */
/*     } */
}


void
CalcMixDens(CdsArray *mixA, const double aveprob, double *probs)
{
    const int       vlen = mixA->vlen, cnum = mixA->cnum;
    double          dist;
    int             i, j;

    AveCds(mixA);

    for (i = 0; i < vlen; ++i)
    {
        dist = 0.0;
        for (j = 0; j < cnum; ++j)
            dist += SqrCdsDist(mixA->cds[j], i, mixA->avecds, i);

        probs[i] *= /* aveprob *  */normal_pdf(sqrt(dist), 0.0, 3.0 * cnum * mixA->var[i]);
    }
}


void
CalcMixProbs(double **probs, const int mixn, const int vlen)
{
    int             i, j;
    double          sump;

    for (i = 0; i < vlen; ++i)
    {
        sump = 0.0;
        for (j = 0; j < mixn; ++j)
            sump += probs[j][i];

        if (sump == 0.0)
            for (j = 0; j < mixn; ++j)
                probs[j][i] = 1.0 / mixn;
        else
            for (j = 0; j < mixn; ++j)
                probs[j][i] /= sump;
    }

/*     for (i = 0; i < vlen; ++i) */
/*     { */
/*         printf("\n%3d:", i+1); */
/*         for (j = 0; j < mixn; ++j) */
/*             printf(" % 6.4f", probs[j][i]); */
/*     } */
}


void
BinaryMixProbs(double **probs, const int mixn, const int vlen)
{
    int             i, j;

    for (i = 0; i < vlen; ++i)
    {
        for (j = 0; j < mixn; ++j)
        {
            if (probs[j][i] > 1.0 / mixn)
                probs[j][i] = 1.0;
            else
                probs[j][i] = 0.0;
        }
    }

    for (i = 0; i < vlen; ++i)
    {
        printf("%3d:\n", i+1);
        for (j = 0; j < mixn; ++j)
            printf(" % 6.4f", probs[j][i]);
    }
}


void
InitializeMix(CdsArray *cdsA, double **probs)
{
    int             i, j;
    const int       vlen = cdsA->vlen, mixn = cdsA->algo->mixture;
    double          ave, sum, aveprob;

    const gsl_rng_type     *T = NULL;
    gsl_rng                *r2 = NULL;


    gsl_rng_env_setup();
    gsl_rng_default_seed = time(NULL);
    T = gsl_rng_ranlxs2;
    r2 = gsl_rng_alloc(T);

    MultiPose(cdsA);

    ave = 0.0;
    for (i = 0; i < vlen; ++i)
        ave += 1.0 / cdsA->var[i];
    ave = vlen/ave;

    for (i = 0; i < vlen; ++i)
        probs[0][i] = normal_pdf(sqrt(cdsA->var[i]), 0.0, ave);

    aveprob = 0.0;
    for (i = 0; i < vlen; ++i)
        aveprob += probs[0][i] /* log(probs[0][i]) */;
    aveprob = aveprob / vlen /* exp(aveprob/vlen) */;

    for (i = 0; i < vlen; ++i)
        probs[0][i] = probs[0][i] / (probs[0][i] + aveprob);

    for (i = 0; i < vlen; ++i)
        probs[1][i] = 1.0 - probs[0][i];

    for (j = 2; j < mixn; ++j)
        for (i = 0; i < vlen; ++i)
            probs[j][i] += uniform_dev(0.0, 1.0 / mixn, r2);

    for (i = 0; i < vlen; ++i)
    {
        sum = 0.0;
        for (j = 0; j < mixn; ++j)
            sum += probs[j][i];

        for (j = 0; j < mixn; ++j)
            probs[j][i] /= sum;
    }
}


void
CalcSubsetProbs(CdsArray **mixA, const int subset, double **probs)
{
    int             i, j, k, m;
    const int       vlen = mixA[0]->vlen, cnum = mixA[0]->cnum;
    double          sump, summix, dist, pij, aveprob, sum;

	for (m = 0; m < vlen; ++m)
	{
		for (j = 0; j < subset - 1; ++j)
		{
			sump = 0.0;
			for (i = 0; i < cnum; ++i)
			{
				summix = 0.0;
				for (k = 0; k < subset - 1; ++k)
				{
					dist = SqrCdsDist(mixA[k]->cds[i], m, mixA[k]->avecds, m);
					summix += probs[k][m] * normal_pdf(sqrt(dist), 0.0, 3.0 * mixA[k]->var[m]);
				}

				dist = SqrCdsDist(mixA[j]->cds[i], m, mixA[j]->avecds, m);
				pij = probs[j][m] * normal_pdf(sqrt(dist), 0.0, 3.0 * mixA[j]->var[m]);
				sump += pij / summix; 
			}

			probs[j][m] = sump / cnum;
		}
	}

    aveprob = 0.0;
    for (i = 0; i < vlen; ++i)
        for (j = 0; j < subset - 1; ++j)
            aveprob += log(probs[j][i]);
    aveprob = exp(aveprob/vlen);

	for (i = 0; i < vlen; ++i)
	    probs[subset - 1][i] = aveprob;

	for (i = 0; i < vlen; ++i)
	{
		sum = 0.0;
		for (j = 0; j < subset; ++j)
			sum += probs[j][i];

		for (j = 0; j < subset; ++j)
			probs[j][i] /= sum;
	}
}


void
InitializeSubsets(CdsArray **mixA, const int mixn, double **probs, double **vars)
{
    int             i, j, n, count;
    const int       vlen = mixA[0]->vlen;
    double         *oldprobs = calloc(vlen, sizeof(double));

    InitializeMix(mixA[0], probs);

    /* VecPrint(probs[0], vlen); */

	count = 0;
	while(1)
	{
		++count;

		memcpy(oldprobs, probs[0], vlen * sizeof(double));

		for (i = 0; i < 2; ++i)
			MultiPoseMix(mixA[i], probs[i], vars[i]);

		NewCalcMixDens(mixA, 2, probs);

		if (CheckConvergenceMix(probs[0], oldprobs, vlen, 1e-2) == 1 && count > 5)
			break;
	}

    /* VecPrint(probs[0], vlen); */

    for (n = 3; n < mixn; ++n)
    {
		for (j = 0; j < n; ++j)
			AveCds(mixA[j]);
	
        CalcSubsetProbs(mixA, n, probs);

		count = 0;
		while(1)
		{
			++count;
	
			memcpy(oldprobs, probs[0], vlen * sizeof(double));
	
			for (i = 0; i < n; ++i)
				MultiPoseMix(mixA[i], probs[i], vars[i]);
	
			NewCalcMixDens(mixA, n, probs);
	
			if (CheckConvergenceMix(probs[0], oldprobs, vlen, 1e-3) == 1 && count > 5)
				break;
		}

        /* VecPrint(probs[0], vlen); */
    }

    free(oldprobs);
}



void
InitializeSubsets_pth(CdsArray **mixA, const int mixn, double **probs,
                      pthread_t *callThd, pthread_attr_t *attr,
                      MixData **mixdata, double **vars)
{
    int             i, j, n, count, rc;
    const int       vlen = mixA[0]->vlen;
    double         *oldprobs = calloc(vlen, sizeof(double));

    InitializeMix(mixA[0], probs);

    /* VecPrint(probs[0], vlen); */

	count = 0;
	while(1)
	{
		++count;

		memcpy(oldprobs, probs[0], vlen * sizeof(double));

        rc = 0;
        for (i = 0; i < 2; ++i)
        {
            mixdata[i]->cdsA = mixA[i];
            mixdata[i]->probs = probs[i];
            mixdata[i]->vars = vars[i];

			rc = pthread_create(&callThd[i], attr, MultiPoseMix_pth, (void *) mixdata[i]);

			if (rc)
			{
				printf("ERROR811: return code from pthread_create() %d is %d\n", i, rc);
				fflush(NULL);
				exit(EXIT_FAILURE);
			}
        }

        rc = 0;
		for (i = 0; i < 2; ++i)
		{
			rc = pthread_join(callThd[i], (void **) NULL);
	
			if (rc)
			{
				printf("ERROR812: return code from pthread_join() %d is %d\n", i, rc);
				fflush(NULL);
				exit(EXIT_FAILURE);
			}
		}

		NewCalcMixDens(mixA, 2, probs);

		if (CheckConvergenceMix(probs[0], oldprobs, vlen, 1e-2) == 1 && count > 5)
			break;
	}

    /* VecPrint(probs[0], vlen); */

    for (n = 3; n < mixn; ++n)
    {
		for (j = 0; j < n; ++j)
			AveCds(mixA[j]);
	
        CalcSubsetProbs(mixA, n, probs);

		count = 0;
		while(1)
		{
			++count;
	
			memcpy(oldprobs, probs[0], vlen * sizeof(double));
	
			rc = 0;
			for (i = 0; i < n; ++i)
			{
				mixdata[i]->cdsA = mixA[i];
				mixdata[i]->probs = probs[i];
				mixdata[i]->vars = vars[i];
				
				/* MultiPoseMix_pth((void *) mixdata[i]); */
	
				rc = pthread_create(&callThd[i], attr, MultiPoseMix_pth, (void *) mixdata[i]);
		
				if (rc)
				{
					printf("ERROR813: return code from pthread_create() %d is %d\n", i, rc);
					fflush(NULL);
					exit(EXIT_FAILURE);
				}
			}

			rc = 0;
			for (i = 0; i < n; ++i)
			{
				rc = pthread_join(callThd[i], (void **) NULL);
		
				if (rc)
				{
					printf("ERROR814: return code from pthread_join() %d is %d\n", i, rc);
					fflush(NULL);
					exit(EXIT_FAILURE);
				}
			}

			NewCalcMixDens(mixA, n, probs);
	
			if (CheckConvergenceMix(probs[0], oldprobs, vlen, 1e-3) == 1 && count > 5)
				break;
		}

        /* VecPrint(probs[0], vlen); */
    }

    free(oldprobs);
}


void
PrintMixtures(CdsArray **mixA, const int mixn, double **probs, const int vlen)
{
    int             i, j;

    for (i = 0; i < mixn; ++i)
    {
        printf("\nselect ");

        for (j = 0; j < vlen; ++j)
        {
            while (probs[i][j] < 0.5 && j < vlen)
                ++j;

            if (j >= vlen)
                break;

            printf("%d-", mixA[0]->cds[0]->resSeq[j]);

            while (probs[i][j] > 0.5 && j < vlen)
                ++j;

            printf("%d", mixA[0]->cds[0]->resSeq[j-1]);

            if (j >= vlen)
                break;
            else
                printf(",");
        }

        printf("\n-s");

        for (j = 0; j < vlen; ++j)
        {
            while (probs[i][j] < 0.5 && j < vlen)
                ++j;

            if (j >= vlen)
                break;

            printf("%d-", mixA[0]->cds[0]->resSeq[j]);

            while (probs[i][j] > 0.5 && j < vlen)
                ++j;

            printf("%d", mixA[0]->cds[0]->resSeq[j-1]);

            if (j >= vlen)
                break;
            else
                printf(":");
        }
    }

    printf("\n\n");
    fflush(NULL);
}


void
PrintMixturesCols(CdsArray **mixA, const int mixn, double **probs, const int vlen)
{
    int             i, j;

    for (i = 0; i < mixn; ++i)
    {
        printf("\n-s");

        for (j = 0; j < vlen; ++j)
        {
            while (probs[i][j] < 0.5 && j < vlen)
                ++j;

            if (j >= vlen)
                break;

            printf("%d-", j+1);

            while (probs[i][j] > 0.5 && j < vlen)
                ++j;

            printf("%d", j);

            if (j >= vlen)
                break;
            else
                printf(":");
        }
    }

    printf("\n\n");
    fflush(NULL);
}


void
CalcLRTs(CdsArray **mixA, const int mixn, const int cnum, double **probs, const int vlen)
{
    double          dist, largest, smallest, logLR, sum;
    double         *logL = malloc(mixn * sizeof(double));
    double         *mxlogL = calloc(mixn, sizeof(double));
    int             i, j, k;

    for (i = 0; i < vlen; ++i)
    {
        printf("Res: %4d\n", mixA[0]->cds[0]->resSeq[i]);

        for (j = 0; j < mixn; ++j)
        {
			dist = 0.0;
			for (k = 0; k < cnum; ++k)
				dist += SqrCdsDist(mixA[j]->cds[k], i, mixA[j]->avecds, i);

            if (probs[j][i] == 0.0)
                printf(" %10.3f", 0.0);
            else
            {
                logL[j] = log(probs[j][i]) + normal_lnpdf(sqrt(dist), 0.0, 3.0 * cnum * mixA[j]->var[i]);
			    printf(" %10.3f", logL[j]);
			}

            mxlogL[j] += logL[j];
        }

        largest = -DBL_MAX;
        smallest = DBL_MAX;

        for (j = 0; j < mixn; ++j)
        {
            if (largest < logL[j])
                largest = logL[j];

            if (smallest > logL[j])
                smallest = logL[j];
        }

		logLR = largest - smallest;
		printf(" %10.3f", logLR);

		if (logLR >= 3.0)
			printf(" *");
    }

    printf("         \n");
    for (j = 0; j < mixn; ++j)
        printf(" %10.3f", mxlogL[j]);

    sum = 0.0;
    for (j = 0; j < mixn; ++j)
        sum += mxlogL[j];

    printf(" %10.3f", sum);

    fflush(NULL);

    free(logL);
    free(mxlogL);
}


static double
CalcAICcorrxn(CdsArray *cdsA)
{
    double         n, p;

    cdsA->stats->nparams = p = CalcParamNum(cdsA);
    cdsA->stats->ndata = n = 3.0 * cdsA->cnum * cdsA->vlen;

    return(- p * n / (n - p - 1));
}


static void
CalcMixAICLogL(CdsArray **mixA, const int mixn, const double **probs,
               const int vlen, double *AIC, double *logL)
{
    double          Lik;
    int             i, j;

    for (i = 0; i < mixn; ++i)
    {
		if (mixA[i]->algo->leastsquares == 1)
			CalcNormResidualsLS(mixA[i]);
		else
			CalcNormResiduals(mixA[i]);
    }

    *logL = 0.0;
    for (i = 0; i < vlen; ++i)
    {
        Lik = 0.0;
        for (j = 0; j < mixn; ++j)
            Lik += probs[j][i] * exp(CalcLogLAtom(mixA[j], i));

        *logL += log(Lik);
    }

    for (i = 0; i < mixn; ++i)
        *logL += CalcHierarchLogL(mixA[i]);

    *AIC = *logL;
    for (i = 0; i < mixn; ++i)
        *AIC += CalcAICcorrxn(mixA[i]);

    *AIC -= (mixn - 1) * vlen;
}


int
Mixture(CdsArray *cdsA, PDBCdsArray *pdbA)
{
    int             count, i;
    const int       vlen = cdsA->vlen, cnum = cdsA->cnum, mixn = cdsA->algo->mixture;
    double        **probs = calloc(mixn, sizeof(double *));
    double         *aveprobs = calloc(mixn, sizeof(double));
    int            *slope = calloc(mixn, sizeof(int));
    double         *oldprobs = calloc(vlen, sizeof(double));
    CdsArray   **mixA = NULL;
    Algorithm      *algo = cdsA->algo;
    /* Statistics     *stats = cdsA->stats; */
    PDBCdsArray *pdb2A = NULL;
    double        **newprobs = MatAlloc(vlen, cnum);
    double        **vars = MatAlloc(mixn, vlen);
    double          logL, AIC;

    pdb2A = PDBCdsArrayInit();
    PDBCdsArrayAlloc(pdb2A, pdbA->cnum, pdbA->vlen);

    mixA = malloc(mixn * sizeof(CdsArray *));

    for (i = 0; i < mixn; ++i)
    {
        probs[i] = malloc(vlen * sizeof(double));
        mixA[i] = CdsArrayInit();
        CdsArrayAlloc(mixA[i], cnum, vlen);
        CdsArraySetup(mixA[i]);
        CdsArrayCopy(mixA[i], cdsA);
        mixA[i]->algo->write_file = 0;
    }

    printf("    Initializing mixture iterations ... \n");
    fflush(NULL);

    /* InitializeMix(cdsA, probs); */
    InitializeSubsets(mixA, mixn, probs, vars);
/*     BinaryMixProbs(probs, mixn, vlen); */

    /* VecPrint(probs[0], vlen); */

    printf("    Beginning mixture iterations ... \n");
    fflush(NULL);

    count = 0;
    while(1)
    {
        ++count;

        printf("    Iteration %d\n", count);
        fflush(NULL);

        memcpy(oldprobs, probs[0], vlen * sizeof(double));

        for (i = 0; i < mixn; ++i)
        {
            MultiPoseMix(mixA[i], probs[i], vars[i]);
            /* CalcMixDens(mixA[i], aveprobs[i], probs[i]); */
        }

        NewCalcMixDens(mixA, mixn, probs);
        AveProb(aveprobs, mixn, probs, vlen);

/*     VecPrint(probs[0], vlen); */
/*     VecPrint(probs[1], vlen); */

        if (CheckConvergenceMix(probs[0], oldprobs, vlen, algo->precision) == 1 && count > 5)
            break;
    }

    PrintMixtures(mixA, mixn, probs, vlen);
    PrintMixturesCols(mixA, mixn, probs, vlen);
    CalcMixAICLogL(mixA, mixn, (const double **) probs, vlen, &AIC, &logL);

    printf("Omnibus mixture logL: %11.2f AIC: %11.2f\n", logL, AIC);

/*     for (i = 0; i < mixn; ++i) */
/*     { */
/*         mixA[i]->algo->write_file = 0; */
/*      PrintSuperposStats(mixA[i]); */
/*  */
/*      printf("\n    Transforming coordinates ... "); */
/*      fflush(NULL); */
/*  */
/*      for (j = 0; j < cnum; ++j) */
/*      { */
/*          PDBCdsCopyAll(pdb2A->cds[j], pdbA->cds[j]); */
/*          Mat3Cpy(pdb2A->cds[j]->matrix, (const double **) mixA[i]->cds[j]->matrix); */
/*          memcpy(pdb2A->cds[j]->translation, mixA[i]->cds[j]->translation, 3 * sizeof(double)); */
/*      } */
/*  */
/*      Mat3Cpy(pdb2A->avecds->matrix, (const double **) mixA[i]->avecds->matrix); */
/*      memcpy(pdb2A->avecds->translation, mixA[i]->avecds->translation, 3 * sizeof(double)); */
/*  */
/*      for (j = 0; j < cnum; ++j) */
/*          TransformPDBCdsIp(pdb2A->cds[j]); */
/*  */
/*      if (algo->alignment == 1) */
/*          Align2segID(pdb2A); */
/*  */
/*      printf("\n    Writing transformed coordinates PDB file ... "); */
/*      fflush(NULL); */
/*  */
/*      WriteTheseusModelFile(pdb2A, algo, stats, mystrcat(algo->rootname, "_sup.pdb")); */
/*  */
/*      if (algo->binary == 3 || algo->binary == 4) */
/*      { */
/*          printf("\n    Writing transformed coordinates binary file ... "); */
/*          fflush(NULL); */
/*   */
/*          WriteBinPDBCdsArray(pdb2A); */
/*      } */
/*  */
/*      printf("\n    Writing average coordinate file ... "); */
/*      fflush(NULL); */
/*  */
/*      TransformCdsIp(mixA[i]->avecds); */
/*      CopyCds2PDB(pdb2A->avecds, mixA[i]->avecds); */
/*      WriteAvePDBCdsFile(pdb2A, mystrcat(algo->rootname, "_ave.pdb")); */
/*     } */

    for (i = 0; i < mixn; ++i)
    {
        CdsArrayDestroy(&mixA[i]);
        free(probs[i]);
    }

    free(mixA);
    free(probs);
    free(aveprobs);
    free(slope);
    free(oldprobs);
    PDBCdsArrayDestroy(&pdb2A);
    MatDestroy(&newprobs);
    MatDestroy(&vars);

    return(count);
}


static void
*MultiPoseMix_pth(void *mixdata_ptr)
{
    MixData        *mixdata = (MixData *) mixdata_ptr;
/*     pthread_mutex_t mutexsum; */
/*  */
/*     pthread_mutex_init(&mutexsum, NULL); */
/*     pthread_mutex_lock(&mutexsum); */
    mixdata->rounds = MultiPoseMix(mixdata->cdsA, mixdata->probs, mixdata->vars);
/*     pthread_mutex_unlock(&mutexsum); */
/*     pthread_mutex_destroy(&mutexsum); */

    pthread_exit((void *) 0);
    /* return((void *) 0); */
}


int
Mixture_pth(CdsArray *cdsA, PDBCdsArray *pdbA)
{
    int             count, i, j, lineskip;
    const int       vlen = cdsA->vlen, cnum = cdsA->cnum, mixn = cdsA->algo->mixture;
    double        **probs = calloc(mixn, sizeof(double *));
    double         *aveprobs = calloc(mixn, sizeof(double));
    int            *slope = calloc(mixn, sizeof(int));
    double         *oldprobs = calloc(vlen, sizeof(double));
    double        **vars = MatAlloc(mixn, vlen);
    CdsArray   **mixA = NULL;
    Algorithm      *algo = cdsA->algo;
    double        **newprobs = MatAlloc(vlen, cnum);
    double          logL, AIC, Lik;

    int             rc;
    MixData       **mixdata = NULL;
    pthread_t      *callThd;
    pthread_attr_t  attr;

    callThd = malloc(mixn * sizeof(pthread_t));
    mixdata = malloc(mixn * sizeof(MixData *));;
    for (i = 0; i < mixn; ++i)
        mixdata[i] = malloc(sizeof(MixData));

    mixA = malloc(mixn * sizeof(CdsArray *));

    pthread_attr_init(&attr);
    pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_JOINABLE);
    pthread_attr_setscope(&attr, PTHREAD_SCOPE_SYSTEM);

    for (i = 0; i < mixn; ++i)
        memsetd(vars[i], 1.0, vlen);

    for (i = 0; i < mixn; ++i)
    {
        probs[i] = malloc(vlen * sizeof(double));
        mixA[i] = CdsArrayInit();
        CdsArrayAlloc(mixA[i], cnum, vlen);
        CdsArraySetup(mixA[i]);
        CdsArrayCopy(mixA[i], cdsA);
        mixA[i]->algo->write_file = 0;
    }

    printf("    Initializing mixture iterations ... \n");
    fflush(NULL);

    /* InitializeMix(cdsA, probs); */
    /* InitializeSubsets(mixA, mixn, probs); */
    InitializeSubsets_pth(mixA, mixn, probs, callThd, &attr, mixdata, vars);

    /* VecPrint(probs[0], vlen); */

    printf("    Beginning mixture iterations ... \n");
    fflush(NULL);

    count = 0;
    while(1)
    {
        ++count;

        printf("    Iteration %d\n", count);
        fflush(NULL);

        memcpy(oldprobs, probs[0], vlen * sizeof(double));

        rc = 0;
        for (i = 0; i < mixn; ++i)
        {
            mixdata[i]->cdsA = mixA[i];
            mixdata[i]->probs = probs[i];
            mixdata[i]->vars = vars[i];
            
            /* MultiPoseMix_pth((void *) mixdata[i]); */

			rc = pthread_create(&callThd[i], &attr, MultiPoseMix_pth, (void *) mixdata[i]);
	
			if (rc)
			{
				printf("ERROR815: return code from pthread_create() %d is %d\n", i, rc);
				fflush(NULL);
				exit(EXIT_FAILURE);
			}
        }

        rc = 0;
		for (i = 0; i < mixn; ++i)
		{
			rc = pthread_join(callThd[i], (void **) NULL);
	
			if (rc)
			{
				printf("ERROR816: return code from pthread_join() %d is %d\n", i, rc);
				fflush(NULL);
				exit(EXIT_FAILURE);
			}
		}

		for (i = 0; i < mixn; ++i)
		{
			printf("    Mixture %d: %4d rounds\n", i, mixdata[i]->rounds);
			fflush(NULL);
        }

        NewCalcMixDens(mixA, mixn, probs);
        AveProb(aveprobs, mixn, probs, vlen);

		printf("\n    mxp:");
		for (j = 0; j < mixn; ++j)
			printf(" % 6.4f", aveprobs[j]);
		printf("\n");

        if (CheckConvergenceMix(probs[0], oldprobs, vlen, algo->precision) == 1 && count > 5)
            break;

        lineskip = 3 + mixn;
		printf("\033[<%d>A", lineskip);
    }

    pthread_attr_destroy(&attr);

    PrintMixtures(mixA, mixn, probs, vlen);

    for (i = 0; i < mixn; ++i)
    {
		if (algo->leastsquares == 1)
			CalcNormResidualsLS(mixA[i]);
		else
			CalcNormResiduals(mixA[i]);
    }

    logL = 0.0;
    for (i = 0; i < vlen; ++i)
    {
        Lik = 0.0;
        for (j = 0; j < mixn; ++j)
            Lik += probs[j][i] * exp(CalcLogLAtom(mixA[j], i));

        logL += log(Lik);
    }

    for (i = 0; i < mixn; ++i)
        logL += CalcHierarchLogL(mixA[i]);

    AIC = logL;
    for (i = 0; i < mixn; ++i)
        AIC += CalcAICcorrxn(mixA[i]);

    AIC -= (mixn - 1) * vlen;

    printf("  * Omnibus mixture logL: %11.2f AIC: %11.2f\n", logL, AIC);

    for (i = 0; i < mixn; ++i)
    {
        CdsArrayDestroy(&mixA[i]);
        free(probs[i]);
    }

    free(mixA);
    free(probs);
    free(aveprobs);
    free(slope);
    free(oldprobs);
    MatDestroy(&newprobs);
    for (i = 0; i < mixn; ++i)
        free(mixdata[i]);
    free(mixdata);
    free(callThd);
    MatDestroy(&vars);

    return(count);
}


/* For superimposing to an alignment, we don't need to weight by occupancy
   since we are using pseudo-coordinates here from the E-M expectation step */
static void
CalcTranslations(CdsArray *scratchA, Algorithm *algo)
{
    Cds        **cds = scratchA->cds;
    int             i;

	for (i = 0; i < scratchA->cnum; ++i)
	{
		if (algo->alignment == 1 && algo->rounds < 3)
			CenMassWtIpOcc(cds[i], scratchA->w);
		else
			CenMassWtIp(cds[i], scratchA->w);
	}
}


static void
MatDiagMultCdsMultMatDiag(Cds *outcds, const double *wtK, const Cds *cds)
{
    int             i;
    double          wtKi;
    const double   *x = (const double *) cds->x,
                   *y = (const double *) cds->y,
                   *z = (const double *) cds->z;

    for (i = 0; i < cds->vlen; ++i)
    {
        wtKi = wtK[i];

        outcds->x[i] = wtKi * x[i];
        outcds->y[i] = wtKi * y[i];
        outcds->z[i] = wtKi * z[i];
    }
}


static double
CalcRotations(CdsArray *cdsA)
{
    Cds        **cds = cdsA->cds;
    const Cds   *avecds = cdsA->avecds;
    const double   *wts = (const double *) cdsA->w;
    Cds         *tcds = cdsA->tcds;
    double          deviation = 0.0, deviation_sum = 0.0;
    int             i;

	MatDiagMultCdsMultMatDiag(tcds, wts, avecds);

	for (i = 0; i < cdsA->cnum; ++i)
	{
		/* note that the avecds are already multiplied by the weight matrices */
		deviation = ProcGSLSVDvan(cds[i],
									 tcds,
									 cds[i]->matrix,
									 cdsA->tmpmat3a,
									 cdsA->tmpmat3b,
									 cdsA->tmpmat3c,
									 cdsA->tmpvec3a);

/*      RotateCdsIp(cds[i], (const double **) cds[i]->matrix); */

		/* find global rmsd and average cds (both held in structure) */
		cds[i]->wRMSD_from_mean = sqrt(deviation / (3 * cdsA->vlen));
		deviation_sum += deviation;
	}

    return(deviation_sum);
}


static void
HierarchVars(CdsArray *cdsA)
{
    int             i;
    double          mean, mu, lambda, b, c, zeta, sigma;

    switch(cdsA->algo->hierarch)
    {
        case 0:
            break;

        case 1: /* inverse gamma fit of variances, excluding the smallest */
            /* This accounts for the fact that the smallest eigenvalue of the covariance
               matrix is always zero, i.e. the covariance matrix is necessarily of rank
               vlen - 1 */
            if (cdsA->algo->rounds > 4)
                InvGammaFitEvals(cdsA, 1);
            else
                InvGammaFitEvals(cdsA, 0);

            if (cdsA->algo->verbose != 0)
                printf("    HierarchVars() chi2:%f\n", cdsA->stats->hierarch_chi2);
            break;

        case 2:
            InvGammaFitVars(cdsA, 1);
            if (cdsA->algo->verbose != 0)
                printf("    HierarchVars() chi2:%f\n", cdsA->stats->hierarch_chi2);
            break;

        case 3:
            InvGamma1FitEvals(cdsA, 1);
            break;

        case 4:
            InvGammaFitVars_minc(cdsA, 1.0, 1);
            break;

        case 5:
            InvGammaMMFitVars(cdsA, &b, &c);
            break;

        case 6:
            InvGammaStacyFitVars(cdsA, &b, &c);
            break;

        case 7:
            for (i = 0; i < cdsA->vlen; ++i)
                cdsA->var[i] = cdsA->CovMat[i][i];
            cdsA->algo->covweight = 0;
            cdsA->algo->varweight = 1;
            InvGammaFitVars(cdsA, 1);
            cdsA->algo->covweight = 1;
            cdsA->algo->varweight = 0;
            CovMat2CorMat(cdsA->CovMat, cdsA->vlen);
            CorMat2CovMat(cdsA->CovMat, (const double *) cdsA->var, cdsA->vlen);
            break;

        case 8: /* ML fit of variances to a reciprocal inverse gaussian dist */
            RecipInvGaussFitVars(cdsA, &mu, &lambda);
            RecipInvGaussAdjustVars(cdsA, mu, lambda);
            break;

        case 9: /* ML fit of variances to a lognorml distribution */
            LognormalFitVars(cdsA, &zeta, &sigma);
            LognormalAdjustVars(cdsA, zeta, sigma);             
            break;

        case 10:
            InvgaussFitVars(cdsA, &mean, &lambda);
            InvgaussAdjustVars(cdsA, zeta, sigma);
            break;

        case 12: /* inv gamma fit to eigenvalues of covariance mat, but only weighting by variances */
            cdsA->algo->covweight = 1;
            cdsA->algo->varweight = 0;
            if (cdsA->algo->alignment == 1)
                CalcCovMatOcc(cdsA);
            else
                CalcCovMat(cdsA);
            InvGammaFitEvals(cdsA, 1);
            cdsA->algo->covweight = 0;
            cdsA->algo->varweight = 1;
            for (i = 0; i < cdsA->vlen; ++i)
                cdsA->var[i] = cdsA->CovMat[i][i];
            break;

        case 13: /* inv gamma fit to eigenvalues of covariance mat, but only weighting by variances */
            cdsA->algo->covweight = 1;
            cdsA->algo->varweight = 0;
            if (cdsA->algo->alignment == 1)
                CalcCovMatOcc(cdsA);
            else
                CalcCovMat(cdsA);
            InvGammaFitVars(cdsA, 0); /* no iterations */
            cdsA->algo->covweight = 0;
            cdsA->algo->varweight = 1;
            for (i = 0; i < cdsA->vlen; ++i)
                cdsA->var[i] = cdsA->CovMat[i][i];
            break;

        default:
            printf("\n  ERROR:  Bad -g option \"%d\" \n", cdsA->algo->hierarch);
            Usage(0);
            exit(EXIT_FAILURE);
            break;
    }
}


static int
CheckConvergenceInner(CdsArray *cdsA, const double precision)
{
    int             i;

    for (i = 0; i < cdsA->cnum; ++i)
    {
        if (TestIdentMat((const double **) cdsA->cds[i]->matrix, 3, precision) == 0)
        /* if (Mat3FrobEq((const double **) cdsA->cds[i]->last_matrix, (const double **) cdsA->cds[i]->matrix, precision) == 0) */
            return(0);
    }

    return(1);
}


static int
CheckConvergenceOuter(CdsArray *cdsA, int round, const double precision)
{
    Algorithm      *algo = cdsA->algo;
    int             i;

    if (round >= algo->iterations)
        return(1);

    if (algo->abort == 1)
        return(1);

/*     else if (algo->alignment == 1 && round < 10) */
/*         return(0); */
    else if (round > 6)
    {
        /* if (Mat3FrobEq((const double **) mat1, (const double **) mat2, algo->precision) == 0) */
        cdsA->stats->precision = 0.0;
        for (i = 0; i < cdsA->cnum; ++i)
            cdsA->stats->precision += FrobDiffNormIdentMat((const double **) cdsA->cds[i]->matrix, 3);
        cdsA->stats->precision /= cdsA->cnum;

        if (cdsA->stats->precision > precision)
            return(0);
        else
            return(1);
    }
    else
        return(0);
}


/* The real thing */
static int
MultiPoseMix(CdsArray *baseA, const double *probs, double *vars)
{
    int             i, round, innerround;
    int             slxn; /* index of random coord to select as first */
    double          deviation_sum = 0.0;
    const int       cnum = baseA->cnum;
    const int       vlen = baseA->vlen;
    double         *evals = malloc(3 * sizeof(double));
    Algorithm      *algo = NULL;
    Statistics     *stats = NULL;
    Cds        **cds = NULL;
    Cds         *avecds = NULL;
    Cds         *tcds = NULL;
    CdsArray    *scratchA = NULL;

    gsl_rng               *r2 = NULL;
    const gsl_rng_type    *T = NULL;
    T = gsl_rng_ranlxs2;
    r2 = gsl_rng_alloc(T);

    /* setup scratchA */
    scratchA = CdsArrayInit();
    CdsArrayAlloc(scratchA, cnum, vlen);
    CdsArraySetup(scratchA);

    baseA->scratchA = scratchA;

    /* duplicate baseA -- copy to scratchA */
    CdsArrayCopy(scratchA, baseA);

    /* setup local aliases based on scratchA */
    algo = scratchA->algo;
    stats = scratchA->stats;
    cds = scratchA->cds;
    avecds = scratchA->avecds;
    tcds = scratchA->tcds;

    memcpy(scratchA->w, probs, vlen * sizeof(double));
    memcpy(baseA->w, probs, vlen * sizeof(double));
/*     memcpy(scratchA->var, vars, vlen * sizeof(double)); */
/*     memcpy(baseA->var, vars, vlen * sizeof(double)); */
    CalcWtsMix(scratchA, probs);

    stats->hierarch_p1 = 0.0;
    stats->hierarch_p2 = 0.0;

    if (algo->embedave != 0)
    {
        printf("    Calculating distance matrix for embedding average ... \n");
        fflush(NULL);

        CdsCopyAll(avecds, cds[0]);
        DistMatsAlloc(scratchA);

        if (algo->alignment == 1)
            CalcMLDistMatOcc(scratchA);
        else
            CalcMLDistMat(scratchA);

        printf("    Embedding average structure (ML) ... \n");
        fflush(NULL);

        EmbedAveCds(scratchA);

        for (i = 0; i < vlen; ++i)
            avecds->resSeq[i] = i+1;

        printf("    Finished embedding \n");
        fflush(NULL);
    }
    else
    {
        //slxn = (int) (genrand_real2() * cnum);
        slxn = gsl_rng_uniform_int(r2, cnum);
        CdsCopyAll(avecds, baseA->cds[slxn]);
    }

    if (algo->notrans == 0)
    {
        CenMassWtIp(avecds, scratchA->w);
        ApplyCenterIp(avecds);
    }

    /* The outer loop:
       (1) First calculates the translations
       (2) Does inner loop -- calc rotations and average till convergence
       (3) Holding the superposition constant, calculates the covariance
           matrices and corresponding weight matrices, looping till 
           convergence when using a dimensional/axial covariance matrix 
    */
    round = 0;
    while(1)
    {
/*         if (round % 62 == 0) */
/*              printf("\n    "); */
/*         else */
/*             putchar('.'); */
/*         fflush(NULL); */

        ++round;
        baseA->algo->rounds = algo->rounds = round;

        /* Find weighted center and translate all cds */
        CalcTranslations(scratchA, algo);
        for (i = 0; i < cnum; ++i)
            ApplyCenterIp(cds[i]);

        /* save the translation vector for each coord in the array */
        for (i = 0; i < cnum; ++i)
            memcpy(cds[i]->translation, cds[i]->center, 3 * sizeof(double));

        /* when superimposing to an alignemnt, initially iterate into unwted LS for a few rounds */
        if (algo->alignment == 1 && round < 5)
            memsetd(scratchA->w, 1.0, vlen);

        /* Inner loop:
           (1) Calc rotations given weights/weight matrices
           (2) Rotate cds with new rotations
           (3) Recalculate average

           Loops till convergence, holding constant the weights, variances, and covariances
           (and thus the translations too) */
        innerround = 0;
        do
        {
            ++innerround;
            algo->innerrounds += innerround;

            /* save the old rotation matrices to test convergence at bottom of loop */
            for (i = 0; i < cnum; ++i)
                MatCpySym(cds[i]->last_matrix, (const double **) cds[i]->matrix, 3);

            /* find the optimal rotation matrices */
            if (algo->alignment == 1 /* && (round == 1 || cnum == 2) */)
                deviation_sum = CalcRotationsOcc(scratchA);
            else
                deviation_sum = CalcRotations(scratchA);

            if (innerround == 1 &&
                CheckConvergenceOuter(scratchA, round, algo->precision) == 1)
                   goto outsidetheloops;

            /* rotate the scratch cds with new rotation matrix */
            for (i = 0; i < cnum; ++i)
                RotateCdsIp(cds[i], (const double **) cds[i]->matrix);

            /* find global rmsd and average cds (both held in structure) */
			if (algo->alignment == 1)
			{
				AveCdsOcc(scratchA);
				EM_MissingCds(scratchA);
			}
			else
			{
				AveCds(scratchA);
			}

            stats->wRMSD_from_mean = sqrt(deviation_sum / (3 * vlen * cnum));

            if (innerround > 160)
            {
                putchar(',');
                fflush(NULL);
                break;
            }
        }
        while(CheckConvergenceInner(scratchA, algo->precision) == 0);

        /* Weighting by dimensional, axial Xi covariance matrix, here diagonal. */
        /* Holding the superposition constant, calculates the covariance
           matrices and corresponding weight matrices, looping till 
           convergence. */
        CalcCovariances(scratchA);

        /* calculate the weights/weight matrices */
        CalcWtsMix(scratchA, probs);
    }

    outsidetheloops:

/*     printf("\n    "); */
/*     fflush(NULL); */

    CdsArrayCopy(baseA, scratchA);
    memcpy(vars, scratchA->var, vlen * sizeof(double));
    memcpy(vars, baseA->var, vlen * sizeof(double));

    CdsArrayDestroy(&scratchA);
    free(evals);

    gsl_rng_free(r2);
    r2 = NULL;

    return(round);
}


/* Calculates weights corresponding to the atomic, row-wise covariance matrix only */
static void
CalcWtsMix(CdsArray *cdsA, const double *probs)
{
    int             i;
    Algorithm      *algo = cdsA->algo;
    double         *variance = cdsA->var;
    double         *weight = cdsA->w;
    const int       vlen = cdsA->vlen;

    if (algo->noave == 0)
        AveCds(cdsA);

    if (algo->leastsquares != 0)
    {
        for (i = 0; i < vlen; ++i)
            weight[i] = probs[i];

        return;
    }

    if (algo->varweight != 0)
    {
        for (i = 0; i < vlen; ++i)
            if (variance[i] < probs[i] * algo->constant)
                variance[i] = probs[i] * algo->constant;

        HierarchVars(cdsA);

        for (i = 0; i < vlen; ++i)
        {
            if (variance[i] >= DBL_MAX)
                weight[i] = 0.0;
            else if (variance[i] == 0.0)
                weight[i] = 0.0;
            else
                weight[i] =  probs[i] / variance[i];
        }
    }

    /* cdsA->stats->wtnorm = NormalizeWeights(weight, vlen); */
}
