/*
 * Decompiled with CFR 0.152.
 */
package simulationMethod;

import PBN.BitSetPBN;
import PBN.PBN;
import PBN.Property;
import PBN.StateBit;
import functionLib.Parameters;
import functionLib.RandomProvider;
import functionLib.StatUtil;
import java.io.FileWriter;
import java.io.PrintWriter;
import java.lang.management.ManagementFactory;
import java.lang.management.ThreadMXBean;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Date;
import java.util.List;
import java.util.Random;
import org.apache.commons.math3.stat.StatUtils;
import org.apache.commons.math3.stat.correlation.Covariance;
import simulationMethod.BitSetPBNSimulationEngine;
import simulationMethod.SimulationMethod;
import userInterface.AssaLog;

public class Gelman
extends SimulationMethod {
    private BitSet trajectory;
    private List<BitSet> trajectories;
    private BitSetPBN pbn;
    private String outputName;
    private int m;
    private int n;
    private int twon;
    private int count;
    private int traIndex;
    private int previousIndex;
    private List<Integer> positiveIndex;
    private List<Integer> negativeIndex;
    private double precision = 0.01;
    private double confidence = 0.95;
    private final int maxBurnIn = 5000;
    private double threshold = 0.001;
    private boolean finishBurnIn = false;
    private long meta1;
    private long meta2;
    int bridge;
    private long[] stateA;
    private long[] stateB;
    private int[] index1;
    private int[] index2;
    private int[][][] transitionsLastChain;

    public Gelman(BitSetPBN pbn, AssaLog assalog) {
        super(assalog);
        this.pbn = pbn;
        SimpleDateFormat df = new SimpleDateFormat("yyyyMMdd-HHmm");
        this.outputName = "gelman" + df.format(new Date()) + RandomProvider.getInstance().getRandom().nextInt() + ".txt";
        this.setInstanceName();
    }

    private void initialise() {
        this.m = Parameters.NUM_PARALLEL;
        this.n = 10;
        this.twon = 2 * this.n;
        this.index1 = new int[this.m];
        this.index2 = new int[this.m];
        this.stateA = new long[this.m];
        this.stateB = new long[this.m];
        this.transitionsLastChain = new int[this.m][2][2];
        int i = 0;
        while (i < this.m) {
            this.stateA[i] = 0L;
            this.stateB[i] = 0L;
            this.transitionsLastChain[i][0][0] = 0;
            this.transitionsLastChain[i][0][1] = 0;
            this.transitionsLastChain[i][1][0] = 0;
            this.transitionsLastChain[i][1][1] = 0;
            ++i;
        }
    }

    public void setParameters(double precision, double confidence) {
        this.precision = precision;
        this.confidence = confidence;
    }

    public double[] run() throws Exception {
        ThreadMXBean thread = ManagementFactory.getThreadMXBean();
        long cpu = thread.getCurrentThreadCpuTime();
        long time = System.currentTimeMillis();
        PrintWriter pw = new PrintWriter(new FileWriter(this.outputName, true));
        SimpleDateFormat df1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
        Date start = new Date();
        pw.println("Program started at " + df1.format(start));
        pw.println("Starting the two-state parallel kenel. Number of cores = " + this.m + ".");
        pw.println("Total number of nodes is " + this.pbn.getN() + ".");
        pw.print("Checking states where ");
        if (this.positiveIndex.size() != 0) {
            pw.print("nodes with index ");
            for (int i : this.positiveIndex) {
                pw.print(String.valueOf(i) + " ");
            }
            pw.print("should have value 1 and ");
        }
        if (this.negativeIndex.size() != 0) {
            int i;
            pw.print("nodes with index ");
            i = 0;
            while (i < this.negativeIndex.size()) {
                pw.print(this.negativeIndex.get(i) + " ");
                ++i;
            }
            pw.println("should have value 0.");
        } else {
            pw.println("no requirements on other nodes.");
        }
        pw.println("Approximation is " + this.precision + ", confidence is " + this.confidence + ", fetch sample every 1 step.");
        this.initialise();
        StateBit[] initialState = new StateBit[this.m];
        double[] sj = new double[this.m];
        double[] mean = new double[this.m];
        double[] meansquare = new double[this.m];
        int statesize = this.pbn.getN();
        double mean_meansquare = 0.0;
        boolean done = false;
        double[] result = new double[4];
        Covariance covariance = new Covariance();
        BitSetPBNSimulationEngine engine = new BitSetPBNSimulationEngine(this.pbn, this);
        int i = 0;
        while (i < this.m) {
            initialState[i] = new StateBit(statesize);
            initialState[i].putLongFromTo(i, 0, statesize - 1);
            ++i;
        }
        this.trajectories = new ArrayList<BitSet>();
        i = 0;
        while (i < this.m) {
            this.trajectory = new BitSet(this.twon);
            this.trajectories.add(this.trajectory);
            ++i;
        }
        long extension = this.n;
        this.previousIndex = 0;
        while (!done) {
            i = 0;
            while (i < this.m) {
                this.count = i;
                this.traIndex = this.previousIndex;
                initialState[i] = engine.simulate(extension * 2L, initialState[i]);
                ++i;
            }
            this.previousIndex = this.traIndex;
            double within = 0.0;
            double grandmean = 0.0;
            mean_meansquare = 0.0;
            double[] values = new double[this.n];
            i = 0;
            while (i < this.m) {
                this.trajectories.get(i).clear(0, this.twon - this.n);
                mean[i] = (double)this.trajectories.get(i).cardinality() / (double)this.n;
                meansquare[i] = Math.pow(mean[i], 2.0);
                int j = this.n;
                while (j < this.twon) {
                    values[j - this.n] = this.trajectories.get(i).get(j) ? 1.0 : 0.0;
                    ++j;
                }
                sj[i] = StatUtils.variance(values);
                ++i;
            }
            within = StatUtils.mean(sj);
            grandmean = StatUtils.mean(mean);
            mean_meansquare = StatUtils.mean(meansquare);
            double between = (double)this.n * StatUtils.variance(mean);
            double varianceW = StatUtils.variance(sj);
            double varianceB = 2.0 * between * between / (double)(this.m - 1);
            double covwxsquare = covariance.covariance(sj, meansquare);
            double covwx = covariance.covariance(sj, mean);
            double covwb = (double)(this.n / this.m) * (covwxsquare - 2.0 * grandmean * covwx);
            double variance = (1.0 - 1.0 / (double)this.n) * within + (double)(this.m + 1) / (double)(this.m * this.n) * between;
            double varianceV = (Math.pow(this.n - 1, 2.0) * (varianceW /= (double)this.m) + Math.pow(1 + 1 / this.m, 2.0) * varianceB + (double)(2 * (1 / this.m + 1) * (this.n - 1)) * covwb) / Math.pow(this.n, 2.0);
            double df = 2.0 * variance * variance / varianceV;
            double psrf = Math.sqrt(variance / within * (df + 3.0) / (df + 1.0));
            if (Math.abs(1.0 - psrf) > this.threshold || varianceV == 0.0) {
                done = false;
                extension = this.n;
                this.n = (int)((long)this.n + extension);
                this.twon = this.n * 2;
                if (this.twon <= 5000) continue;
                done = true;
                this.twon = this.n;
                this.n = (int)((long)this.n - extension);
                continue;
            }
            done = true;
        }
        this.finishBurnIn = true;
        double invPhi = StatUtil.getInvCDF(0.5 * (1.0 + this.confidence), false);
        long[][] transitionsLast = new long[2][2];
        this.meta1 = 0L;
        this.meta2 = 0L;
        int i2 = 0;
        while (i2 < this.m) {
            if (this.trajectories.get(i2).get(this.n)) {
                ++this.meta1;
                this.index1[i2] = 1;
            } else {
                ++this.meta2;
                this.index1[i2] = 0;
            }
            int j = 1;
            while (j < this.n) {
                this.bridge = this.index1[i2];
                if (this.trajectories.get(i2).get(this.n + j)) {
                    ++this.meta1;
                    this.index1[i2] = 1;
                    this.index2[i2] = 1;
                    if (this.traIndex != 0) {
                        long[] lArray = transitionsLast[this.bridge];
                        int n = this.index2[i2];
                        lArray[n] = lArray[n] + 1L;
                    }
                } else {
                    ++this.meta2;
                    this.index1[i2] = 0;
                    this.index2[i2] = 0;
                    if (this.traIndex != 0) {
                        long[] lArray = transitionsLast[this.bridge];
                        int n = this.index2[i2];
                        lArray[n] = lArray[n] + 1L;
                    }
                }
                ++j;
            }
            ++i2;
        }
        double beta = transitionsLast[0][0] + transitionsLast[0][1] == 0L ? 0.0 : (double)transitionsLast[0][1] / (double)(transitionsLast[0][0] + transitionsLast[0][1]);
        double alpha = transitionsLast[1][0] + transitionsLast[1][1] == 0L ? 0.0 : (double)transitionsLast[1][0] / (double)(transitionsLast[1][0] + transitionsLast[1][1]);
        pw.println("alpha=" + alpha + ",beta=" + beta);
        long N = (long)Math.ceil(alpha * beta * (2.0 - alpha - beta) / (Math.pow(alpha + beta, 3.0) * Math.pow(this.precision / invPhi, 2.0)));
        pw.println("N=" + N);
        long current = 0L;
        current = this.twon - 1;
        this.count = 0;
        int round = 0;
        long totalStateA = this.meta1;
        long totalStateB = this.meta2;
        long[][] totalTransitionsLast = new long[2][2];
        while (N > current) {
            ++round;
            extension = (int)Math.ceil((double)(N - current) / (double)this.m);
            totalStateA = this.meta1;
            totalStateB = this.meta2;
            int j = 0;
            while (j < 2) {
                int k = 0;
                while (k < 2) {
                    totalTransitionsLast[j][k] = transitionsLast[j][k];
                    ++k;
                }
                ++j;
            }
            int i3 = 0;
            while (i3 < this.m) {
                this.count = i3;
                initialState[i3] = engine.simulate(extension, initialState[i3]);
                totalStateA += this.stateA[i3];
                totalStateB += this.stateB[i3];
                int j2 = 0;
                while (j2 < 2) {
                    int k = 0;
                    while (k < 2) {
                        long[] lArray = totalTransitionsLast[j2];
                        int n = k;
                        lArray[n] = lArray[n] + (long)this.transitionsLastChain[i3][j2][k];
                        ++k;
                    }
                    ++j2;
                }
                ++i3;
            }
            beta = totalTransitionsLast[0][0] + totalTransitionsLast[0][1] == 0L ? 0.0 : (double)totalTransitionsLast[0][1] / (double)(totalTransitionsLast[0][0] + totalTransitionsLast[0][1]);
            alpha = totalTransitionsLast[1][0] + totalTransitionsLast[1][1] == 0L ? 0.0 : (double)totalTransitionsLast[1][0] / (double)(totalTransitionsLast[1][0] + totalTransitionsLast[1][1]);
            N = (long)Math.ceil(alpha * beta * (2.0 - alpha - beta) / (Math.pow(alpha + beta, 3.0) * Math.pow(this.precision / invPhi, 2.0)));
            pw.println("Re-estimating round " + round + ": alpha=" + alpha + ", beta=" + beta + ", N=" + N + ", current length=" + (current += extension * (long)this.m) + ".");
            if (N <= 0x3FFFFFFFFFFFFFFFL && N >= 0L) continue;
            this.assalog.println("N is too large. I will stop simulation.");
            break;
        }
        pw.println("Re-estimation times=" + round);
        pw.println("final N=" + N);
        result[0] = (double)totalStateA / (double)(totalStateA + totalStateB);
        result[1] = totalStateA + totalStateB + (long)(this.m * this.n);
        result[2] = (double)(thread.getCurrentThreadCpuTime() - cpu) / 1.0E9;
        result[3] = (double)(System.currentTimeMillis() - time) / 1000.0;
        pw.println("Total simulation steps = " + result[1] + ".");
        pw.close();
        return result;
    }

    public void runPlot() throws Exception {
        this.initialise();
        StateBit[] initialState = new StateBit[this.m];
        double[] sj = new double[this.m];
        double[] mean = new double[this.m];
        int statesize = this.pbn.getN();
        double mean_meansquare = 0.0;
        boolean done = false;
        BitSetPBNSimulationEngine engine = new BitSetPBNSimulationEngine(this.pbn, this);
        Random ran = RandomProvider.getInstance().getRandom();
        long statespace = (long)Math.pow(2.0, this.pbn.getN());
        statespace /= (long)this.m;
        int i = 0;
        while (i < this.m) {
            initialState[i] = new StateBit(statesize);
            initialState[i].putLongFromTo(statespace * (long)i, 0, statesize - 1);
            ++i;
        }
        this.trajectories = new ArrayList<BitSet>();
        i = 0;
        while (i < this.m) {
            this.trajectory = new BitSet(this.twon);
            this.trajectories.add(this.trajectory);
            ++i;
        }
        this.previousIndex = 0;
        int extension = this.n;
        int k = 1;
        while (k < 501) {
            int i2 = 0;
            while (i2 < this.m) {
                this.count = i2;
                this.traIndex = this.previousIndex;
                initialState[i2] = engine.simulate(extension * 2, initialState[i2]);
                ++i2;
            }
            this.previousIndex = this.traIndex;
            double within = 0.0;
            double grandmean = 0.0;
            i2 = 0;
            while (i2 < this.m) {
                mean[i2] = (double)this.trajectories.get(i2).cardinality() / (double)this.n;
                mean_meansquare += mean[i2] * mean[i2];
                grandmean += mean[i2];
                sj[i2] = 0.0;
                int j = this.n;
                while (j < this.twon) {
                    boolean value = this.trajectories.get(i2).get(j);
                    int n = i2;
                    sj[n] = sj[n] + Math.pow((double)value - mean[i2], 2.0);
                    ++j;
                }
                sj[i2] = sj[i2] / (double)(this.n - 1);
                within += sj[i2];
                ++i2;
            }
            within /= (double)this.m;
            grandmean /= (double)this.m;
            mean_meansquare /= (double)this.m;
            double between = 0.0;
            i2 = 0;
            while (i2 < this.m) {
                between += Math.pow(grandmean - mean[i2], 2.0);
                ++i2;
            }
            between = between * (double)this.n / (double)(this.m - 1);
            double varianceW = 0.0;
            i2 = 0;
            while (i2 < this.m) {
                varianceW += Math.pow(sj[i2] - within, 2.0);
                ++i2;
            }
            varianceW /= (double)(this.m - 1);
            varianceW /= (double)this.m;
            double varianceB = 2.0 * between * between / (double)(this.m - 1);
            double covwxsquare = 0.0;
            double covwx = 0.0;
            i2 = 0;
            while (i2 < this.m) {
                covwxsquare += (sj[i2] - within) * (mean[i2] * mean[i2] - mean_meansquare);
                covwx += (sj[i2] - within) * (mean[i2] - grandmean);
                ++i2;
            }
            double covwb = (double)(this.n / this.m) * ((covwxsquare /= (double)(this.m - 1)) - 2.0 * grandmean * (covwx /= (double)(this.m - 1)));
            double varianceV = (Math.pow(this.n - 1, 2.0) * varianceW + Math.pow(1 + 1 / this.m, 2.0) * varianceB + (double)(2 * (1 / this.m + 1) * (this.n - 1)) * covwb) / Math.pow(this.n, 2.0);
            double variance = (1.0 - 1.0 / (double)this.n) * within + (double)(this.m + 1) / (double)(this.m * this.n) * between;
            double df = 2.0 * variance * variance / varianceV;
            double psrf = Math.sqrt(variance / within * (df + 3.0) / (df + 1.0));
            extension = 5;
            this.n += extension;
            this.twon = this.n * 2;
            ++k;
        }
    }

    @Override
    public void updateTransition(StateBit st, int para) {
        if (this.finishBurnIn) {
            boolean eva = this.evaluateState(st);
            int bridge = this.index1[this.count];
            if (eva) {
                int n = this.count;
                this.stateA[n] = this.stateA[n] + 1L;
                this.index1[this.count] = 1;
                this.index2[this.count] = 1;
                int[] nArray = this.transitionsLastChain[this.count][bridge];
                int n2 = this.index2[this.count];
                nArray[n2] = nArray[n2] + 1;
            } else {
                int n = this.count;
                this.stateB[n] = this.stateB[n] + 1L;
                this.index1[this.count] = 0;
                this.index2[this.count] = 0;
                int[] nArray = this.transitionsLastChain[this.count][bridge];
                int n3 = this.index2[this.count];
                nArray[n3] = nArray[n3] + 1;
            }
        } else {
            this.trajectories.get(this.count).set(this.traIndex, this.evaluateState(st));
            ++this.traIndex;
        }
    }

    public boolean evaluateState(StateBit st) {
        int i;
        if (this.positiveIndex != null) {
            i = 0;
            while (i < this.positiveIndex.size()) {
                if (!st.get(this.positiveIndex.get(i))) {
                    return false;
                }
                ++i;
            }
        }
        if (this.negativeIndex != null) {
            i = 0;
            while (i < this.negativeIndex.size()) {
                if (st.get(this.negativeIndex.get(i))) {
                    return false;
                }
                ++i;
            }
        }
        return true;
    }

    public double mean(List<Double> yjm, int start, int end) {
        double total = 0.0;
        int i = start;
        while (i < end) {
            total += yjm.get(i).doubleValue();
            ++i;
        }
        return total / (double)(end - start);
    }

    @Override
    public double[] run(PBN pbn) throws Exception {
        this.pbn = (BitSetPBN)pbn;
        return this.run();
    }

    @Override
    public void setInstanceName() {
        this.instanceName = "Gelman and Rubin";
    }

    @Override
    public void setExpressions(List<Property> properties) {
        this.positiveIndex = properties.get(0).getPositiveIndex();
        this.negativeIndex = properties.get(0).getNegativeIndex();
    }

    @Override
    public String getLogFile() {
        return this.outputName;
    }

    @Override
    public void setParameters(double[] parameters) throws Exception {
    }
}

