package fr.inria.cf.coldstart;

import cern.colt.matrix.impl.AbstractFormatter;
import fr.inria.cf.object.MatrixCF;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.Random;
import java.util.StringTokenizer;
import org.encog.Encog;
import org.encog.engine.network.activation.ActivationSigmoid;
import org.encog.engine.network.activation.ActivationTANH;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLDataSet;
import org.encog.ml.train.strategy.Greedy;
import org.encog.ml.train.strategy.HybridStrategy;
import org.encog.ml.train.strategy.StopTrainingStrategy;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.layers.BasicLayer;
import org.encog.neural.networks.training.TrainingSetScore;
import org.encog.neural.networks.training.anneal.NeuralSimulatedAnnealing;
import org.encog.neural.networks.training.propagation.back.Backpropagation;
import org.encog.neural.networks.training.propagation.manhattan.ManhattanPropagation;
import org.encog.neural.networks.training.propagation.resilient.RPROPType;
import org.encog.neural.networks.training.propagation.resilient.ResilientPropagation;
import org.encog.neural.pattern.ElmanPattern;

/* loaded from: input_file:fr/inria/cf/coldstart/ENCOGNeuralNetwork.class */
public class ENCOGNeuralNetwork extends ColdStartMethod {
    private Random rand = new Random(12345678);
    private BasicNetwork network;
    private double alpha;
    private double momentum;
    private int numOfMaxIterations;
    private double errorThreshold;
    private int numOfLayers;
    private MLDataSet trainingSet;
    private MLDataSet testSet;
    private NeuralNetworkType neuralNetType;
    private int numOfFeatures;
    private int numOfTrainingInstances;
    private double[][] benchMatrix;
    private ArrayList<Integer> missingInstList;
    private double[][] missingInstAlgRows;
    private double[][] predictionMatrix;
    private double[][] newPredictionMatrix;
    private double[][] featureMatrix;
    private double[][] UMatrix;
    private double[][] VMatrix;
    private double[][] predictValuesArr;
    private double[][] errorArr;
    private double[] currentFitness;
    private double[] totalDiffArr;
    private int kSVDTruncationVal;

    /* loaded from: input_file:fr/inria/cf/coldstart/ENCOGNeuralNetwork$ActivationFunctionType.class */
    public enum ActivationFunctionType {
        Sigmoid,
        Tan;

        /* renamed from: values, reason: to resolve conflict with enum method */
        public static ActivationFunctionType[] valuesCustom() {
            ActivationFunctionType[] valuesCustom = values();
            int length = valuesCustom.length;
            ActivationFunctionType[] activationFunctionTypeArr = new ActivationFunctionType[length];
            System.arraycopy(valuesCustom, 0, activationFunctionTypeArr, 0, length);
            return activationFunctionTypeArr;
        }
    }

    public ENCOGNeuralNetwork(NeuralNetworkType neuralNetworkType, int i, int i2, double d, int i3, int i4) {
        this.neuralNetType = neuralNetworkType;
        this.kSVDTruncationVal = i3;
        this.numOfFeatures = i4;
        this.numOfMaxIterations = i;
        this.numOfLayers = i2;
        this.momentum = d;
    }

    private double[][] getPredictionMatrixWithoutMissingInstances(double[][] dArr, ArrayList<Integer> arrayList) {
        double[][] dArr2 = new double[dArr.length][dArr[0].length - arrayList.size()];
        int i = 0;
        for (int i2 = 0; i2 < dArr[0].length; i2++) {
            if (!arrayList.contains(Integer.valueOf(i2))) {
                for (int i3 = 0; i3 < dArr.length; i3++) {
                    dArr2[i3][i] = dArr[i3][i2];
                }
                i++;
            }
        }
        return dArr2;
    }

    private void initialiseNN(int i, int i2, int i3) {
        if (i <= 2) {
            System.out.println(" >> numOfLayers = " + i + " is not large enough!! ");
            System.exit(-1);
            return;
        }
        this.network = new BasicNetwork();
        this.network.addLayer(new BasicLayer(null, true, i2));
        for (int i4 = 0; i4 < i - 2; i4++) {
            this.network.addLayer(new BasicLayer(new ActivationTANH(), true, i2));
        }
        this.network.addLayer(new BasicLayer(new ActivationTANH(), false, i3));
        this.network.getStructure().finalizeStructure();
        this.network.reset();
        double[][] dArr = new double[this.featureMatrix.length - this.missingInstList.size()][this.featureMatrix[0].length];
        double[][] dArr2 = new double[this.missingInstList.size()][this.featureMatrix[0].length];
        int i5 = 0;
        int i6 = 0;
        for (int i7 = 0; i7 < this.featureMatrix.length; i7++) {
            if (this.missingInstList.contains(Integer.valueOf(i7))) {
                for (int i8 = 0; i8 < this.featureMatrix[0].length; i8++) {
                    dArr2[i6][i8] = this.featureMatrix[i7][i8];
                }
                i6++;
            } else {
                for (int i9 = 0; i9 < this.featureMatrix[0].length; i9++) {
                    dArr[i5][i9] = this.featureMatrix[i7][i9];
                }
                i5++;
            }
        }
        this.trainingSet = new BasicMLDataSet(dArr, this.UMatrix);
        this.testSet = new BasicMLDataSet(dArr2, null);
    }

    private ArrayList<Integer> determineMissingInstanceList(double[][] dArr) {
        MatrixCF matrixCF = new MatrixCF(dArr);
        ArrayList<Integer> arrayList = new ArrayList<>();
        for (int i = 0; i < dArr[0].length; i++) {
            if (matrixCF.getNumOfNonEmptyEntriesPerColumn()[i] == 0) {
                arrayList.add(Integer.valueOf(i));
            }
        }
        return arrayList;
    }

    public double[][] loadMatrixFromTxt(String str, int i, int i2) {
        double[][] dArr = new double[i][i2];
        try {
            BufferedReader bufferedReader = new BufferedReader(new FileReader(str));
            int i3 = 0;
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    break;
                }
                StringTokenizer stringTokenizer = new StringTokenizer(readLine, "\t");
                int i4 = 0;
                while (stringTokenizer.hasMoreTokens()) {
                    dArr[i3][i4] = Double.parseDouble(stringTokenizer.nextToken());
                    i4++;
                }
                i3++;
            }
        } catch (IOException e) {
            System.out.println(e);
            System.exit(-1);
        }
        return dArr;
    }

    public double[][] loadMatrixFromTxtWithFirstInfoLine(String str) {
        double[][] dArr = (double[][]) null;
        try {
            BufferedReader bufferedReader = new BufferedReader(new FileReader(str));
            int i = 0;
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    break;
                }
                StringTokenizer stringTokenizer = new StringTokenizer(readLine, "\t");
                if (i == 0) {
                    dArr = new double[Integer.parseInt(stringTokenizer.nextToken())][Integer.parseInt(stringTokenizer.nextToken())];
                } else {
                    int i2 = 0;
                    while (stringTokenizer.hasMoreTokens()) {
                        dArr[i - 1][i2] = Double.parseDouble(stringTokenizer.nextToken());
                        i2++;
                    }
                }
                i++;
            }
        } catch (IOException e) {
            System.out.println(e);
        }
        return dArr;
    }

    private void calculatePredictionValues() {
        for (int i = 0; i < this.predictValuesArr.length; i++) {
            for (int i2 = 0; i2 < this.kSVDTruncationVal; i2++) {
                this.predictValuesArr[i][i2] = this.network.compute(this.testSet.get(i).getInput()).getData(i2);
            }
        }
    }

    private double calculateFitness(int i) {
        this.currentFitness[i] = 0.0d;
        this.totalDiffArr[i] = 0.0d;
        Iterator<Integer> it = this.missingInstList.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            double abs = Math.abs(this.predictValuesArr[intValue][i] - this.UMatrix[intValue][i]);
            this.totalDiffArr[i] = abs;
            this.errorArr[intValue][i] = Math.pow(abs, 2.0d);
            double[] dArr = this.currentFitness;
            dArr[i] = dArr[i] + this.errorArr[intValue][i];
        }
        double[] dArr2 = this.currentFitness;
        double d = dArr2[i] / (2 * this.numOfTrainingInstances);
        dArr2[i] = d;
        return d;
    }

    public void applyResillientPropagation(int i) {
        ResilientPropagation resilientPropagation = new ResilientPropagation(this.network, this.trainingSet);
        int i2 = 1;
        do {
            resilientPropagation.iteration();
            if (i2 % 10 == 0) {
                System.out.println("Epoch #" + i2 + " Error:" + resilientPropagation.getError());
            }
            i2++;
            if (resilientPropagation.getError() <= 1.0E-5d) {
                break;
            }
        } while (i2 < i);
        Encog.getInstance().shutdown();
        calculatePredictionValues();
        completePredictionMatrix();
    }

    public void applyResillientPropagationWithType(int i, RPROPType rPROPType) {
        ResilientPropagation resilientPropagation = new ResilientPropagation(this.network, this.trainingSet);
        resilientPropagation.setRPROPType(rPROPType);
        int i2 = 1;
        do {
            resilientPropagation.iteration();
            if (i2 % 10 == 0) {
                System.out.println("Epoch #" + i2 + " Error:" + resilientPropagation.getError());
            }
            i2++;
            if (resilientPropagation.getError() <= 1.0E-5d) {
                break;
            }
        } while (i2 < i);
        Encog.getInstance().shutdown();
        calculatePredictionValues();
        completePredictionMatrix();
    }

    public void applyBackPropagation(int i) {
        Backpropagation backpropagation = new Backpropagation(this.network, this.trainingSet, this.alpha, this.momentum);
        int i2 = 1;
        do {
            backpropagation.iteration();
            System.out.println("Epoch #" + i2 + " Error:" + backpropagation.getError());
            i2++;
            if (backpropagation.getError() <= 1.0E-5d) {
                break;
            }
        } while (i2 < i);
        Encog.getInstance().shutdown();
        calculatePredictionValues();
        completePredictionMatrix();
    }

    public void applyManhattanPropagation(int i) {
        ManhattanPropagation manhattanPropagation = new ManhattanPropagation(this.network, this.trainingSet, this.alpha);
        int i2 = 1;
        do {
            manhattanPropagation.iteration();
            System.out.println("Epoch #" + i2 + " Error:" + manhattanPropagation.getError());
            i2++;
            if (manhattanPropagation.getError() <= 1.0E-5d) {
                break;
            }
        } while (i2 < i);
        System.out.println("Neural Network Results:");
        Encog.getInstance().shutdown();
        calculatePredictionValues();
        completePredictionMatrix();
    }

    public void applyElmanRecurrentNN(int i) {
        ElmanPattern elmanPattern = new ElmanPattern();
        elmanPattern.setActivationFunction(new ActivationSigmoid());
        elmanPattern.setInputNeurons(1);
        elmanPattern.addHiddenLayer(6);
        elmanPattern.setOutputNeurons(1);
        NeuralSimulatedAnnealing neuralSimulatedAnnealing = new NeuralSimulatedAnnealing(this.network, new TrainingSetScore(this.trainingSet), 10.0d, 2.0d, 100);
        Backpropagation backpropagation = new Backpropagation(this.network, this.trainingSet, this.alpha, this.momentum);
        StopTrainingStrategy stopTrainingStrategy = new StopTrainingStrategy();
        backpropagation.addStrategy(new Greedy());
        backpropagation.addStrategy(new HybridStrategy(neuralSimulatedAnnealing));
        backpropagation.addStrategy(stopTrainingStrategy);
        int i2 = 1;
        do {
            backpropagation.iteration();
            System.out.println("Epoch #" + i2 + " Error:" + backpropagation.getError());
            i2++;
            if (backpropagation.getError() <= 1.0E-5d) {
                break;
            }
        } while (i2 < i);
        System.out.println("Neural Network Results:");
        Encog.getInstance().shutdown();
        calculatePredictionValues();
        completePredictionMatrix();
    }

    private void calculateLinePredictionUsingUVMatrices() {
        System.out.print(" >> Before U-update: U[11] = ");
        for (int i = 0; i < this.kSVDTruncationVal; i++) {
            System.out.print(String.valueOf(this.UMatrix[11][i]) + "\t");
        }
        System.out.println();
        double[][] dArr = new double[this.missingInstList.size()][this.UMatrix[0].length];
        for (int i2 = 0; i2 < this.missingInstList.size(); i2++) {
            for (int i3 = 0; i3 < this.kSVDTruncationVal; i3++) {
                dArr[i2][i3] = this.predictValuesArr[i2][i3];
            }
        }
        int i4 = 0;
        for (int i5 = 0; i5 < this.missingInstList.size(); i5++) {
            for (int i6 = 0; i6 < this.VMatrix[0].length; i6++) {
                for (int i7 = 0; i7 < this.kSVDTruncationVal; i7++) {
                    double[] dArr2 = this.missingInstAlgRows[i4];
                    int i8 = i6;
                    dArr2[i8] = dArr2[i8] + (dArr[i5][i7] * this.VMatrix[i7][i6]);
                }
            }
            i4++;
        }
    }

    private void completePredictionMatrix() {
        calculateLinePredictionUsingUVMatrices();
        int i = 0;
        for (int i2 = 0; i2 < this.predictionMatrix[0].length; i2++) {
            if (this.missingInstList.contains(Integer.valueOf(i2))) {
                for (int i3 = 0; i3 < this.predictionMatrix.length; i3++) {
                    this.predictionMatrix[i3][i2] = this.missingInstAlgRows[i][i3];
                }
                i++;
            }
        }
    }

    private double[][] readFeatureMatrixFile(String str, int i, int i2) {
        double[][] dArr = new double[i][i2];
        try {
            BufferedReader bufferedReader = new BufferedReader(new FileReader(str));
            int i3 = 0;
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    break;
                }
                StringTokenizer stringTokenizer = new StringTokenizer(readLine, "\t");
                if (i3 != 0) {
                    int i4 = 0;
                    stringTokenizer.nextToken();
                    while (stringTokenizer.hasMoreTokens()) {
                        dArr[i3 - 1][i4] = Double.parseDouble(stringTokenizer.nextToken());
                        i4++;
                    }
                }
                i3++;
            }
        } catch (IOException e) {
            e.printStackTrace();
            System.exit(-1);
        }
        return dArr;
    }

    public void printTotalDiffArr(int i) {
        System.out.println(" >> totalDiffArr (uInx=" + i + ") : " + this.totalDiffArr[i] + AbstractFormatter.DEFAULT_ROW_SEPARATOR);
    }

    public double[][] getPredictionMatrix() {
        return this.predictionMatrix;
    }

    public ArrayList<Integer> getMissingInstList() {
        return this.missingInstList;
    }

    public void writeXDUValuesToTxt(String str, double[][] dArr, int i) {
        try {
            BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(str));
            for (double[] dArr2 : dArr) {
                for (int i2 = 0; i2 < i && i2 < dArr2.length; i2++) {
                    bufferedWriter.write(String.valueOf(dArr2[i2]) + "\t");
                }
                bufferedWriter.write(AbstractFormatter.DEFAULT_ROW_SEPARATOR);
            }
            bufferedWriter.close();
        } catch (IOException e) {
            System.out.println(e);
        }
    }

    @Override // fr.inria.cf.coldstart.ColdStartMethod
    public MatrixCF apply(MatrixCF matrixCF, double[][] dArr) {
        this.predictionMatrix = matrixCF.getMatrix();
        this.missingInstList = matrixCF.getEmptyColumnList();
        this.missingInstAlgRows = new double[this.missingInstList.size()][matrixCF.getNumOfRows()];
        this.newPredictionMatrix = getPredictionMatrixWithoutMissingInstances(this.predictionMatrix, this.missingInstList);
        ArrayList<double[][]> uVMatricesUsingSVD = setUVMatricesUsingSVD(this.newPredictionMatrix, this.kSVDTruncationVal);
        this.UMatrix = uVMatricesUsingSVD.get(0);
        this.VMatrix = uVMatricesUsingSVD.get(1);
        this.featureMatrix = normalizeMatrix(dArr);
        this.numOfTrainingInstances = dArr.length - this.missingInstList.size();
        this.predictValuesArr = new double[this.missingInstList.size()][this.kSVDTruncationVal];
        this.errorArr = new double[this.UMatrix.length][this.kSVDTruncationVal];
        this.currentFitness = new double[this.kSVDTruncationVal];
        this.totalDiffArr = new double[this.UMatrix.length];
        initialiseNN(this.numOfLayers, this.numOfFeatures, this.kSVDTruncationVal);
        if (this.neuralNetType == NeuralNetworkType.ResillientPropagation) {
            applyResillientPropagation(this.numOfMaxIterations);
        } else if (this.neuralNetType == NeuralNetworkType.ResillientPropagationWithTypeiRPROPp) {
            applyResillientPropagationWithType(this.numOfMaxIterations, RPROPType.iRPROPp);
        } else if (this.neuralNetType == NeuralNetworkType.ResillientPropagationWithTypeiRPROPm) {
            applyResillientPropagationWithType(this.numOfMaxIterations, RPROPType.iRPROPm);
        } else if (this.neuralNetType == NeuralNetworkType.ResillientPropagationWithTypeRPROPm) {
            applyResillientPropagationWithType(this.numOfMaxIterations, RPROPType.RPROPm);
        } else if (this.neuralNetType == NeuralNetworkType.BackPropagation) {
            applyBackPropagation(this.numOfMaxIterations);
        } else if (this.neuralNetType == NeuralNetworkType.ManhattanPropagation) {
            applyManhattanPropagation(this.numOfMaxIterations);
        } else if (this.neuralNetType == NeuralNetworkType.ElmanRecurrentNN) {
            applyElmanRecurrentNN(this.numOfMaxIterations);
        }
        return new MatrixCF(this.predictionMatrix);
    }

    public static void main(String[] strArr) {
        ActivationFunctionType activationFunctionType = ActivationFunctionType.Tan;
        NeuralNetworkType neuralNetworkType = NeuralNetworkType.ResillientPropagationWithTypeiRPROPp;
        String str = String.valueOf("E:/_DELLNtb-Offce/_Eclipse Helios/WorkSpaceINRIA/CFbasedPortfolioSelection/") + "data/SAT-HAN-rank-coldStart-bench-kFold-10-4-0.1.txt";
        String str2 = String.valueOf("E:/_DELLNtb-Offce/_Eclipse Helios/WorkSpaceINRIA/CFbasedPortfolioSelection/") + "HAN-features.txt";
        new ENCOGNeuralNetwork(neuralNetworkType, 100, 4, 0.5d, 5, 46).applyResillientPropagation(100);
    }
}
