package edu.cmu.sphinx.linguist.lextree;

import edu.cmu.sphinx.linguist.WordSequence;
import edu.cmu.sphinx.linguist.acoustic.HMM;
import edu.cmu.sphinx.linguist.acoustic.HMMPool;
import edu.cmu.sphinx.linguist.acoustic.HMMPosition;
import edu.cmu.sphinx.linguist.acoustic.Unit;
import edu.cmu.sphinx.linguist.dictionary.Dictionary;
import edu.cmu.sphinx.linguist.dictionary.Pronunciation;
import edu.cmu.sphinx.linguist.dictionary.Word;
import edu.cmu.sphinx.linguist.language.ngram.LanguageModel;
import edu.cmu.sphinx.util.LogMath;
import edu.cmu.sphinx.util.TimerPool;
import edu.cmu.sphinx.util.Utilities;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.logging.Logger;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:edu/cmu/sphinx/linguist/lextree/HMMTree.class */
public class HMMTree {
    private final HMMPool hmmPool;
    private InitialWordNode initialNode;
    private Dictionary dictionary;
    private LanguageModel lm;
    private final boolean addFillerWords;
    private Set<Word> allWords;
    private EntryPointTable entryPointTable;
    private boolean debug;
    private final float languageWeight;
    private WordNode sentenceEndWordNode;
    static final /* synthetic */ boolean $assertionsDisabled;
    private final boolean addSilenceWord = true;
    private final Set<Unit> entryPoints = new HashSet();
    private Set<Unit> exitPoints = new HashSet();
    private final Map<Object, HMMNode[]> endNodeMap = new HashMap();
    private Logger logger = Logger.getLogger(HMMTree.class.getSimpleName());

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:edu/cmu/sphinx/linguist/lextree/HMMTree$EntryPoint.class */
    public class EntryPoint {
        final Unit baseUnit;
        int nodeCount;
        Set<Unit> rcSet;
        final Node baseNode = new Node(LogMath.getLogZero());
        final Map<Unit, Node> unitToEntryPointMap = new HashMap();
        List<Pronunciation> singleUnitWords = new ArrayList();
        float totalProbability = LogMath.getLogZero();

        EntryPoint(Unit unit) {
            this.baseUnit = unit;
        }

        Node getEntryPointsFromLeftContext(Unit unit) {
            return this.unitToEntryPointMap.get(unit);
        }

        void addProbability(float f) {
            if (f > this.totalProbability) {
                this.totalProbability = f;
            }
        }

        float getProbability() {
            return this.totalProbability;
        }

        void freeze() {
            Iterator<Node> it = this.unitToEntryPointMap.values().iterator();
            while (it.hasNext()) {
                it.next().freeze();
            }
            this.singleUnitWords = null;
            this.rcSet = null;
        }

        Node getNode() {
            return this.baseNode;
        }

        void addSingleUnitWord(Pronunciation pronunciation) {
            this.singleUnitWords.add(pronunciation);
        }

        private Collection<Unit> getEntryPointRC() {
            if (this.rcSet == null) {
                this.rcSet = new HashSet();
                Iterator<Node> it = this.baseNode.getSuccessorMap().values().iterator();
                while (it.hasNext()) {
                    this.rcSet.add(((UnitNode) it.next()).getBaseUnit());
                }
            }
            return this.rcSet;
        }

        void createEntryPointMap() {
            HashMap hashMap = new HashMap();
            HashMap<HMM, HMMNode> hashMap2 = new HashMap<>();
            for (Unit unit : HMMTree.this.exitPoints) {
                Node node = new Node(LogMath.getLogZero());
                for (Unit unit2 : getEntryPointRC()) {
                    HMM hmm = HMMTree.this.hmmPool.getHMM(this.baseUnit, unit, unit2, HMMPosition.BEGIN);
                    Node node2 = (Node) hashMap.get(hmm);
                    Node node3 = node2;
                    if (node2 == null) {
                        node3 = node.addSuccessor(hmm, getProbability());
                        hashMap.put(hmm, node3);
                    } else {
                        node.putSuccessor(hmm, node3);
                    }
                    this.nodeCount++;
                    connectEntryPointNode(node3, unit2);
                }
                connectSingleUnitWords(unit, node, hashMap2);
                this.unitToEntryPointMap.put(unit, node);
            }
        }

        private void connectSingleUnitWords(Unit unit, Node node, HashMap<HMM, HMMNode> hashMap) {
            if (this.singleUnitWords.isEmpty()) {
                return;
            }
            for (Unit unit2 : HMMTree.this.entryPoints) {
                HMM hmm = HMMTree.this.hmmPool.getHMM(this.baseUnit, unit, unit2, HMMPosition.SINGLE);
                HMMNode hMMNode = hashMap.get(hmm);
                HMMNode hMMNode2 = hMMNode;
                if (hMMNode == null) {
                    hMMNode2 = (HMMNode) node.addSuccessor(hmm, getProbability());
                    hashMap.put(hmm, hMMNode2);
                } else {
                    node.putSuccessor(hmm, hMMNode2);
                }
                hMMNode2.addRC(unit2);
                this.nodeCount++;
                for (Pronunciation pronunciation : this.singleUnitWords) {
                    if (pronunciation.getWord() == HMMTree.this.dictionary.getSentenceStartWord()) {
                        HMMTree.this.initialNode = new InitialWordNode(pronunciation, hMMNode2);
                    } else {
                        WordNode addSuccessor = hMMNode2.addSuccessor(pronunciation, HMMTree.this.getWordUnigramProbability(pronunciation.getWord()));
                        if (pronunciation.getWord() == HMMTree.this.dictionary.getSentenceEndWord()) {
                            HMMTree.this.sentenceEndWordNode = addSuccessor;
                        }
                    }
                    this.nodeCount++;
                }
            }
        }

        private void connectEntryPointNode(Node node, Unit unit) {
            for (Node node2 : this.baseNode.getSuccessors()) {
                UnitNode unitNode = (UnitNode) node2;
                if (unitNode.getBaseUnit() == unit) {
                    node.addSuccessor(unitNode);
                }
            }
        }

        void dump() {
            System.out.println("EntryPoint " + this.baseUnit + " RC Followers: " + getEntryPointRC().size());
            int i = 0;
            Collection<Unit> entryPointRC = getEntryPointRC();
            System.out.print("    ");
            Iterator<Unit> it = entryPointRC.iterator();
            while (it.hasNext()) {
                System.out.print(Utilities.pad(it.next().getName(), 4));
                int i2 = i;
                i++;
                if (i2 >= 12) {
                    i = 0;
                    System.out.println();
                    System.out.print("    ");
                }
            }
            System.out.println();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:edu/cmu/sphinx/linguist/lextree/HMMTree$EntryPointTable.class */
    public class EntryPointTable {
        private final Map<Unit, EntryPoint> entryPoints = new HashMap();

        EntryPointTable(Collection<Unit> collection) {
            for (Unit unit : collection) {
                this.entryPoints.put(unit, new EntryPoint(unit));
            }
        }

        EntryPoint getEntryPoint(Unit unit) {
            return this.entryPoints.get(unit);
        }

        void createEntryPointMaps() {
            Iterator<EntryPoint> it = this.entryPoints.values().iterator();
            while (it.hasNext()) {
                it.next().createEntryPointMap();
            }
        }

        void freeze() {
            Iterator<EntryPoint> it = this.entryPoints.values().iterator();
            while (it.hasNext()) {
                it.next().freeze();
            }
        }

        void dump() {
            Iterator<EntryPoint> it = this.entryPoints.values().iterator();
            while (it.hasNext()) {
                it.next().dump();
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public HMMTree(HMMPool hMMPool, Dictionary dictionary, LanguageModel languageModel, boolean z, float f) {
        this.hmmPool = hMMPool;
        this.dictionary = dictionary;
        this.lm = languageModel;
        this.addFillerWords = z;
        this.languageWeight = f;
        TimerPool.getTimer(this, "Create HMM Tree").start();
        compile();
        TimerPool.getTimer(this, "Create HMM Tree").stop();
    }

    public Node[] getEntryPoint(Unit unit, Unit unit2) {
        return this.entryPointTable.getEntryPoint(unit2).getEntryPointsFromLeftContext(unit).getSuccessors();
    }

    public HMMNode[] getHMMNodes(EndNode endNode) {
        HMMNode[] hMMNodeArr = this.endNodeMap.get(endNode.getKey());
        if (hMMNodeArr == null) {
            HashMap hashMap = new HashMap();
            Unit baseUnit = endNode.getBaseUnit();
            Unit leftContext = endNode.getLeftContext();
            for (Unit unit : this.entryPoints) {
                HMM hmm = this.hmmPool.getHMM(baseUnit, leftContext, unit, HMMPosition.END);
                HMMNode hMMNode = (HMMNode) hashMap.get(hmm);
                if (hMMNode == null) {
                    hMMNode = new HMMNode(hmm, LogMath.getLogOne());
                    hashMap.put(hmm, hMMNode);
                }
                hMMNode.addRC(unit);
                for (Node node : endNode.getSuccessors()) {
                    hMMNode.addSuccessor((WordNode) node);
                }
            }
            hMMNodeArr = (HMMNode[]) hashMap.values().toArray(new HMMNode[hashMap.size()]);
            this.endNodeMap.put(endNode.getKey(), hMMNodeArr);
        }
        return hMMNodeArr;
    }

    public WordNode getSentenceEndWordNode() {
        if ($assertionsDisabled || this.sentenceEndWordNode != null) {
            return this.sentenceEndWordNode;
        }
        throw new AssertionError();
    }

    private void compile() {
        collectEntryAndExitUnits();
        this.entryPointTable = new EntryPointTable(this.entryPoints);
        addWords();
        this.entryPointTable.createEntryPointMaps();
        freeze();
    }

    void dumpTree() {
        System.out.println("Dumping Tree ...");
        dumpTree(0, getInitialNode(), new HashMap());
        System.out.println("... done Dumping Tree");
    }

    private void dumpTree(int i, Node node, Map<Node, Node> map) {
        if (map.get(node) == null) {
            map.put(node, node);
            System.out.println(Utilities.pad(i) + node);
            if (node instanceof WordNode) {
                return;
            }
            for (Node node2 : node.getSuccessors()) {
                dumpTree(i + 1, node2, map);
            }
        }
    }

    private void collectEntryAndExitUnits() {
        for (Word word : getAllWords()) {
            for (int i = 0; i < word.getPronunciations().length; i++) {
                Pronunciation pronunciation = word.getPronunciations()[i];
                Unit unit = pronunciation.getUnits()[0];
                Unit unit2 = pronunciation.getUnits()[pronunciation.getUnits().length - 1];
                this.entryPoints.add(unit);
                this.exitPoints.add(unit2);
            }
        }
        if (this.debug) {
            System.out.println("Entry Points: " + this.entryPoints.size());
            System.out.println("Exit Points: " + this.exitPoints.size());
        }
    }

    private void freeze() {
        this.entryPointTable.freeze();
        this.dictionary = null;
        this.lm = null;
        this.exitPoints = null;
        this.allWords = null;
    }

    private void addWords() {
        Iterator<Word> it = getAllWords().iterator();
        while (it.hasNext()) {
            addWord(it.next());
        }
    }

    private void addWord(Word word) {
        float wordUnigramProbability = getWordUnigramProbability(word);
        for (Pronunciation pronunciation : word.getPronunciations()) {
            addPronunciation(pronunciation, wordUnigramProbability);
        }
    }

    private void addPronunciation(Pronunciation pronunciation, float f) {
        Unit[] units = pronunciation.getUnits();
        Unit unit = units[0];
        EntryPoint entryPoint = this.entryPointTable.getEntryPoint(unit);
        entryPoint.addProbability(f);
        if (units.length <= 1) {
            entryPoint.addSingleUnitWord(pronunciation);
            return;
        }
        Node node = entryPoint.getNode();
        Unit unit2 = unit;
        for (int i = 1; i < units.length - 1; i++) {
            Unit unit3 = units[i];
            Unit unit4 = units[i + 1];
            HMM hmm = this.hmmPool.getHMM(unit3, unit2, unit4, HMMPosition.INTERNAL);
            if (hmm == null) {
                this.logger.severe("Missing HMM for unit " + unit3.getName() + " with lc=" + unit2.getName() + " rc=" + unit4.getName());
            } else {
                node = node.addSuccessor(hmm, f);
            }
            unit2 = unit3;
        }
        WordNode addSuccessor = node.addSuccessor(new EndNode(units[units.length - 1], unit2, f), f).addSuccessor(pronunciation, f);
        if (addSuccessor.getWord().isSentenceEndWord()) {
            this.sentenceEndWordNode = addSuccessor;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public float getWordUnigramProbability(Word word) {
        float logOne = LogMath.getLogOne();
        if (!word.isFiller()) {
            logOne = this.lm.getProbability(new WordSequence(new Word[]{word})) * this.languageWeight;
        }
        return logOne;
    }

    private Set<Word> getAllWords() {
        if (this.allWords == null) {
            this.allWords = new HashSet();
            Iterator<String> it = this.lm.getVocabulary().iterator();
            while (it.hasNext()) {
                Word word = this.dictionary.getWord(it.next());
                if (word != null) {
                    this.allWords.add(word);
                }
            }
            if (this.addFillerWords) {
                this.allWords.addAll(Arrays.asList(this.dictionary.getFillerWords()));
            } else {
                this.allWords.add(this.dictionary.getSilenceWord());
            }
        }
        return this.allWords;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public InitialWordNode getInitialNode() {
        return this.initialNode;
    }

    static {
        $assertionsDisabled = !HMMTree.class.desiredAssertionStatus();
    }
}
