package org.encog.ml.hmm.train.bw;

import java.lang.reflect.Array;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.Iterator;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.MLSequenceSet;
import org.encog.ml.hmm.HiddenMarkovModel;
import org.encog.ml.hmm.alog.ForwardBackwardCalculator;
import org.encog.neural.flat.FlatNetwork;

/* loaded from: classes2.dex */
public class TrainBaumWelch extends BaseBaumWelch {
    public TrainBaumWelch(HiddenMarkovModel hiddenMarkovModel, MLSequenceSet mLSequenceSet) {
        super(hiddenMarkovModel, mLSequenceSet);
    }

    @Override // org.encog.ml.hmm.train.bw.BaseBaumWelch
    protected double[][] estimateGamma(double[][][] dArr, ForwardBackwardCalculator forwardBackwardCalculator) {
        double[][] dArr2 = (double[][]) Array.newInstance((Class<?>) Double.TYPE, dArr.length + 1, dArr[0].length);
        for (int i = 0; i < dArr.length + 1; i++) {
            Arrays.fill(dArr2[i], FlatNetwork.NO_BIAS_ACTIVATION);
        }
        for (int i2 = 0; i2 < dArr.length; i2++) {
            for (int i3 = 0; i3 < dArr[0].length; i3++) {
                for (int i4 = 0; i4 < dArr[0].length; i4++) {
                    double[] dArr3 = dArr2[i2];
                    dArr3[i3] = dArr3[i3] + dArr[i2][i3][i4];
                }
            }
        }
        for (int i5 = 0; i5 < dArr[0].length; i5++) {
            for (int i6 = 0; i6 < dArr[0].length; i6++) {
                double[] dArr4 = dArr2[dArr.length];
                dArr4[i5] = dArr4[i5] + dArr[dArr.length - 1][i6][i5];
            }
        }
        return dArr2;
    }

    @Override // org.encog.ml.hmm.train.bw.BaseBaumWelch
    public double[][][] estimateXi(MLDataSet mLDataSet, ForwardBackwardCalculator forwardBackwardCalculator, HiddenMarkovModel hiddenMarkovModel) {
        if (mLDataSet.size() <= 1) {
            throw new IllegalArgumentException("Must have more than one observation");
        }
        double[][][] dArr = (double[][][]) Array.newInstance((Class<?>) Double.TYPE, mLDataSet.size() - 1, hiddenMarkovModel.getStateCount(), hiddenMarkovModel.getStateCount());
        double probability = forwardBackwardCalculator.probability();
        Iterator it2 = mLDataSet.iterator();
        it2.next();
        int i = 0;
        while (true) {
            int i2 = i;
            if (i2 >= mLDataSet.size() - 1) {
                return dArr;
            }
            MLDataPair mLDataPair = (MLDataPair) it2.next();
            for (int i3 = 0; i3 < hiddenMarkovModel.getStateCount(); i3++) {
                for (int i4 = 0; i4 < hiddenMarkovModel.getStateCount(); i4++) {
                    dArr[i2][i3][i4] = (((forwardBackwardCalculator.alphaElement(i2, i3) * hiddenMarkovModel.getTransitionProbability(i3, i4)) * hiddenMarkovModel.getStateDistribution(i4).probability(mLDataPair)) * forwardBackwardCalculator.betaElement(i2 + 1, i4)) / probability;
                }
            }
            i = i2 + 1;
        }
    }

    @Override // org.encog.ml.hmm.train.bw.BaseBaumWelch
    public ForwardBackwardCalculator generateForwardBackwardCalculator(MLDataSet mLDataSet, HiddenMarkovModel hiddenMarkovModel) {
        return new ForwardBackwardCalculator(mLDataSet, hiddenMarkovModel, EnumSet.allOf(ForwardBackwardCalculator.Computation.class));
    }
}
