package fr.inria.mochy.statsAndTasks;

import fr.inria.mochy.core.equalization.EquNetNeural;
import fr.inria.mochy.ui.FullNeuralNetPipelineController;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javafx.util.Pair;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.nnet.MultiLayerPerceptron;
import org.neuroph.util.TransferFunctionType;

/* loaded from: input_file:fr/inria/mochy/statsAndTasks/FullNnetPipelineTask.class */
public class FullNnetPipelineTask extends AbstractStats {
    EquNetNeural n;
    NeuralNetwork currentNnet;
    int standardDeviationTarget;
    int maximumSteps;
    int startTokensNb;
    int endTokensNb;
    int targetSpeed;
    int stepTokens;
    int runsNb = 100;
    int nnetGeneration = 1000;
    int stepsAfterEqualization = 1000;
    int pass = 3;
    int nnetChosenNb = 5;

    @Override // javafx.concurrent.Task
    protected Object call() throws Exception {
        this.n = (EquNetNeural) getSimu().getN();
        this.targetSpeed = this.n.getTargetSpeed();
        this.standardDeviationTarget = FullNeuralNetPipelineController.standardDeviation;
        this.maximumSteps = FullNeuralNetPipelineController.maximumSteps;
        this.startTokensNb = FullNeuralNetPipelineController.startTokensNb;
        this.endTokensNb = FullNeuralNetPipelineController.endTokensNb;
        this.stepTokens = FullNeuralNetPipelineController.stepToken;
        updateMessage("Step 1/2 : generation");
        ArrayList<NeuralNetwork> generateNnets = generateNnets();
        updateMessage("Step 2/2 : optimization");
        NeuralNetwork optimizeNeuralNets = optimizeNeuralNets(generateNnets);
        if (optimizeNeuralNets != null) {
            optimizeNeuralNets.save("neural/neuralNet.nnet");
        }
        updateProgress(this.nnetGeneration, this.nnetGeneration);
        return null;
    }

    ArrayList<NeuralNetwork> generateNnets() throws Exception {
        HashMap hashMap = new HashMap();
        for (int i = 0; i < this.nnetGeneration; i++) {
            MultiLayerPerceptron multiLayerPerceptron = new MultiLayerPerceptron(TransferFunctionType.TANH, 6, 10, 10, 10, 1);
            this.n.setNeuralNetwork(multiLayerPerceptron);
            float floatValue = launchSim().floatValue();
            if (floatValue != Float.POSITIVE_INFINITY) {
                hashMap.put(Float.valueOf(floatValue), multiLayerPerceptron);
            }
            if (isCancelled()) {
                break;
            }
            updateProgress(i, this.nnetGeneration);
        }
        Map map = (Map) hashMap.entrySet().stream().sorted(Map.Entry.comparingByKey()).collect(Collectors.toMap((v0) -> {
            return v0.getKey();
        }, (v0) -> {
            return v0.getValue();
        }, (neuralNetwork, neuralNetwork2) -> {
            return neuralNetwork2;
        }, LinkedHashMap::new));
        ArrayList<NeuralNetwork> arrayList = new ArrayList<>();
        Iterator it = map.entrySet().iterator();
        for (int i2 = 1; i2 <= this.nnetChosenNb && it.hasNext(); i2++) {
            Map.Entry entry = (Map.Entry) it.next();
            arrayList.add((NeuralNetwork) entry.getValue());
            System.out.println(entry.getKey());
        }
        return arrayList;
    }

    NeuralNetwork optimizeNeuralNets(ArrayList<NeuralNetwork> arrayList) throws Exception {
        HashMap hashMap = new HashMap();
        int i = 0;
        updateMessage("optimize nnet " + (0 + 1) + "/" + arrayList.size());
        Iterator<NeuralNetwork> it = arrayList.iterator();
        while (it.hasNext()) {
            Pair<Float, NeuralNetwork> optimizeNeuralNet = optimizeNeuralNet(it.next());
            hashMap.put(optimizeNeuralNet.getKey(), optimizeNeuralNet.getValue());
            i++;
            updateProgress(i, arrayList.size());
            if (isCancelled()) {
                break;
            }
            updateMessage("optimize nnet " + (i + 1) + "/" + arrayList.size());
        }
        Map.Entry entry = (Map.Entry) ((Map) hashMap.entrySet().stream().sorted(Map.Entry.comparingByKey()).collect(Collectors.toMap((v0) -> {
            return v0.getKey();
        }, (v0) -> {
            return v0.getValue();
        }, (neuralNetwork, neuralNetwork2) -> {
            return neuralNetwork2;
        }, LinkedHashMap::new))).entrySet().iterator().next();
        updateProgress(100L, 100L);
        return (NeuralNetwork) entry.getValue();
    }

    Pair<Float, NeuralNetwork> optimizeNeuralNet(NeuralNetwork neuralNetwork) throws Exception {
        this.n.setNeuralNetwork(neuralNetwork);
        float floatValue = launchSim().floatValue();
        System.out.println("original value = " + floatValue);
        loop0: for (int i = 1; i <= this.pass; i++) {
            for (int i2 = 0; i2 < neuralNetwork.getWeights().length; i2++) {
                if (isCancelled()) {
                    break loop0;
                }
                NeuralNetwork mutant = mutant(neuralNetwork, i2, true);
                this.n.setNeuralNetwork(mutant);
                float floatValue2 = launchSim().floatValue();
                if (floatValue2 < floatValue) {
                    floatValue = floatValue2;
                    neuralNetwork = mutant;
                    System.out.println("mutant plus on weight " + i2 + " : " + floatValue);
                } else {
                    NeuralNetwork mutant2 = mutant(neuralNetwork, i2, false);
                    this.n.setNeuralNetwork(mutant2);
                    float floatValue3 = launchSim().floatValue();
                    if (floatValue3 < floatValue) {
                        floatValue = floatValue3;
                        neuralNetwork = mutant2;
                        System.out.println("mutant moins on weight " + i2 + " : " + floatValue);
                    }
                }
                updateProgress((i2 + 1) * i, this.pass * neuralNetwork.getWeights().length);
            }
        }
        return new Pair<>(Float.valueOf(floatValue), neuralNetwork);
    }

    NeuralNetwork mutant(NeuralNetwork neuralNetwork, int i, boolean z) {
        MultiLayerPerceptron multiLayerPerceptron = new MultiLayerPerceptron(TransferFunctionType.TANH, 6, 10, 10, 1);
        double[] array = Stream.of((Object[]) neuralNetwork.getWeights()).mapToDouble((v0) -> {
            return v0.doubleValue();
        }).toArray();
        if (z) {
            array[i] = array[i] + 0.1d;
        } else {
            array[i] = array[i] - 0.1d;
        }
        multiLayerPerceptron.setWeights(array);
        return multiLayerPerceptron;
    }

    /* JADX WARN: Code restructure failed: missing block: B:36:0x00e4, code lost:
    
        if (isCancelled() == false) goto L26;
     */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    java.lang.Float launchSim() throws java.lang.Exception {
        /*
            Method dump skipped, instructions count: 302
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: fr.inria.mochy.statsAndTasks.FullNnetPipelineTask.launchSim():java.lang.Float");
    }

    float calculVariance(ArrayList<Float> arrayList) {
        float f = 0.0f;
        Iterator<Float> it = arrayList.iterator();
        while (it.hasNext()) {
            f = (float) (f + Math.pow(it.next().floatValue() - this.targetSpeed, 2.0d));
        }
        return f / arrayList.size();
    }
}
