package marytts.machinelearning;

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.PrintWriter;
import java.io.StringReader;
import java.util.Scanner;
import marytts.features.FeatureDefinition;
import marytts.unitselection.select.Target;

/* loaded from: input_file:marytts/machinelearning/SoP.class */
public class SoP {
    private double[] coeffs;
    private String[] factors;
    private int[] factorsIndex;
    boolean interceptTerm;
    double correlation;
    double rmse;
    double solution;
    FeatureDefinition featureDefinition;

    public void setCorrelation(double d) {
        this.correlation = d;
    }

    public void setRMSE(double d) {
        this.rmse = d;
    }

    public double[] getCoeffs() {
        return this.coeffs;
    }

    public double getCorrelation() {
        return this.correlation;
    }

    public double getRMSE() {
        return this.rmse;
    }

    public int[] getFactorsIndex() {
        return this.factorsIndex;
    }

    public SoP() {
        this.featureDefinition = null;
    }

    public SoP(FeatureDefinition featureDefinition) {
        this.featureDefinition = null;
        this.featureDefinition = featureDefinition;
    }

    public void setCoeffsAndFactors(double[] dArr, int[] iArr, String[] strArr, boolean z) throws Exception {
        if (this.featureDefinition == null) {
            throw new Exception("FeatureDefinition not defined in SoP");
        }
        this.interceptTerm = z;
        int length = iArr.length;
        if (!this.interceptTerm) {
            this.coeffs = new double[length];
            this.factors = new String[length];
            this.factorsIndex = new int[length];
            for (int i = 0; i < length; i++) {
                this.coeffs[i] = dArr[i];
                this.factors[i] = strArr[iArr[i]];
                this.factorsIndex[i] = this.featureDefinition.getFeatureIndex(this.factors[i]);
            }
            return;
        }
        this.coeffs = new double[length + 1];
        this.factors = new String[length + 1];
        this.factorsIndex = new int[length + 1];
        this.coeffs[0] = dArr[0];
        this.factors[0] = "_";
        this.factorsIndex[0] = -1;
        for (int i2 = 1; i2 < length + 1; i2++) {
            this.coeffs[i2] = dArr[i2];
            this.factors[i2] = strArr[iArr[i2 - 1]];
            this.factorsIndex[i2] = this.featureDefinition.getFeatureIndex(this.factors[i2]);
        }
    }

    public void load(String str) {
        String str2 = "";
        Scanner scanner = null;
        try {
            try {
                scanner = new Scanner(new BufferedReader(new FileReader(str)));
                while (scanner.hasNext()) {
                    String nextLine = scanner.nextLine();
                    if (nextLine.trim().equals("")) {
                        break;
                    } else {
                        str2 = str2 + nextLine + "\n";
                    }
                }
                this.featureDefinition = new FeatureDefinition(new BufferedReader(new StringReader(str2)), false);
                if (scanner.hasNext()) {
                    setCoeffsAndFactors(scanner.nextLine());
                }
                if (scanner != null) {
                    scanner.close();
                }
            } catch (Exception e) {
                e.printStackTrace();
                if (scanner != null) {
                    scanner.close();
                }
            }
        } catch (Throwable th) {
            if (scanner != null) {
                scanner.close();
            }
            throw th;
        }
    }

    public void setCoeffsAndFactors(String str) {
        String[] split = str.split(" ");
        int i = 0;
        this.coeffs = new double[split.length / 2];
        this.factors = new String[split.length / 2];
        this.factorsIndex = new int[split.length / 2];
        this.interceptTerm = false;
        for (int i2 = 0; i2 < split.length; i2 = i2 + 1 + 1) {
            this.coeffs[i] = Double.parseDouble(split[i2]);
            this.factors[i] = split[i2 + 1];
            if (split[i2 + 1].contentEquals("_")) {
                this.interceptTerm = true;
                this.factorsIndex[i] = -1;
            } else {
                this.factorsIndex[i] = this.featureDefinition.getFeatureIndex(this.factors[i]);
            }
            i++;
        }
    }

    public SoP(String str, FeatureDefinition featureDefinition) {
        this.featureDefinition = null;
        this.featureDefinition = featureDefinition;
        String[] split = str.split(" ");
        int i = 0;
        this.coeffs = new double[split.length / 2];
        this.factors = new String[split.length / 2];
        this.factorsIndex = new int[split.length / 2];
        this.interceptTerm = false;
        for (int i2 = 0; i2 < split.length; i2 = i2 + 1 + 1) {
            this.coeffs[i] = Double.parseDouble(split[i2]);
            this.factors[i] = split[i2 + 1];
            if (split[i2 + 1].contentEquals("_")) {
                this.interceptTerm = true;
                this.factorsIndex[i] = -1;
            } else {
                this.factorsIndex[i] = this.featureDefinition.getFeatureIndex(this.factors[i]);
            }
            i++;
        }
    }

    public FeatureDefinition getFeatureDefinition() {
        return this.featureDefinition;
    }

    public double solve(Target target, FeatureDefinition featureDefinition, boolean z) {
        this.solution = 0.0d;
        double d = 0.0d;
        if (this.interceptTerm) {
            this.solution = this.coeffs[0];
            for (int i = 1; i < this.coeffs.length; i++) {
                this.solution += this.coeffs[i] * target.getFeatureVector().getByteFeature(this.factorsIndex[i]);
                if (this.solution > 0.0d) {
                    d = this.solution;
                } else {
                    System.out.println("WARNING: sop solution negative");
                }
            }
        } else {
            for (int i2 = 0; i2 < this.coeffs.length; i2++) {
                this.solution += this.coeffs[i2] * target.getFeatureVector().getByteFeature(this.factorsIndex[i2]);
                if (this.solution > 0.0d) {
                    d = this.solution;
                } else {
                    System.out.println("WARNING: sop solution negative");
                }
            }
        }
        if (this.solution < 0.0d) {
            this.solution = d;
        }
        return z ? Math.exp(this.solution) : this.solution;
    }

    public double solve(Target target, FeatureDefinition featureDefinition, boolean z, boolean z2) {
        this.solution = 0.0d;
        double d = 0.0d;
        if (this.interceptTerm) {
            this.solution = this.coeffs[0];
            if (z2) {
                System.out.format("   solution = %.3f (coeff[0])\n", Double.valueOf(this.coeffs[0]));
            }
            for (int i = 1; i < this.coeffs.length; i++) {
                byte byteFeature = target.getFeatureVector().getByteFeature(this.factorsIndex[i]);
                String featureAsString = target.getFeatureVector().getFeatureAsString(this.factorsIndex[i], featureDefinition);
                if (featureDefinition.hasFeatureValue(this.factorsIndex[i], featureAsString)) {
                    if (z2) {
                        System.out.format("   %.3f + (%.3f * %d (%s) = ", Double.valueOf(this.solution), Double.valueOf(this.coeffs[i]), Byte.valueOf(byteFeature), this.factors[i]);
                    }
                    this.solution += this.coeffs[i] * byteFeature;
                    if (z2) {
                        System.out.format("%.3f  featureIndex=%d  feaValStr=%s \n", Double.valueOf(this.solution), Integer.valueOf(this.factorsIndex[i]), featureAsString);
                    }
                } else {
                    System.out.format("WARNING: Feature value for %s = %s is not valid", Double.valueOf(this.coeffs[i]), featureAsString);
                }
                if (this.solution > 0.0d) {
                    d = this.solution;
                }
            }
        } else {
            for (int i2 = 0; i2 < this.coeffs.length; i2++) {
                this.solution += this.coeffs[i2] * target.getFeatureVector().getByteFeature(this.factorsIndex[i2]);
            }
            if (this.solution > 0.0d) {
                d = this.solution;
            }
        }
        return z2 ? z ? Math.exp(d) : d : z ? Math.exp(this.solution) : this.solution;
    }

    public double interpret(Target target) {
        return solve(target, this.featureDefinition, false);
    }

    public void saveSelectedFeatures(PrintWriter printWriter) {
        for (int i = 0; i < this.coeffs.length; i++) {
            printWriter.print(this.coeffs[i] + " " + this.factors[i] + " ");
        }
        printWriter.println();
    }

    public void printCoefficients() {
        if (this.coeffs == null) {
            System.out.println("There is no coefficients to print (coeffs=null).");
            return;
        }
        System.out.println("SoP coefficients (factor : factorIndex in FeatureDefinition)");
        for (int i = 0; i < this.coeffs.length; i++) {
            System.out.format(" %.5f (%s : %d)\n", Double.valueOf(this.coeffs[i]), this.factors[i], Integer.valueOf(this.factorsIndex[i]));
        }
    }
}
