package comirva.audio.util.gmm;

import comirva.audio.XMLSerializable;
import comirva.audio.util.PointList;
import comirva.audio.util.math.Matrix;
import java.io.IOException;
import java.lang.ref.SoftReference;
import java.util.LinkedList;
import java.util.Random;
import javax.xml.stream.XMLStreamException;
import javax.xml.stream.XMLStreamReader;
import javax.xml.stream.XMLStreamWriter;

/* loaded from: input_file:comirva/audio/util/gmm/GaussianMixture.class */
public final class GaussianMixture implements XMLSerializable {
    private static final long serialVersionUID = 1;
    private int dimension;
    private GaussianComponent[] components;
    private static final int MAX_ITERATIONS = 80;
    private static double[][] p_ij = new double[1][1];
    private static SoftReference<double[][]> p_ij_SoftRef = new SoftReference<>(p_ij);
    private static Random rnd = new Random();

    public GaussianMixture(double[] dArr, Matrix[] matrixArr, Matrix[] matrixArr2) throws IllegalArgumentException {
        this.dimension = 0;
        this.components = new GaussianComponent[0];
        if (dArr.length != matrixArr.length || matrixArr.length != matrixArr2.length || dArr.length < 1) {
            throw new IllegalArgumentException("all arrays must have the same length with size greater than 0;");
        }
        this.components = new GaussianComponent[dArr.length];
        double d = 0.0d;
        for (int i = 0; i < this.components.length; i++) {
            if (matrixArr[i] == null || matrixArr2[i] == null) {
                throw new IllegalArgumentException("all mean and covarince matrices must not be null values;");
            }
            d += dArr[i];
            this.components[i] = new GaussianComponent(dArr[i], matrixArr[i], matrixArr2[i]);
        }
        if (d < 0.99d || d > 1.01d) {
            throw new IllegalArgumentException("the sum over all component weights must be in the interval [0.99, 1.01];");
        }
        this.dimension = this.components[0].getDimension();
        for (int i2 = 0; i2 < this.components.length; i2++) {
            if (this.components[i2].getDimension() != this.dimension) {
                throw new IllegalArgumentException("the dimensions of all components must be the same;");
            }
        }
    }

    public GaussianMixture(GaussianComponent[] gaussianComponentArr) throws IllegalArgumentException {
        this.dimension = 0;
        this.components = new GaussianComponent[0];
        if (gaussianComponentArr == null) {
            throw new IllegalArgumentException("the component array must not be null;");
        }
        double d = 0.0d;
        for (int i = 0; i < gaussianComponentArr.length; i++) {
            if (gaussianComponentArr[i] == null) {
                throw new IllegalArgumentException("all components in the array must not be null;");
            }
            d += gaussianComponentArr[i].getComponentWeight();
        }
        if (d < 0.99d || d > 1.01d) {
            throw new IllegalArgumentException("the sum over all component weights must be in the interval [0.99, 1.01];");
        }
        this.components = gaussianComponentArr;
        this.dimension = gaussianComponentArr[0].getDimension();
        for (GaussianComponent gaussianComponent : gaussianComponentArr) {
            if (gaussianComponent.getDimension() != this.dimension) {
                throw new IllegalArgumentException("the dimensions of all components must be the same;");
            }
        }
    }

    private GaussianMixture() {
        this.dimension = 0;
        this.components = new GaussianComponent[0];
    }

    public double getLogLikelihood(PointList pointList) {
        double d = 0.0d;
        for (int i = 0; i < pointList.size(); i++) {
            d += Math.log(getProbability(pointList.get(i)));
        }
        return d;
    }

    public double[] nextSample() {
        double d = 0.0d;
        double nextDouble = rnd.nextDouble();
        for (int i = 0; i < this.components.length; i++) {
            d += this.components[i].getComponentWeight();
            if (nextDouble < d) {
                return this.components[i].nextSample();
            }
        }
        if (this.components.length - 1 >= 0) {
            return this.components[this.components.length - 1].nextSample();
        }
        throw new IllegalStateException("gaussian components of this mixture not yet defined;");
    }

    public double getProbability(Matrix matrix) {
        double d = 0.0d;
        for (int i = 0; i < this.components.length; i++) {
            d += this.components[i].getWeightedSampleProbability(matrix);
        }
        return d;
    }

    public int getDimension() {
        return this.dimension;
    }

    public void print() {
        for (int i = 0; i < this.components.length; i++) {
            System.out.println("Component " + i + ":");
            this.components[i].print();
        }
    }

    public Matrix getMean(int i) {
        return this.components[i].getMean();
    }

    protected static void getBuffer(int i, int i2) {
        p_ij = p_ij_SoftRef.get();
        if (p_ij == null) {
            p_ij = new double[i][2 * i2];
            p_ij_SoftRef = new SoftReference<>(p_ij);
        }
        if (p_ij[0].length < i2 || p_ij.length < i) {
            if (p_ij[0].length < i2) {
                i2 += i2;
            }
            p_ij = new double[i][i2];
            p_ij_SoftRef = new SoftReference<>(p_ij);
            System.gc();
        }
    }

    public static GaussianMixture readGMM(XMLStreamReader xMLStreamReader) throws IOException, XMLStreamException {
        GaussianMixture gaussianMixture = new GaussianMixture();
        gaussianMixture.readXML(xMLStreamReader);
        return gaussianMixture;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v19 */
    /* JADX WARN: Type inference failed for: r0v3, types: [double[][]] */
    /* JADX WARN: Type inference failed for: r0v4, types: [java.lang.Throwable] */
    public void runEM(PointList pointList) throws CovarianceSingularityException {
        double d = -1.7976931348623157E308d;
        int i = 0;
        ?? r0 = p_ij;
        synchronized (r0) {
            getBuffer(this.components.length, pointList.size());
            do {
                double d2 = d;
                estimationStep(pointList);
                maximizationStep(pointList);
                i++;
                d = getLogLikelihood(pointList);
                System.out.print("*");
                if (d2 - d >= -0.1d) {
                    break;
                }
            } while (i < MAX_ITERATIONS);
            p_ij = new double[1][1];
            r0 = r0;
            System.out.println();
        }
    }

    private void estimationStep(PointList pointList) {
        for (int i = 0; i < pointList.size(); i++) {
            double d = 0.0d;
            for (int i2 = 0; i2 < this.components.length; i2++) {
                double weightedSampleProbability = this.components[i2].getWeightedSampleProbability(pointList.get(i));
                d += weightedSampleProbability;
                p_ij[i2][i] = weightedSampleProbability;
            }
            for (int i3 = 0; i3 < this.components.length; i3++) {
                double[] dArr = p_ij[i3];
                int i4 = i;
                dArr[i4] = dArr[i4] / d;
            }
        }
    }

    private void maximizationStep(PointList pointList) throws CovarianceSingularityException {
        for (int i = 0; i < this.components.length; i++) {
            try {
                this.components[i].maximise(pointList, p_ij[i]);
            } catch (CovarianceSingularityException e) {
                PointList pointList2 = new PointList(pointList.getDimension(), pointList.size());
                for (int i2 = 0; i2 < pointList.size(); i2++) {
                    if (p_ij[i][i2] < 0.95d) {
                        pointList2.add(pointList.get(i2).getColumnPackedCopy());
                    }
                }
                throw new CovarianceSingularityException(pointList2);
            }
        }
    }

    @Override // comirva.audio.XMLSerializable
    public void writeXML(XMLStreamWriter xMLStreamWriter) throws IOException, XMLStreamException {
        xMLStreamWriter.writeStartElement("gmm");
        for (int i = 0; i < this.components.length; i++) {
            this.components[i].writeXML(xMLStreamWriter);
        }
        xMLStreamWriter.writeEndElement();
    }

    @Override // comirva.audio.XMLSerializable
    public void readXML(XMLStreamReader xMLStreamReader) throws IOException, XMLStreamException {
        LinkedList linkedList = new LinkedList();
        xMLStreamReader.require(1, (String) null, "gmm");
        xMLStreamReader.next();
        while (xMLStreamReader.isStartElement()) {
            linkedList.add(GaussianComponent.readGC(xMLStreamReader));
            xMLStreamReader.next();
        }
        this.components = new GaussianComponent[linkedList.size()];
        for (int i = 0; i < linkedList.size(); i++) {
            this.components[i] = (GaussianComponent) linkedList.get(i);
        }
        if (this.components[0] != null) {
            this.dimension = this.components[0].getDimension();
        }
        xMLStreamReader.require(2, (String) null, "gmm");
    }
}
