/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.treedatalikelihood.discrete;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.branchratemodel.DifferentiableBranchRates;
import dr.evomodel.treedatalikelihood.BeagleDataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.discrete.DiscreteTraitBranchRateDelegate;
import java.util.Arrays;

public class DiscreteTraitNodeHeightDelegate
extends DiscreteTraitBranchRateDelegate {
    static final String GRADIENT_TRAIT_NAME = "NodeHeightGradient";
    static final String HESSIAN_TRAIT_NAME = "NodeHeightHessian";
    private final DifferentiableBranchRates branchRates;

    DiscreteTraitNodeHeightDelegate(String string, Tree tree, BeagleDataLikelihoodDelegate beagleDataLikelihoodDelegate, DifferentiableBranchRates differentiableBranchRates) {
        super(string, tree, beagleDataLikelihoodDelegate);
        this.branchRates = differentiableBranchRates;
    }

    @Override
    protected int getGradientLength() {
        return this.tree.getInternalNodeCount();
    }

    @Override
    protected void getNodeDerivatives(Tree tree, double[] dArray, double[] dArray2) {
        Object object;
        Object object2;
        int n;
        double[] dArray3 = new double[tree.getNodeCount() - 1];
        double[] dArray4 = dArray2 == null ? null : new double[tree.getNodeCount() - 1];
        super.getNodeDerivatives(tree, dArray3, dArray4);
        if (dArray != null) {
            Arrays.fill(dArray, 0.0);
            for (n = 0; n < tree.getInternalNodeCount(); ++n) {
                object2 = tree.getNode(n + tree.getExternalNodeCount());
                for (int i = 0; i < tree.getChildCount((NodeRef)object2); ++i) {
                    object = tree.getChild((NodeRef)object2, i);
                    int n2 = this.getParameterIndex((NodeRef)object, tree);
                    int n3 = n;
                    dArray[n3] = dArray[n3] + dArray3[n2] * this.branchRates.getBranchRate(tree, (NodeRef)object);
                }
                if (tree.isRoot((NodeRef)object2)) continue;
                int n4 = n;
                dArray[n4] = dArray[n4] - dArray3[this.getParameterIndex((NodeRef)object2, tree)] * this.branchRates.getBranchRate(tree, (NodeRef)object2);
            }
        }
        if (dArray2 != null) {
            int n5;
            n = tree.getInternalNodeCount();
            object2 = new double[n][];
            double[][] dArrayArray = new double[tree.getNodeCount() - 1][];
            object = new double[tree.getNodeCount()][this.patternCount * this.stateCount * this.categoryCount];
            double[][] dArray5 = new double[tree.getNodeCount()][this.patternCount * this.stateCount * this.categoryCount];
            double[][] dArray6 = new double[tree.getNodeCount()][this.stateCount * this.stateCount * this.categoryCount];
            for (int i = 0; i < tree.getNodeCount(); ++i) {
                this.beagle.getPartials(this.getPostOrderPartialIndex(i), -1, dArray5[i]);
                this.beagle.getTransitionMatrix(this.evolutionaryProcessDelegate.getMatrixIndex(i), dArray6[i]);
                this.beagle.getPartials(this.getPreOrderPartialIndex(i), -1, object[i]);
            }
            double[] dArray7 = new double[this.stateCount * this.stateCount * this.categoryCount];
            double[] dArray8 = new double[this.stateCount * this.stateCount * this.categoryCount];
            double[] dArray9 = new double[this.stateCount * this.stateCount * this.categoryCount];
            double[] dArray10 = new double[this.patternCount * this.stateCount * this.categoryCount];
            double[] dArray11 = new double[this.patternCount * this.stateCount * this.categoryCount];
            double[] dArray12 = new double[this.patternCount * this.stateCount * this.categoryCount];
            double[] dArray13 = new double[this.patternCount * this.stateCount * this.categoryCount];
            double[] dArray14 = new double[this.patternCount * this.stateCount * this.categoryCount];
            double[] dArray15 = new double[this.stateCount * this.stateCount * this.categoryCount];
            double[][] dArrayArray2 = new double[tree.getNodeCount()][];
            this.evolutionaryProcessDelegate.getSubstitutionModel(0).getInfinitesimalMatrix(dArray15);
            double[][] dArrayArray3 = new double[tree.getNodeCount()][];
            double[][] dArrayArray4 = new double[tree.getNodeCount()][];
            double[] dArray16 = new double[this.patternCount];
            double[] dArray17 = new double[this.patternCount];
            double[] dArray18 = new double[this.patternCount];
            double[] dArray19 = new double[this.patternCount];
            for (n5 = 0; n5 < tree.getNodeCount() - 1; ++n5) {
                this.beagle.getTransitionMatrix(this.evolutionaryProcessDelegate.getInfinitesimalMatrixBufferIndex(n5), dArray7);
                this.getMatrixVectorProduct(dArray7, dArray5[n5], dArray14);
                this.getMatrixVectorProduct(dArray7, dArray14, dArray10);
                dArrayArray2[n5] = this.getVectorStateReduction(this.getVectorVectorProduct(dArray5[n5], object[n5]));
                dArrayArray3[n5] = this.getVectorVectorDivision(this.getVectorStateReduction(this.getVectorVectorProduct(object[n5], dArray14)), dArrayArray2[n5]);
                dArrayArray4[n5] = this.getVectorMinusVector(this.getVectorVectorDivision(this.getVectorStateReduction(this.getVectorVectorProduct(object[n5], dArray10)), dArrayArray2[n5]), this.getVectorVectorProduct(dArrayArray3[n5], dArrayArray3[n5]));
            }
            dArrayArray2[tree.getRoot().getNumber()] = this.getDoubleVectorReduction(dArray5[tree.getRoot().getNumber()], object[tree.getRoot().getNumber()], true);
            for (n5 = 0; n5 < n; ++n5) {
                NodeRef nodeRef = tree.getNode(n5 + tree.getExternalNodeCount());
                NodeRef nodeRef2 = tree.getChild(nodeRef, 0);
                NodeRef nodeRef3 = tree.getChild(nodeRef, 1);
                this.beagle.getTransitionMatrix(this.evolutionaryProcessDelegate.getInfinitesimalMatrixBufferIndex(nodeRef.getNumber()), dArray7);
                this.beagle.getTransitionMatrix(this.evolutionaryProcessDelegate.getInfinitesimalMatrixBufferIndex(nodeRef2.getNumber()), dArray8);
                this.beagle.getTransitionMatrix(this.evolutionaryProcessDelegate.getInfinitesimalMatrixBufferIndex(nodeRef3.getNumber()), dArray9);
                this.getMatrixVectorProduct(dArray6[nodeRef2.getNumber()], dArray5[nodeRef2.getNumber()], dArray10);
                this.getMatrixVectorProduct(dArray8, dArray10, dArray12);
                this.getMatrixVectorProduct(dArray6[nodeRef3.getNumber()], dArray5[nodeRef3.getNumber()], dArray11);
                this.getMatrixVectorProduct(dArray9, dArray11, dArray13);
                this.getTripleVectorReduction(dArray12, dArray13, object[nodeRef.getNumber()], dArray16);
                this.getTripleVectorReduction(dArray10, dArray11, object[nodeRef.getNumber()], dArray17);
                this.getVectorVectorDivision(dArray16, dArray17, dArray18);
                this.getVectorVectorProduct(dArrayArray3[this.getParameterIndex(nodeRef2, tree)], dArrayArray3[this.getParameterIndex(nodeRef3, tree)], dArray19);
                object2[n5] = this.getVectorMinusVector(dArray18, dArray19);
                this.getVectorPlusScaledVector(dArrayArray4[this.getParameterIndex(nodeRef2, tree)], dArrayArray4[this.getParameterIndex(nodeRef3, tree)], this.branchRates.getBranchRate(tree, nodeRef2) * this.branchRates.getBranchRate(tree, nodeRef2), this.branchRates.getBranchRate(tree, nodeRef3) * this.branchRates.getBranchRate(tree, nodeRef3), dArray18);
                this.getVectorPlusScaledVector(dArray18, object2[n5], 1.0, 2.0 * this.branchRates.getBranchRate(tree, nodeRef2) * this.branchRates.getBranchRate(tree, nodeRef3), dArray19);
                dArray2[n5] = this.getVectorPatternReduction(dArray19);
                if (tree.isRoot(nodeRef)) continue;
                this.getMatrixTransformVectorProduct(dArray7, object[nodeRef.getNumber()], dArray14);
                this.getTripleVectorReduction(dArray12, dArray11, dArray14, dArray16);
                this.getTripleVectorReduction(dArray10, dArray11, object[nodeRef.getNumber()], dArray17);
                this.getVectorVectorDivision(dArray16, dArray17, dArray18);
                this.getVectorVectorProduct(dArrayArray3[this.getParameterIndex(nodeRef2, tree)], dArrayArray3[this.getParameterIndex(nodeRef, tree)], dArray19);
                dArrayArray[this.getParameterIndex((NodeRef)nodeRef2, (Tree)tree)] = this.getVectorMinusVector(dArray18, dArray19);
                this.getTripleVectorReduction(dArray13, dArray10, dArray14, dArray16);
                this.getTripleVectorReduction(dArray10, dArray11, object[nodeRef.getNumber()], dArray17);
                this.getVectorVectorDivision(dArray16, dArray17, dArray18);
                this.getVectorVectorProduct(dArrayArray3[this.getParameterIndex(nodeRef3, tree)], dArrayArray3[this.getParameterIndex(nodeRef, tree)], dArray19);
                dArrayArray[this.getParameterIndex((NodeRef)nodeRef3, (Tree)tree)] = this.getVectorMinusVector(dArray18, dArray19);
                this.getVectorPlusScaledVector(dArrayArray4[this.getParameterIndex(nodeRef, tree)], dArrayArray[this.getParameterIndex(nodeRef2, tree)], this.branchRates.getBranchRate(tree, nodeRef) * this.branchRates.getBranchRate(tree, nodeRef), -2.0 * this.branchRates.getBranchRate(tree, nodeRef) * this.branchRates.getBranchRate(tree, nodeRef2), dArray18);
                this.getVectorPlusScaledVector(dArray18, dArrayArray[this.getParameterIndex(nodeRef3, tree)], 1.0, -2.0 * this.branchRates.getBranchRate(tree, nodeRef) * this.branchRates.getBranchRate(tree, nodeRef3), dArray19);
                int n6 = n5;
                dArray2[n6] = dArray2[n6] + this.getVectorPatternReduction(dArray19);
            }
        }
    }

    private double[] getVectorMinusVector(double[] dArray, double[] dArray2) {
        double[] dArray3 = new double[dArray.length];
        for (int i = 0; i < dArray.length; ++i) {
            dArray3[i] = dArray[i] - dArray2[i];
        }
        return dArray3;
    }

    private void getVectorPlusScaledVector(double[] dArray, double[] dArray2, double d, double d2, double[] dArray3) {
        if (dArray2 != null) {
            for (int i = 0; i < dArray.length; ++i) {
                dArray3[i] = dArray[i] * d + dArray2[i] * d2;
            }
        } else {
            for (int i = 0; i < dArray.length; ++i) {
                dArray3[i] = dArray[i] * d;
            }
        }
    }

    private int getParameterIndex(NodeRef nodeRef, Tree tree) {
        return nodeRef.getNumber() < tree.getRoot().getNumber() ? nodeRef.getNumber() : nodeRef.getNumber() - 1;
    }

    private double[] getVectorStateReduction(double[] dArray) {
        double[] dArray2 = new double[this.patternCount];
        for (int i = 0; i < this.categoryCount; ++i) {
            double d = this.siteRateModel.getProportionForCategory(i);
            int n = 0;
            while (n < dArray2.length) {
                double d2 = 0.0;
                for (int j = 0; j < this.stateCount; ++j) {
                    d2 += dArray[i * this.patternCount * this.stateCount + n * this.stateCount + j];
                }
                int n2 = n++;
                dArray2[n2] = dArray2[n2] + d2 * d;
            }
        }
        return dArray2;
    }

    private double[] getVectorVectorProduct(double[] dArray, double[] dArray2) {
        double[] dArray3 = new double[dArray.length];
        for (int i = 0; i < dArray.length; ++i) {
            dArray3[i] = dArray[i] * dArray2[i];
        }
        return dArray3;
    }

    private void getVectorVectorProduct(double[] dArray, double[] dArray2, double[] dArray3) {
        for (int i = 0; i < dArray.length; ++i) {
            dArray3[i] = dArray[i] * dArray2[i];
        }
    }

    private void getMatrixVectorProduct(double[] dArray, double[] dArray2, double[] dArray3) {
        assert (dArray2.length == dArray3.length);
        for (int i = 0; i < this.categoryCount; ++i) {
            for (int j = 0; j < this.patternCount; ++j) {
                for (int k = 0; k < this.stateCount; ++k) {
                    double d = 0.0;
                    for (int i2 = 0; i2 < this.stateCount; ++i2) {
                        d += dArray[i * this.stateCount * this.stateCount + k * this.stateCount + i2] * dArray2[i * this.patternCount * this.stateCount + j * this.stateCount + i2];
                    }
                    dArray3[i * this.patternCount * this.stateCount + j * this.stateCount + k] = d;
                }
            }
        }
    }

    private void getMatrixTransformVectorProduct(double[] dArray, double[] dArray2, double[] dArray3) {
        assert (dArray2.length == dArray3.length);
        for (int i = 0; i < this.categoryCount; ++i) {
            for (int j = 0; j < this.patternCount; ++j) {
                for (int k = 0; k < this.stateCount; ++k) {
                    double d = 0.0;
                    for (int i2 = 0; i2 < this.stateCount; ++i2) {
                        d += dArray[i * this.stateCount * this.stateCount + i2 * this.stateCount + k] * dArray2[i * this.patternCount * this.stateCount + j * this.stateCount + i2];
                    }
                    dArray3[i * this.patternCount * this.stateCount + j * this.stateCount + k] = d;
                }
            }
        }
    }

    private double getVectorPatternReduction(double[] dArray) {
        double d = 0.0;
        for (int i = 0; i < this.patternCount; ++i) {
            d += dArray[i] * this.patternList.getPatternWeight(i);
        }
        return d;
    }

    private void getTripleVectorReduction(double[] dArray, double[] dArray2, double[] dArray3, double[] dArray4) {
        assert (dArray.length == dArray2.length);
        assert (dArray2.length == dArray3.length);
        Arrays.fill(dArray4, 0.0);
        for (int i = 0; i < this.categoryCount; ++i) {
            double d = this.siteRateModel.getProportionForCategory(i);
            int n = 0;
            while (n < this.patternCount) {
                double d2 = 0.0;
                for (int j = 0; j < this.stateCount; ++j) {
                    int n2 = i * this.stateCount * this.patternCount + n * this.stateCount + j;
                    d2 += dArray[n2] * dArray2[n2] * dArray3[n2];
                }
                int n3 = n++;
                dArray4[n3] = dArray4[n3] + d * d2;
            }
        }
    }

    private double[] getVectorVectorDivision(double[] dArray, double[] dArray2) {
        double[] dArray3 = new double[dArray.length];
        for (int i = 0; i < dArray.length; ++i) {
            dArray3[i] = dArray[i] / dArray2[i];
        }
        return dArray3;
    }

    private void getVectorVectorDivision(double[] dArray, double[] dArray2, double[] dArray3) {
        for (int i = 0; i < dArray.length; ++i) {
            dArray3[i] = dArray[i] / dArray2[i];
        }
    }

    private double[] getDoubleVectorReduction(double[] dArray, double[] dArray2, boolean bl) {
        assert (dArray.length == dArray2.length);
        double[] dArray3 = new double[this.patternCount];
        double[] dArray4 = this.siteRateModel.getCategoryProportions();
        for (int i = 0; i < this.categoryCount; ++i) {
            double d = dArray4[i] * (bl ? 1.0 : this.siteRateModel.getRateForCategory(i));
            int n = 0;
            while (n < this.patternCount) {
                double d2 = 0.0;
                for (int j = 0; j < this.stateCount; ++j) {
                    int n2 = i * this.stateCount * this.patternCount + n * this.stateCount + j;
                    d2 += dArray[n2] * dArray2[n2];
                }
                int n3 = n++;
                dArray3[n3] = dArray3[n3] + d2 * d;
            }
        }
        return dArray3;
    }

    @Override
    protected String getGradientTraitName() {
        return GRADIENT_TRAIT_NAME;
    }

    @Override
    protected String getHessianTraitName() {
        return HESSIAN_TRAIT_NAME;
    }
}

