package org.encog.neural.networks.training.propagation;

import org.encog.engine.network.activation.ActivationFunction;
import org.encog.mathutil.error.ErrorCalculation;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLDataPair;
import org.encog.neural.error.ErrorFunction;
import org.encog.neural.flat.FlatNetwork;
import org.encog.util.EngineArray;
import org.encog.util.concurrency.EngineTask;

/* loaded from: classes2.dex */
public class GradientWorker implements EngineTask {
    private final double[] actual;
    private final ErrorCalculation errorCalculation = new ErrorCalculation();
    private final ErrorFunction errorFunction;
    private double[] flatSpot;
    private final double[] gradients;
    private final int high;
    private final int[] layerCounts;
    private final double[] layerDelta;
    private final int[] layerFeedCounts;
    private final int[] layerIndex;
    private final double[] layerOutput;
    private final double[] layerSums;
    private final int low;
    private final FlatNetwork network;
    private final Propagation owner;
    private final MLDataPair pair;
    private final MLDataSet training;
    private final int[] weightIndex;
    private final double[] weights;

    public GradientWorker(FlatNetwork flatNetwork, Propagation propagation, MLDataSet mLDataSet, int i, int i2, double[] dArr, ErrorFunction errorFunction) {
        this.network = flatNetwork;
        this.training = mLDataSet;
        this.low = i;
        this.high = i2;
        this.owner = propagation;
        this.flatSpot = dArr;
        this.errorFunction = errorFunction;
        this.layerDelta = new double[this.network.getLayerOutput().length];
        this.gradients = new double[this.network.getWeights().length];
        this.actual = new double[this.network.getOutputCount()];
        this.weights = this.network.getWeights();
        this.layerIndex = this.network.getLayerIndex();
        this.layerCounts = this.network.getLayerCounts();
        this.weightIndex = this.network.getWeightIndex();
        this.layerOutput = this.network.getLayerOutput();
        this.layerSums = this.network.getLayerSums();
        this.layerFeedCounts = this.network.getLayerFeedCounts();
        this.pair = BasicMLDataPair.createPair(this.network.getInputCount(), this.network.getOutputCount());
    }

    private void process(double[] dArr, double[] dArr2, double d) {
        this.network.compute(dArr, this.actual);
        this.errorCalculation.updateError(this.actual, dArr2, d);
        this.errorFunction.calculateError(dArr2, this.actual, this.layerDelta);
        for (int i = 0; i < this.actual.length; i++) {
            this.layerDelta[i] = (this.network.getActivationFunctions()[0].derivativeFunction(this.layerSums[i], this.layerOutput[i]) + this.flatSpot[0]) * this.layerDelta[i] * d;
        }
        for (int beginTraining = this.network.getBeginTraining(); beginTraining < this.network.getEndTraining(); beginTraining++) {
            processLevel(beginTraining);
        }
    }

    private void processLevel(int i) {
        int i2 = this.layerIndex[i + 1];
        int i3 = this.layerIndex[i];
        int i4 = this.layerCounts[i + 1];
        int i5 = this.layerFeedCounts[i];
        int i6 = this.weightIndex[i];
        ActivationFunction activationFunction = this.network.getActivationFunctions()[i + 1];
        double d = this.flatSpot[i + 1];
        int i7 = 0;
        int i8 = i2;
        while (i7 < i4) {
            double d2 = this.layerOutput[i8];
            double d3 = FlatNetwork.NO_BIAS_ACTIVATION;
            int i9 = i6 + i7;
            int i10 = i3;
            for (int i11 = 0; i11 < i5; i11++) {
                double[] dArr = this.gradients;
                dArr[i9] = dArr[i9] + (this.layerDelta[i10] * d2);
                d3 += this.weights[i9] * this.layerDelta[i10];
                i9 += i4;
                i10++;
            }
            this.layerDelta[i8] = d3 * (activationFunction.derivativeFunction(this.layerSums[i8], this.layerOutput[i8]) + d);
            i7++;
            i8++;
        }
    }

    public final FlatNetwork getNetwork() {
        return this.network;
    }

    public final double[] getWeights() {
        return this.weights;
    }

    @Override // org.encog.util.concurrency.EngineTask
    public final void run() {
        try {
            this.errorCalculation.reset();
            for (int i = this.low; i <= this.high; i++) {
                this.training.getRecord(i, this.pair);
                process(this.pair.getInputArray(), this.pair.getIdealArray(), this.pair.getSignificance());
            }
            this.owner.report(this.gradients, this.errorCalculation.calculate(), null);
            EngineArray.fill(this.gradients, FlatNetwork.NO_BIAS_ACTIVATION);
        } catch (Throwable th) {
            this.owner.report(null, FlatNetwork.NO_BIAS_ACTIVATION, th);
        }
    }
}
