package org.encog.neural.networks.training.propagation.quick;

import org.encog.ml.data.MLDataSet;
import org.encog.neural.flat.FlatNetwork;
import org.encog.neural.networks.ContainsFlat;
import org.encog.neural.networks.training.LearningRate;
import org.encog.neural.networks.training.TrainingError;
import org.encog.neural.networks.training.propagation.Propagation;
import org.encog.neural.networks.training.propagation.TrainingContinuation;
import org.encog.util.EngineArray;
import org.encog.util.validate.ValidateNetwork;

/* loaded from: classes2.dex */
public class QuickPropagation extends Propagation implements LearningRate {
    public static final String LAST_GRADIENTS = "LAST_GRADIENTS";
    private double decay;
    private double eps;
    private double[] lastDelta;
    private double learningRate;
    private double outputEpsilon;
    private double shrink;

    public QuickPropagation(ContainsFlat containsFlat, MLDataSet mLDataSet) {
        this(containsFlat, mLDataSet, 2.0d);
    }

    public QuickPropagation(ContainsFlat containsFlat, MLDataSet mLDataSet, double d) {
        super(containsFlat, mLDataSet);
        this.decay = 1.0E-4d;
        this.outputEpsilon = 0.35d;
        ValidateNetwork.validateMethodToData(containsFlat, mLDataSet);
        this.learningRate = d;
        this.lastDelta = new double[this.network.getFlat().getWeights().length];
    }

    @Override // org.encog.ml.train.MLTrain
    public final boolean canContinue() {
        return false;
    }

    public final double[] getLastDelta() {
        return this.lastDelta;
    }

    @Override // org.encog.neural.networks.training.LearningRate
    public final double getLearningRate() {
        return this.learningRate;
    }

    public double getOutputEpsilon() {
        return this.outputEpsilon;
    }

    public double getShrink() {
        return this.shrink;
    }

    @Override // org.encog.neural.networks.training.propagation.Propagation
    public void initOthers() {
        this.eps = this.outputEpsilon / getTraining().getRecordCount();
        this.shrink = this.learningRate / (1.0d + this.learningRate);
    }

    public final boolean isValidResume(TrainingContinuation trainingContinuation) {
        if (trainingContinuation.getContents().containsKey("LAST_GRADIENTS") && trainingContinuation.getTrainingType().equals(getClass().getSimpleName())) {
            return ((double[]) trainingContinuation.get("LAST_GRADIENTS")).length == ((ContainsFlat) getMethod()).getFlat().getWeights().length;
        }
        return false;
    }

    @Override // org.encog.ml.train.MLTrain
    public final TrainingContinuation pause() {
        TrainingContinuation trainingContinuation = new TrainingContinuation();
        trainingContinuation.setTrainingType(getClass().getSimpleName());
        trainingContinuation.set("LAST_GRADIENTS", getLastGradient());
        return trainingContinuation;
    }

    @Override // org.encog.ml.train.MLTrain
    public final void resume(TrainingContinuation trainingContinuation) {
        if (!isValidResume(trainingContinuation)) {
            throw new TrainingError("Invalid training resume data length");
        }
        EngineArray.arrayCopy((double[]) trainingContinuation.get("LAST_GRADIENTS"), getLastGradient());
    }

    @Override // org.encog.neural.networks.training.LearningRate
    public final void setLearningRate(double d) {
        this.learningRate = d;
    }

    public void setOutputEpsilon(double d) {
        setOutputEpsilon(d);
    }

    public void setShrink(double d) {
        this.shrink = d;
    }

    @Override // org.encog.neural.networks.training.propagation.Propagation
    public final double updateWeight(double[] dArr, double[] dArr2, int i) {
        double d;
        double d2 = FlatNetwork.NO_BIAS_ACTIVATION;
        double d3 = this.network.getFlat().getWeights()[i];
        double d4 = this.lastDelta[i];
        double d5 = (d3 * this.decay) + (-this.gradients[i]);
        double d6 = -dArr2[i];
        if (d4 < FlatNetwork.NO_BIAS_ACTIVATION) {
            if (d5 > FlatNetwork.NO_BIAS_ACTIVATION) {
                d2 = FlatNetwork.NO_BIAS_ACTIVATION - (this.eps * d5);
            }
            d = d5 >= this.shrink * d6 ? d2 + (this.learningRate * d4) : d2 + ((d4 * d5) / (d6 - d5));
        } else if (d4 > FlatNetwork.NO_BIAS_ACTIVATION) {
            if (d5 < FlatNetwork.NO_BIAS_ACTIVATION) {
                d2 = FlatNetwork.NO_BIAS_ACTIVATION - (this.eps * d5);
            }
            d = d5 <= this.shrink * d6 ? d2 + (this.learningRate * d4) : d2 + ((d4 * d5) / (d6 - d5));
        } else {
            d = FlatNetwork.NO_BIAS_ACTIVATION - (d5 * this.eps);
        }
        this.lastDelta[i] = d;
        getLastGradient()[i] = dArr[i];
        return d;
    }
}
