/*
 * Decompiled with CFR 0.152.
 */
package umontreal.iro.lecuyer.probdistmulti;

import optimization.Uncmin_f77;
import optimization.Uncmin_methods;
import umontreal.iro.lecuyer.probdistmulti.ContinuousDistributionMulti;
import umontreal.iro.lecuyer.util.Num;

public class DirichletDist
extends ContinuousDistributionMulti {
    protected double[] alpha;

    public DirichletDist(double[] alpha) {
        this.setParams(alpha);
    }

    public double density(double[] x) {
        return DirichletDist.density_(this.alpha, x);
    }

    public double[] getMean() {
        return DirichletDist.getMean_(this.alpha);
    }

    public double[][] getCovariance() {
        return DirichletDist.getCovariance_(this.alpha);
    }

    public double[][] getCorrelation() {
        return DirichletDist.getCorrelation_(this.alpha);
    }

    private static void verifParam(double[] alpha) {
        for (int i = 0; i < alpha.length; ++i) {
            if (!(alpha[i] <= 0.0)) continue;
            throw new IllegalArgumentException("alpha[" + i + "] <= 0");
        }
    }

    private static double density_(double[] alpha, double[] x) {
        double alpha0 = 0.0;
        double sumLnGamma = 0.0;
        double sumAlphaLnXi = 0.0;
        if (alpha.length != x.length) {
            throw new IllegalArgumentException("alpha and x must have the same dimension");
        }
        for (int i = 0; i < alpha.length; ++i) {
            alpha0 += alpha[i];
            sumLnGamma += Num.lnGamma(alpha[i]);
            sumAlphaLnXi += (alpha[i] - 1.0) * Math.log(x[i]);
        }
        return Math.exp(Num.lnGamma(alpha0) - sumLnGamma + sumAlphaLnXi);
    }

    public static double density(double[] alpha, double[] x) {
        DirichletDist.verifParam(alpha);
        return DirichletDist.density_(alpha, x);
    }

    private static double[][] getCovariance_(double[] alpha) {
        int i;
        double[][] cov = new double[alpha.length][alpha.length];
        double alpha0 = 0.0;
        for (i = 0; i < alpha.length; ++i) {
            alpha0 += alpha[i];
        }
        for (i = 0; i < alpha.length; ++i) {
            for (int j = 0; j < alpha.length; ++j) {
                cov[i][j] = -(alpha[i] * alpha[j]) / (alpha0 * alpha0 * (alpha0 + 1.0));
            }
            cov[i][i] = alpha[i] / alpha0 * (1.0 - alpha[i] / alpha0) / (alpha0 + 1.0);
        }
        return cov;
    }

    public static double[][] getCovariance(double[] alpha) {
        DirichletDist.verifParam(alpha);
        return DirichletDist.getCovariance_(alpha);
    }

    private static double[][] getCorrelation_(double[] alpha) {
        int i;
        double[][] corr = new double[alpha.length][alpha.length];
        double alpha0 = 0.0;
        for (i = 0; i < alpha.length; ++i) {
            alpha0 += alpha[i];
        }
        for (i = 0; i < alpha.length; ++i) {
            for (int j = 0; j < alpha.length; ++j) {
                corr[i][j] = -Math.sqrt(alpha[i] * alpha[j] / ((alpha0 - alpha[i]) * (alpha0 - alpha[j])));
            }
            corr[i][i] = 1.0;
        }
        return corr;
    }

    public static double[][] getCorrelation(double[] alpha) {
        DirichletDist.verifParam(alpha);
        return DirichletDist.getCorrelation_(alpha);
    }

    @Deprecated
    public static double[] getMaximumLikelihoodEstimate(double[][] x, int n, int d) {
        return DirichletDist.getMLE(x, n, d);
    }

    public static double[] getMLE(double[][] x, int n, int d) {
        int j;
        int i;
        if (n <= 0) {
            throw new IllegalArgumentException("n <= 0");
        }
        if (d <= 0) {
            throw new IllegalArgumentException("d <= 0");
        }
        double[] logP = new double[d];
        double[] mean = new double[d];
        double[] var = new double[d];
        for (i = 0; i < d; ++i) {
            logP[i] = 0.0;
            mean[i] = 0.0;
        }
        for (i = 0; i < n; ++i) {
            for (j = 0; j < d; ++j) {
                int n2 = j;
                logP[n2] = logP[n2] + Math.log(x[i][j]);
                int n3 = j;
                mean[n3] = mean[n3] + x[i][j];
            }
        }
        i = 0;
        while (i < d) {
            int n4 = i;
            logP[n4] = logP[n4] / (double)n;
            int n5 = i++;
            mean[n5] = mean[n5] / (double)n;
        }
        double sum = 0.0;
        for (j = 0; j < d; ++j) {
            sum = 0.0;
            for (i = 0; i < n; ++i) {
                sum += (x[i][j] - mean[j]) * (x[i][j] - mean[j]);
            }
            var[j] = sum / (double)n;
        }
        double alpha0 = mean[0] * (1.0 - mean[0]) / var[0] - 1.0;
        Optim system = new Optim(logP, n);
        double[] parameters = new double[d];
        double[] xpls = new double[d + 1];
        double[] alpha = new double[d + 1];
        double[] fpls = new double[d + 1];
        double[] gpls = new double[d + 1];
        int[] itrcmd = new int[2];
        double[][] a = new double[d + 1][d + 1];
        double[] udiag = new double[d + 1];
        for (i = 1; i <= d; ++i) {
            alpha[i] = mean[i - 1] * alpha0;
        }
        Uncmin_f77.optif0_f77((int)d, (double[])alpha, (Uncmin_methods)system, (double[])xpls, (double[])fpls, (double[])gpls, (int[])itrcmd, (double[][])a, (double[])udiag);
        for (i = 0; i < d; ++i) {
            parameters[i] = xpls[i + 1];
        }
        return parameters;
    }

    private static double[] getMean_(double[] alpha) {
        int i;
        double alpha0 = 0.0;
        double[] mean = new double[alpha.length];
        for (i = 0; i < alpha.length; ++i) {
            alpha0 += alpha[i];
        }
        for (i = 0; i < alpha.length; ++i) {
            mean[i] = alpha[i] / alpha0;
        }
        return mean;
    }

    public static double[] getMean(double[] alpha) {
        DirichletDist.verifParam(alpha);
        return DirichletDist.getMean_(alpha);
    }

    public double[] getAlpha() {
        return this.alpha;
    }

    public double getAlpha(int i) {
        return this.alpha[i];
    }

    public void setParams(double[] alpha) {
        this.dimension = alpha.length;
        this.alpha = new double[this.dimension];
        for (int i = 0; i < this.dimension; ++i) {
            if (alpha[i] <= 0.0) {
                throw new IllegalArgumentException("alpha[" + i + "] <= 0");
            }
            this.alpha[i] = alpha[i];
        }
    }

    private static class Optim
    implements Uncmin_methods {
        double[] logP;
        int n;
        int k;

        public Optim(double[] logP, int n) {
            this.n = n;
            this.k = logP.length;
            this.logP = new double[this.k];
            System.arraycopy(logP, 0, this.logP, 0, this.k);
        }

        public double f_to_minimize(double[] alpha) {
            double sumAlpha = 0.0;
            double sumLnGammaAlpha = 0.0;
            double sumAlphaLnP = 0.0;
            for (int i = 1; i < alpha.length; ++i) {
                if (alpha[i] <= 0.0) {
                    return 1.0E200;
                }
                sumAlpha += alpha[i];
                sumLnGammaAlpha += Num.lnGamma(alpha[i]);
                sumAlphaLnP += (alpha[i] - 1.0) * this.logP[i - 1];
            }
            return (double)(-this.n) * (Num.lnGamma(sumAlpha) - sumLnGammaAlpha + sumAlphaLnP);
        }

        public void gradient(double[] alpha, double[] g) {
        }

        public void hessian(double[] alpha, double[][] h) {
        }
    }
}

