package org.encog.app.analyst.commands;

import org.encog.app.analyst.AnalystError;
import org.encog.app.analyst.EncogAnalyst;
import org.encog.app.analyst.script.prop.ScriptProperties;
import org.encog.ml.MLMethod;
import org.encog.ml.MLResettable;
import org.encog.ml.TrainingImplementationType;
import org.encog.ml.bayesian.BayesianNetwork;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.folded.FoldedDataSet;
import org.encog.ml.factory.MLTrainFactory;
import org.encog.ml.train.MLTrain;
import org.encog.neural.networks.training.cross.CrossValidationKFold;
import org.encog.persist.EncogDirectoryPersistence;
import org.encog.util.logging.EncogLogging;
import org.encog.util.simple.EncogUtility;
import org.encog.util.validate.ValidateNetwork;

/* loaded from: classes2.dex */
public class CmdTrain extends Cmd {
    public static final String COMMAND_NAME = "TRAIN";
    private int kfold;

    public CmdTrain(EncogAnalyst encogAnalyst) {
        super(encogAnalyst);
    }

    private MLTrain createTrainer(MLMethod mLMethod, MLDataSet mLDataSet) {
        MLTrainFactory mLTrainFactory = new MLTrainFactory();
        String propertyString = getProp().getPropertyString(ScriptProperties.ML_TRAIN_TYPE);
        String propertyString2 = getProp().getPropertyString(ScriptProperties.ML_TRAIN_ARGUMENTS);
        EncogLogging.log(0, "training type:" + propertyString);
        EncogLogging.log(0, "training args:" + propertyString2);
        if (mLMethod instanceof MLResettable) {
            getAnalyst().setMethod(mLMethod);
        }
        MLTrain create = mLTrainFactory.create(mLMethod, mLDataSet, propertyString, propertyString2);
        return this.kfold > 0 ? new CrossValidationKFold(create, this.kfold) : create;
    }

    private int obtainCross() {
        String propertyString = getProp().getPropertyString(ScriptProperties.ML_TRAIN_CROSS);
        if (propertyString == null || propertyString.length() == 0) {
            return 0;
        }
        if (!propertyString.toLowerCase().startsWith("kfold:")) {
            throw new AnalystError("Unknown cross validation: " + propertyString);
        }
        String substring = propertyString.substring(6);
        try {
            return Integer.parseInt(substring);
        } catch (NumberFormatException e) {
            throw new AnalystError("Invalid kfold :" + substring);
        }
    }

    private MLMethod obtainMethod() {
        MLMethod mLMethod = (MLMethod) EncogDirectoryPersistence.loadObject(getScript().resolveFilename(getProp().getPropertyString(ScriptProperties.ML_CONFIG_MACHINE_LEARNING_FILE)));
        if (mLMethod instanceof MLMethod) {
            return mLMethod;
        }
        throw new AnalystError("The object to be trained must be an instance of MLMethod. " + mLMethod.getClass().getSimpleName());
    }

    private MLDataSet obtainTrainingSet() {
        MLDataSet loadEGB2Memory = EncogUtility.loadEGB2Memory(getScript().resolveFilename(getProp().getPropertyString(ScriptProperties.ML_CONFIG_TRAINING_FILE)));
        return this.kfold > 0 ? new FoldedDataSet(loadEGB2Memory) : loadEGB2Memory;
    }

    private void performTraining(MLTrain mLTrain, MLMethod mLMethod, MLDataSet mLDataSet) {
        ValidateNetwork.validateMethodToData(mLMethod, mLDataSet);
        double propertyDouble = getProp().getPropertyDouble(ScriptProperties.ML_TRAIN_TARGET_ERROR);
        getAnalyst().reportTrainingBegin();
        int maxIteration = getAnalyst().getMaxIteration();
        if (mLTrain.getImplementationType() != TrainingImplementationType.OnePass) {
            while (true) {
                mLTrain.iteration();
                getAnalyst().reportTraining(mLTrain);
                if (mLTrain.getError() <= propertyDouble || getAnalyst().shouldStopCommand() || mLTrain.isTrainingDone() || (maxIteration != -1 && mLTrain.getIteration() >= maxIteration)) {
                    break;
                }
            }
        } else {
            mLTrain.iteration();
            getAnalyst().reportTraining(mLTrain);
        }
        mLTrain.finishTraining();
        getAnalyst().reportTrainingEnd();
        getAnalyst().setMethod(mLTrain.getMethod());
    }

    @Override // org.encog.app.analyst.commands.Cmd
    public final boolean executeCommand(String str) {
        this.kfold = obtainCross();
        MLDataSet obtainTrainingSet = obtainTrainingSet();
        MLMethod obtainMethod = obtainMethod();
        MLTrain createTrainer = createTrainer(obtainMethod, obtainTrainingSet);
        if (obtainMethod instanceof BayesianNetwork) {
            ((BayesianNetwork) obtainMethod).defineClassificationStructure(getProp().getPropertyString(ScriptProperties.ML_CONFIG_QUERY));
        }
        EncogLogging.log(0, "Beginning training");
        performTraining(createTrainer, obtainMethod, obtainTrainingSet);
        String propertyString = getProp().getPropertyString(ScriptProperties.ML_CONFIG_MACHINE_LEARNING_FILE);
        EncogDirectoryPersistence.saveObject(getAnalyst().getScript().resolveFilename(propertyString), createTrainer.getMethod());
        EncogLogging.log(0, "save to:" + propertyString);
        obtainTrainingSet.close();
        return getAnalyst().shouldStopCommand();
    }

    @Override // org.encog.app.analyst.commands.Cmd
    public final String getName() {
        return COMMAND_NAME;
    }
}
