package edu.cmu.sphinx.research.parallel;

import edu.cmu.sphinx.decoder.pruner.Pruner;
import edu.cmu.sphinx.decoder.scorer.AcousticScorer;
import edu.cmu.sphinx.decoder.scorer.Scoreable;
import edu.cmu.sphinx.decoder.search.ActiveList;
import edu.cmu.sphinx.decoder.search.ActiveListFactory;
import edu.cmu.sphinx.decoder.search.SearchManager;
import edu.cmu.sphinx.decoder.search.Token;
import edu.cmu.sphinx.frontend.Data;
import edu.cmu.sphinx.linguist.Linguist;
import edu.cmu.sphinx.linguist.SearchState;
import edu.cmu.sphinx.linguist.SearchStateArc;
import edu.cmu.sphinx.linguist.flat.SentenceHMMState;
import edu.cmu.sphinx.linguist.flat.SentenceHMMStateArc;
import edu.cmu.sphinx.result.Result;
import edu.cmu.sphinx.util.LogMath;
import edu.cmu.sphinx.util.Timer;
import edu.cmu.sphinx.util.TimerPool;
import edu.cmu.sphinx.util.props.PropertyException;
import edu.cmu.sphinx.util.props.PropertySheet;
import edu.cmu.sphinx.util.props.S4Boolean;
import edu.cmu.sphinx.util.props.S4Component;
import java.io.IOException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:edu/cmu/sphinx/research/parallel/ParallelSearchManager.class */
public class ParallelSearchManager implements SearchManager {

    @S4Component(type = ActiveListFactory.class)
    public static final String PROP_ACTIVE_LIST_FACTORY = "activeListFactory";

    @S4Boolean(defaultValue = false)
    public static final String PROP_DO_FEATURE_PRUNING = "doFeaturePruning";

    @S4Component(type = FeatureScorePruner.class)
    public static final String PROP_FEATURE_SCORE_PRUNER = "featureScorePruner";

    @S4Boolean(defaultValue = false)
    public static final String PROP_DO_COMBINE_PRUNING = "doCombinePruning";

    @S4Component(type = AcousticScorer.class)
    public static final String PROP_SCORER = "scorer";

    @S4Component(type = Linguist.class)
    public static final String PROP_LINGUIST = "linguist";

    @S4Component(type = LogMath.class)
    public static final String PROP_LOG_MATH = "logMath";
    private ParallelSimpleLinguist linguist;
    private AcousticScorer scorer;
    private Pruner featureScorePruner;
    private Pruner combinedScorePruner;
    private LogMath logMath;
    private int currentFrameNumber;
    private ActiveListFactory activeListFactory;
    private ActiveList combinedActiveList;
    private ActiveList delayedExpansionList;
    private List<Token> resultList;
    private Map<SearchState, Token> bestTokenMap;
    private Timer scoreTimer;
    private Timer pruneTimer;
    private Timer growTimer;
    private boolean doFeaturePruning;
    private boolean doCombinePruning;
    static final /* synthetic */ boolean $assertionsDisabled;

    @Override // edu.cmu.sphinx.util.props.Configurable
    public void newProperties(PropertySheet propertySheet) throws PropertyException {
        this.logMath = (LogMath) propertySheet.getComponent("logMath");
        this.linguist = (ParallelSimpleLinguist) propertySheet.getComponent("linguist");
        this.scorer = (AcousticScorer) propertySheet.getComponent("scorer");
        this.activeListFactory = (ActiveListFactory) propertySheet.getComponent("activeListFactory");
        this.doFeaturePruning = propertySheet.getBoolean(PROP_DO_FEATURE_PRUNING).booleanValue();
        this.doCombinePruning = propertySheet.getBoolean(PROP_DO_COMBINE_PRUNING).booleanValue();
        if (this.doFeaturePruning) {
            this.featureScorePruner = (FeatureScorePruner) propertySheet.getComponent(PROP_FEATURE_SCORE_PRUNER);
        }
        this.scoreTimer = TimerPool.getTimer(this, "Score");
        this.pruneTimer = TimerPool.getTimer(this, "Prune");
        this.growTimer = TimerPool.getTimer(this, "Grow");
    }

    @Override // edu.cmu.sphinx.decoder.search.SearchManager
    public void allocate() {
        this.bestTokenMap = new HashMap();
        try {
            this.linguist.allocate();
            if (this.doFeaturePruning) {
                this.featureScorePruner.allocate();
            }
            if (this.doCombinePruning) {
                this.combinedScorePruner.allocate();
            }
            this.scorer.allocate();
        } catch (IOException e) {
            throw new RuntimeException(toString() + ": allocation of search manager resources failed", e);
        }
    }

    private void debugPrint(String str) {
    }

    @Override // edu.cmu.sphinx.decoder.search.SearchManager
    public void startRecognition() {
        this.currentFrameNumber = 0;
        this.linguist.startRecognition();
        if (this.doFeaturePruning) {
            this.featureScorePruner.startRecognition();
        }
        if (this.doCombinePruning) {
            this.combinedScorePruner.startRecognition();
        }
        this.scorer.startRecognition();
        createInitialLists();
    }

    private void createInitialLists() {
        this.combinedActiveList = this.activeListFactory.newInstance();
        this.delayedExpansionList = this.activeListFactory.newInstance();
        SentenceHMMState sentenceHMMState = (SentenceHMMState) this.linguist.getSearchGraph().getInitialState();
        CombineToken combineToken = new CombineToken(null, sentenceHMMState, this.currentFrameNumber);
        setBestToken(sentenceHMMState, combineToken);
        Iterator<FeatureStream> featureStreams = this.linguist.getFeatureStreams();
        while (featureStreams.hasNext()) {
            FeatureStream next = featureStreams.next();
            next.setActiveList(this.activeListFactory.newInstance());
            ParallelToken parallelToken = new ParallelToken(sentenceHMMState, next, this.currentFrameNumber);
            parallelToken.setLastCombineTime(this.currentFrameNumber);
            combineToken.addParallelToken(next, parallelToken);
        }
        this.resultList = new LinkedList();
        calculateCombinedScore(combineToken);
        growCombineToken(combineToken);
    }

    @Override // edu.cmu.sphinx.decoder.search.SearchManager
    public Result recognize(int i) {
        boolean z = false;
        for (int i2 = 0; i2 < i && !z; i2++) {
            z = recognize();
        }
        return new Result(this.combinedActiveList, this.resultList, this.currentFrameNumber, z, this.logMath);
    }

    private boolean recognize() {
        debugPrint("-----\nFrame: " + this.currentFrameNumber);
        boolean score = score();
        if (score) {
            prune();
            grow();
            this.currentFrameNumber++;
        }
        debugPrint("-----");
        return !score;
    }

    private boolean score() {
        this.scoreTimer.start();
        debugPrint("Scoring");
        boolean z = false;
        Iterator<FeatureStream> featureStreams = this.linguist.getFeatureStreams();
        while (featureStreams.hasNext()) {
            Data calculateScores = this.scorer.calculateScores(featureStreams.next().getActiveList().getTokens());
            Scoreable scoreable = null;
            if (calculateScores instanceof Scoreable) {
                scoreable = (Scoreable) calculateScores;
            }
            z = scoreable != null;
        }
        debugPrint(" done Scoring");
        this.scoreTimer.stop();
        return z;
    }

    private void prune() {
        this.pruneTimer.start();
        debugPrint("Pruning");
        if (this.doFeaturePruning) {
            Iterator<FeatureStream> featureStreams = this.linguist.getFeatureStreams();
            while (featureStreams.hasNext()) {
                FeatureStream next = featureStreams.next();
                next.setActiveList(this.featureScorePruner.prune(next.getActiveList()));
            }
        }
        debugPrint(" done Pruning");
        this.pruneTimer.stop();
    }

    private void printActiveLists() {
        debugPrint(" CombinedActiveList: " + this.combinedActiveList.size());
        Iterator<FeatureStream> featureStreams = this.linguist.getFeatureStreams();
        while (featureStreams.hasNext()) {
            FeatureStream next = featureStreams.next();
            debugPrint(" ActiveList, " + next.getName() + ": " + next.getActiveList().size());
        }
    }

    private void grow() {
        this.growTimer.start();
        debugPrint("Growing");
        this.resultList = new LinkedList();
        this.combinedActiveList = this.activeListFactory.newInstance();
        this.delayedExpansionList = this.activeListFactory.newInstance();
        Iterator<FeatureStream> featureStreams = this.linguist.getFeatureStreams();
        while (featureStreams.hasNext()) {
            FeatureStream next = featureStreams.next();
            ActiveList activeList = next.getActiveList();
            next.setActiveList(this.activeListFactory.newInstance());
            growActiveList(activeList);
        }
        growDelayedExpansionList();
        Iterator<FeatureStream> featureStreams2 = this.linguist.getFeatureStreams();
        while (featureStreams2.hasNext()) {
            FeatureStream next2 = featureStreams2.next();
            ActiveList newInstance = this.activeListFactory.newInstance();
            for (ParallelToken parallelToken : next2.getActiveList()) {
                if (!parallelToken.isPruned()) {
                    newInstance.add(parallelToken);
                }
            }
            next2.setActiveList(newInstance);
        }
        debugPrint(" done Growing");
        this.growTimer.stop();
    }

    private void growDelayedExpansionList() {
        Iterator<Token> it = this.delayedExpansionList.iterator();
        while (it.hasNext()) {
            calculateCombinedScore((CombineToken) it.next());
        }
        if (this.doCombinePruning) {
            this.delayedExpansionList = this.combinedScorePruner.prune(this.delayedExpansionList);
        }
        for (CombineToken combineToken : this.delayedExpansionList) {
            combineToken.setLastCombineTime(this.currentFrameNumber);
            growCombineToken(combineToken);
        }
    }

    private void calculateCombinedScore(CombineToken combineToken) {
        new FeatureScoreCombiner().combineScore(combineToken);
    }

    private void growActiveList(ActiveList activeList) {
        Iterator<Token> it = activeList.iterator();
        while (it.hasNext()) {
            growParallelToken((ParallelToken) it.next());
        }
    }

    private void growParallelToken(ParallelToken parallelToken) {
        CombineToken combineToken;
        if (!$assertionsDisabled && parallelToken.isFinal()) {
            throw new AssertionError();
        }
        int frameNumber = parallelToken.getFrameNumber();
        if (parallelToken.isEmitting()) {
            frameNumber++;
        }
        for (SearchStateArc searchStateArc : ((SentenceHMMState) parallelToken.getSearchState()).getSuccessors()) {
            SentenceHMMState sentenceHMMState = (SentenceHMMState) searchStateArc.getState();
            float score = parallelToken.getScore() + searchStateArc.getProbability();
            Token bestToken = getBestToken(sentenceHMMState);
            boolean z = bestToken == null || bestToken.getFrameNumber() != frameNumber;
            if (sentenceHMMState.getColor() == SentenceHMMState.Color.RED) {
                if (z) {
                    combineToken = new CombineToken(parallelToken, sentenceHMMState, frameNumber);
                    setBestToken(sentenceHMMState, combineToken);
                    this.delayedExpansionList.add(combineToken);
                } else {
                    combineToken = (CombineToken) getBestToken(sentenceHMMState);
                }
                if (!$assertionsDisabled && combineToken.getFrameNumber() != frameNumber) {
                    throw new AssertionError();
                }
                ParallelToken parallelToken2 = combineToken.getParallelToken(parallelToken.getFeatureStream());
                if (z || parallelToken2 == null || parallelToken2.getScore() <= score) {
                    ParallelToken parallelToken3 = new ParallelToken(parallelToken, sentenceHMMState, score, parallelToken.getCombinedScore(), frameNumber, parallelToken.getLastCombineTime());
                    combineToken.addParallelToken(parallelToken3.getFeatureStream(), parallelToken3);
                }
            } else {
                if (sentenceHMMState.getColor() != SentenceHMMState.Color.GREEN) {
                    throw new IllegalStateException("Color of state " + sentenceHMMState.getName() + " not RED or GREEN, its " + sentenceHMMState.getColor() + '!');
                }
                if (z || getBestToken(sentenceHMMState).getScore() <= score) {
                    ParallelToken parallelToken4 = new ParallelToken(parallelToken, sentenceHMMState, score, parallelToken.getCombinedScore(), frameNumber, parallelToken.getLastCombineTime());
                    if (parallelToken4.isEmitting()) {
                        replaceParallelToken(sentenceHMMState, parallelToken4);
                        this.combinedActiveList.add(parallelToken4);
                    } else {
                        growParallelToken(parallelToken4);
                    }
                }
            }
        }
    }

    private Token getBestToken(SearchState searchState) {
        return this.bestTokenMap.get(searchState);
    }

    private Token setBestToken(SearchState searchState, Token token) {
        return this.bestTokenMap.put(searchState, token);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void growCombineToken(CombineToken combineToken) {
        if (combineToken.isFinal()) {
            this.resultList.add(combineToken);
        }
        int frameNumber = combineToken.getFrameNumber();
        if (!$assertionsDisabled && combineToken.isEmitting()) {
            throw new AssertionError();
        }
        for (SearchStateArc searchStateArc : combineToken.getSearchState().getSuccessors()) {
            SentenceHMMStateArc sentenceHMMStateArc = (SentenceHMMStateArc) searchStateArc;
            SentenceHMMState sentenceHMMState = (SentenceHMMState) sentenceHMMStateArc.getState();
            Token bestToken = getBestToken(sentenceHMMState);
            boolean z = bestToken == null || bestToken.getFrameNumber() != frameNumber;
            if (sentenceHMMState.getColor() == SentenceHMMState.Color.RED) {
                float score = combineToken.getScore() + sentenceHMMStateArc.getProbability();
                if (z || bestToken.getScore() <= score) {
                    CombineToken combineToken2 = new CombineToken(combineToken, sentenceHMMState, frameNumber);
                    combineToken2.setScore(score);
                    transitionParallelTokens(combineToken, combineToken2, sentenceHMMStateArc.getProbability());
                    setBestToken(sentenceHMMState, combineToken2);
                    growCombineToken(combineToken2);
                }
            } else {
                if (sentenceHMMState.getColor() != SentenceHMMState.Color.GREEN) {
                    throw new IllegalStateException("Color of state not RED or GREEN!");
                }
                ParallelToken parallelToken = combineToken.getParallelToken(((ParallelState) sentenceHMMState).getFeatureStream());
                if (parallelToken != null) {
                    float probability = sentenceHMMStateArc.getProbability() + parallelToken.getFeatureScore();
                    ParallelToken parallelToken2 = (ParallelToken) bestToken;
                    if (z || parallelToken2.getFeatureScore() <= probability) {
                        ParallelToken parallelToken3 = new ParallelToken(parallelToken, sentenceHMMState, probability, parallelToken.getCombinedScore(), frameNumber, parallelToken.getLastCombineTime());
                        if (sentenceHMMState.isEmitting()) {
                            replaceParallelToken(sentenceHMMState, parallelToken3);
                            this.combinedActiveList.add(parallelToken3);
                        } else {
                            growParallelToken(parallelToken3);
                        }
                    }
                }
            }
        }
    }

    private void transitionParallelTokens(CombineToken combineToken, CombineToken combineToken2, float f) {
        Iterator<ParallelToken> it = combineToken.iterator();
        while (it.hasNext()) {
            ParallelToken next = it.next();
            ParallelToken parallelToken = new ParallelToken(next, (SentenceHMMState) combineToken2.getSearchState(), next.getFeatureScore() + f, next.getCombinedScore(), next.getFrameNumber(), next.getLastCombineTime());
            combineToken2.addParallelToken(parallelToken.getFeatureStream(), parallelToken);
        }
    }

    private void replaceParallelToken(SentenceHMMState sentenceHMMState, ParallelToken parallelToken) {
        ParallelToken parallelToken2 = (ParallelToken) setBestToken(sentenceHMMState, parallelToken);
        parallelToken.getFeatureStream().getActiveList().add(parallelToken);
        if (parallelToken2 != null) {
            parallelToken2.setPruned(true);
        }
    }

    @Override // edu.cmu.sphinx.decoder.search.SearchManager
    public void stopRecognition() {
        this.scorer.stopRecognition();
        if (this.doFeaturePruning) {
            this.featureScorePruner.stopRecognition();
        }
        if (this.doCombinePruning) {
            this.combinedScorePruner.stopRecognition();
        }
        this.linguist.stopRecognition();
        this.bestTokenMap = new HashMap();
    }

    @Override // edu.cmu.sphinx.decoder.search.SearchManager
    public void deallocate() {
        this.scorer.deallocate();
        if (this.doFeaturePruning) {
            this.featureScorePruner.deallocate();
        }
        if (this.doCombinePruning) {
            this.combinedScorePruner.deallocate();
        }
        this.linguist.deallocate();
    }

    static {
        $assertionsDisabled = !ParallelSearchManager.class.desiredAssertionStatus();
    }
}
