Java Code Examples for org.ejml.ops.CommonOps#mult()

The following examples show how to use org.ejml.ops.CommonOps#mult() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: SafeMultivariateActualizedWithDriftIntegrator.java    From beast-mcmc with GNU Lesser General Public License v2.1 6 votes vote down vote up
@Override
public void getBranchExpectation(double[] actualization, double[] parentValue, double[] displacement,
                                 double[] expectation) {

    assert (expectation != null);
    assert (expectation.length >= dimTrait);

    assert (actualization != null);
    assert (actualization.length >= dimTrait * dimTrait);

    assert (parentValue != null);
    assert (parentValue.length >= dimTrait);

    assert (displacement != null);
    assert (displacement.length >= dimTrait);

    DenseMatrix64F branchExpectationMatrix = new DenseMatrix64F(dimTrait, 1);
    CommonOps.mult(wrap(actualization, 0, dimTrait, dimTrait),
            wrap(parentValue, 0, dimTrait, 1),
            branchExpectationMatrix);
    CommonOps.addEquals(branchExpectationMatrix, wrap(displacement, 0, dimTrait, 1));

    unwrap(branchExpectationMatrix, expectation, 0);
}
 
Example 2
Source File: SafeMultivariateActualizedWithDriftIntegrator.java    From beast-mcmc with GNU Lesser General Public License v2.1 6 votes vote down vote up
private void schurComplementInverse(final DenseMatrix64F A, final DenseMatrix64F D,
                                    final DenseMatrix64F C, final DenseMatrix64F B,
                                    final double[] destination, final int offset) {
    DenseMatrix64F invA = new DenseMatrix64F(dimProcess, dimProcess);
    CommonOps.invert(A, invA);
    DenseMatrix64F invMatD = getSchurInverseComplement(invA, D, C, B);

    DenseMatrix64F invAB = new DenseMatrix64F(dimProcess, dimProcess);
    CommonOps.mult(invA, B, invAB);
    DenseMatrix64F invMatB = new DenseMatrix64F(dimProcess, dimProcess);
    CommonOps.mult(-1.0, invAB, invMatD, invMatB);

    DenseMatrix64F CinvA = new DenseMatrix64F(dimProcess, dimProcess);
    CommonOps.mult(C, invA, CinvA);
    DenseMatrix64F invMatC = new DenseMatrix64F(dimProcess, dimProcess);
    CommonOps.mult(-1.0, invMatD, CinvA, invMatC);

    DenseMatrix64F invMatA = new DenseMatrix64F(dimProcess, dimProcess);
    CommonOps.mult(-1.0, invMatB, CinvA, invMatA);
    CommonOps.addEquals(invMatA, invA);

    blockUnwrap(invMatA, invMatD, invMatC, invMatB, destination, offset);
}
 
Example 3
Source File: SafeMultivariateActualizedWithDriftIntegrator.java    From beast-mcmc with GNU Lesser General Public License v2.1 6 votes vote down vote up
private void computeIOUActualization(final int scaledOffset,
                                     DenseMatrix64F inverseSelectionStrength) {
    // YY
    DenseMatrix64F actualizationOU = wrap(actualizations, scaledOffset, dimProcess, dimProcess);

    // XX
    DenseMatrix64F temp = CommonOps.identity(dimProcess);
    CommonOps.addEquals(temp, -1.0, actualizationOU);
    DenseMatrix64F actualizationIOU = new DenseMatrix64F(dimProcess, dimProcess);
    CommonOps.mult(inverseSelectionStrength, temp, actualizationIOU);

    // YX and XX
    DenseMatrix64F actualizationYX = new DenseMatrix64F(dimProcess, dimProcess); // zeros
    DenseMatrix64F actualizationXX = CommonOps.identity(dimProcess);

    blockUnwrap(actualizationOU, actualizationXX, actualizationIOU, actualizationYX, actualizations, scaledOffset);
}
 
Example 4
Source File: BranchRateGradient.java    From beast-mcmc with GNU Lesser General Public License v2.1 5 votes vote down vote up
public void makeGradientMatrices0(DenseMatrix64F matrix1, DenseMatrix64F logDetComponent,
                                  BranchSufficientStatistics statistics, double differentialScaling) {

    final NormalSufficientStatistics above = statistics.getAbove();
    final NormalSufficientStatistics branch = statistics.getBranch();

    DenseMatrix64F Qi = above.getRawPrecision();
    CommonOps.scale(differentialScaling, branch.getRawVariance(), matrix1); //matrix1 = Si
    CommonOps.mult(Qi, matrix1, logDetComponent); //matrix0 = logDetComponent
    CommonOps.mult(logDetComponent, Qi, matrix1); //matrix1 = QuadraticComponent

}
 
Example 5
Source File: OUDiffusionModelDelegate.java    From beast-mcmc with GNU Lesser General Public License v2.1 5 votes vote down vote up
private void actualizeDisplacementGradient(ContinuousDiffusionIntegrator cdi,
                                               int nodeIndex, DenseMatrix64F gradient) {
        // q_i
        double[] qi = new double[dim * dim];
        cdi.getBranch1mActualization(getMatrixBufferOffsetIndex(nodeIndex), qi);
        DenseMatrix64F Actu = wrap(qi, 0, dim, dim);
        CommonOps.scale(-1.0, Actu);
//        for (int i = 0; i < dim; i++) {
//            Actu.unsafe_set(i, i, Actu.unsafe_get(i, i) - 1.0);
//        }
        DenseMatrix64F tmp = new DenseMatrix64F(dim, 1);
        CommonOps.mult(Actu, gradient, tmp);
        CommonOps.scale(-1.0, tmp, gradient);
    }
 
Example 6
Source File: OUDiffusionModelDelegate.java    From beast-mcmc with GNU Lesser General Public License v2.1 5 votes vote down vote up
private double[] actualizeRootGradientFull(ContinuousDiffusionIntegrator cdi,
                                           int nodeIndex, DenseMatrix64F gradient) {
    // q_i
    double[] qi = new double[dim * dim];
    cdi.getBranchActualization(getMatrixBufferOffsetIndex(nodeIndex), qi);
    DenseMatrix64F Actu = wrap(qi, 0, dim, dim);
    DenseMatrix64F tmp = new DenseMatrix64F(dim, 1);
    CommonOps.mult(Actu, gradient, tmp);
    return tmp.getData();
}
 
Example 7
Source File: MultivariateIntegrator.java    From beast-mcmc with GNU Lesser General Public License v2.1 5 votes vote down vote up
@Override
public void calculatePreOrderRoot(int priorBufferIndex, int rootNodeIndex, int precisionIndex) {

    super.calculatePreOrderRoot(priorBufferIndex, rootNodeIndex, precisionIndex);

    updatePrecisionOffsetAndDeterminant(precisionIndex);

    final DenseMatrix64F Pd = wrap(diffusions, precisionOffset, dimTrait, dimTrait);
    final DenseMatrix64F Vd = wrap(inverseDiffusions, precisionOffset, dimTrait, dimTrait);

    int rootOffset = dimPartial * rootNodeIndex;

    // TODO For each trait in parallel
    for (int trait = 0; trait < numTraits; ++trait) {

        @SuppressWarnings("SpellCheckingInspection") final DenseMatrix64F Proot = wrap(preOrderPartials, rootOffset + dimTrait, dimTrait, dimTrait);
        @SuppressWarnings("SpellCheckingInspection") final DenseMatrix64F Vroot = wrap(preOrderPartials, rootOffset + dimTrait + dimTrait * dimTrait, dimTrait, dimTrait);

        // TODO Block below is for the conjugate prior ONLY
        {
            final DenseMatrix64F tmp = matrix0;

            MissingOps.safeMult(Pd, Proot, tmp);
            unwrap(tmp, preOrderPartials, rootOffset + dimTrait);

            CommonOps.mult(Vd, Vroot, tmp);
            unwrap(tmp, preOrderPartials, rootOffset + dimTrait + dimTrait * dimTrait);
        }
        rootOffset += dimPartialForTrait;
    }
}
 
Example 8
Source File: CompoundEigenMatrix.java    From beast-mcmc with GNU Lesser General Public License v2.1 5 votes vote down vote up
private void computeTransformedMatrix() {
    DenseMatrix64F baseMatrix = wrapSpherical(offDiagonalParameter.getParameterValues(), 0, dim);
    DenseMatrix64F diagonalMatrix = wrapDiagonal(diagonalParameter.getParameterValues(), 0, dim);

    CommonOps.mult(baseMatrix, diagonalMatrix, temp);
    CommonOps.invert(baseMatrix);
    CommonOps.mult(temp, baseMatrix, transformedMatrix);

    compositionKnown = true;
}
 
Example 9
Source File: SafeMultivariateActualizedWithDriftIntegrator.java    From beast-mcmc with GNU Lesser General Public License v2.1 5 votes vote down vote up
private static void transformMatrixGeneral(DenseMatrix64F matrix, DenseMatrix64F rotation) {
    int dim = matrix.getNumRows();
    DenseMatrix64F tmp = new DenseMatrix64F(dim, dim);
    DenseMatrix64F rotationInverse = new DenseMatrix64F(dim, dim);
    CommonOps.invert(rotation, rotationInverse);
    CommonOps.mult(rotationInverse, matrix, tmp);
    CommonOps.multTransB(tmp, rotationInverse, matrix);
}
 
Example 10
Source File: SafeMultivariateActualizedWithDriftIntegrator.java    From beast-mcmc with GNU Lesser General Public License v2.1 5 votes vote down vote up
@Override
void scaleAndDriftMean(int ibo, int imo, int ido) {
    final DenseMatrix64F Qdi = wrap(actualizations, imo, dimTrait, dimTrait);
    final DenseMatrix64F ni = wrap(preOrderPartials, ibo, dimTrait, 1);
    final DenseMatrix64F niacc = matrixNiacc;
    CommonOps.mult(Qdi, ni, niacc);
    unwrap(niacc, preOrderPartials, ibo);

    for (int g = 0; g < dimTrait; ++g) {
        preOrderPartials[ibo + g] += displacements[ido + g];
    }

}
 
Example 11
Source File: SafeMultivariateActualizedWithDriftIntegrator.java    From beast-mcmc with GNU Lesser General Public License v2.1 4 votes vote down vote up
private void scaleVariance(DenseMatrix64F Q, DenseMatrix64F P,
                           DenseMatrix64F QtP, DenseMatrix64F QtPQ) {
    CommonOps.mult(Q, P, QtP);
    CommonOps.multTransB(QtP, Q, QtPQ);
}
 
Example 12
Source File: ConditionalVarianceAndTransform2.java    From beast-mcmc with GNU Lesser General Public License v2.1 4 votes vote down vote up
public ConditionalVarianceAndTransform2(final DenseMatrix64F variance,
                                        final int[] missingIndices, final int[] notMissingIndices) {

    assert (missingIndices.length + notMissingIndices.length == variance.getNumRows());
    assert (missingIndices.length + notMissingIndices.length == variance.getNumCols());

    this.missingIndices = missingIndices;
    this.notMissingIndices = notMissingIndices;

    if (DEBUG) {
        System.err.println("variance:\n" + variance);
    }

    DenseMatrix64F S22 = new DenseMatrix64F(notMissingIndices.length, notMissingIndices.length);
    gatherRowsAndColumns(variance, S22, notMissingIndices, notMissingIndices);

    if (DEBUG) {
        System.err.println("S22:\n" + S22);
    }

    DenseMatrix64F S22Inv = new DenseMatrix64F(notMissingIndices.length, notMissingIndices.length);
    CommonOps.invert(S22, S22Inv);

    if (DEBUG) {
        System.err.println("S22Inv:\n" + S22Inv);
    }

    DenseMatrix64F S12 = new DenseMatrix64F(missingIndices.length, notMissingIndices.length);
    gatherRowsAndColumns(variance, S12, missingIndices, notMissingIndices);

    if (DEBUG) {
        System.err.println("S12:\n" + S12);
    }

    DenseMatrix64F S12S22Inv = new DenseMatrix64F(missingIndices.length, notMissingIndices.length);
    CommonOps.mult(S12, S22Inv, S12S22Inv);

    if (DEBUG) {
        System.err.println("S12S22Inv:\n" + S12S22Inv);
    }

    DenseMatrix64F S12S22InvS21 = new DenseMatrix64F(missingIndices.length, missingIndices.length);
    CommonOps.multTransB(S12S22Inv, S12, S12S22InvS21);

    if (DEBUG) {
        System.err.println("S12S22InvS21:\n" + S12S22InvS21);
    }

    sBar = new DenseMatrix64F(missingIndices.length, missingIndices.length);
    gatherRowsAndColumns(variance, sBar, missingIndices, missingIndices);
    CommonOps.subtract(sBar, S12S22InvS21, sBar);


    if (DEBUG) {
        System.err.println("sBar:\n" + sBar);
    }


    this.affineTransform = S12S22Inv;
    this.tempStorage = new double[missingIndices.length];

    this.numMissing = missingIndices.length;
    this.numNotMissing = notMissingIndices.length;

}
 
Example 13
Source File: SafeMultivariateActualizedWithDriftIntegrator.java    From beast-mcmc with GNU Lesser General Public License v2.1 4 votes vote down vote up
private void computeIOUVarianceBranch(final int sourceOffset,
                                      final int destinationOffset,
                                      double branchLength,
                                      DenseMatrix64F inverseSelectionStrength) {
    DenseMatrix64F actualization = wrap(actualizations, destinationOffset, dimProcess, dimProcess);
    DenseMatrix64F stationaryVariance = wrap(stationaryVariances, sourceOffset, dimProcess, dimProcess);

    DenseMatrix64F invAS = new DenseMatrix64F(dimProcess, dimProcess);
    CommonOps.mult(inverseSelectionStrength, stationaryVariance, invAS);

    //// Variance YY
    DenseMatrix64F varianceYY = wrap(variances, destinationOffset, dimProcess, dimProcess);

    //// Variance XX
    DenseMatrix64F varianceXX = new DenseMatrix64F(dimProcess, dimProcess);
    // Variance 1
    CommonOps.multTransB(invAS, inverseSelectionStrength, varianceXX);
    DenseMatrix64F temp = new DenseMatrix64F(dimProcess, dimProcess);
    CommonOps.multTransB(varianceXX, actualization, temp);
    CommonOps.multAdd(-1.0, actualization, temp, varianceXX);
    // Delta
    DenseMatrix64F delta = new DenseMatrix64F(dimProcess, dimProcess);
    addTrans(invAS, delta);
    // Variance 2
    CommonOps.addEquals(varianceXX, branchLength, delta);
    // Variance 3
    DenseMatrix64F temp2 = CommonOps.identity(dimProcess);
    CommonOps.addEquals(temp2, -1.0, actualization);
    DenseMatrix64F temp3 = new DenseMatrix64F(dimProcess, dimProcess);
    CommonOps.mult(temp2, inverseSelectionStrength, temp3);
    CommonOps.mult(temp3, delta, temp2);
    addTrans(temp2, temp);
    // All
    CommonOps.addEquals(varianceXX, -1.0, temp);

    //// Variance XY
    DenseMatrix64F varianceXY = new DenseMatrix64F(dimProcess, dimProcess);
    // Variance 1
    CommonOps.multTransB(stationaryVariance, temp3, varianceXY);
    // Variance 2
    CommonOps.mult(temp3, stationaryVariance, temp);
    CommonOps.multTransB(temp, actualization, temp2);
    // All
    CommonOps.addEquals(varianceXY, -1.0, temp2);

    //// Variance YX
    DenseMatrix64F varianceYX = new DenseMatrix64F(dimProcess, dimProcess);
    CommonOps.transpose(varianceXY, varianceYX);

    blockUnwrap(varianceYY, varianceXX, varianceXY, varianceYX, variances, destinationOffset);
    schurComplementInverse(varianceYY, varianceXX, varianceXY, varianceYX, precisions, destinationOffset);
}
 
Example 14
Source File: LKJTransformTest.java    From beast-mcmc with GNU Lesser General Public License v2.1 4 votes vote down vote up
public void testJacobianComposition() {
    System.out.println("\nTest LKJ Composition Cholesky Jacobian.");

    // Transforms
    double[] cholValues = transformChol.inverse(CPCs, 0, CPCs.length);
    double[] corrValues = transform.inverse(CPCs, 0, CPCs.length);
    double[] corrValuesBis = transformCorrToChol.inverse(cholValues, 0, cholValues.length);

    Transform.MultivariableTransform transformComposition = new Transform.ComposeMultivariable(transformChol, transformCorrToChol);
    double[] corrValuesTer = transformComposition.inverse(CPCs, 0, CPCs.length);


    for (int k = 0; k < CPCs.length; k++) {
        assertEquals("inverse transform k=" + k,
                format.format(corrValues[k]),
                format.format(corrValuesBis[k]));
    }

    for (int k = 0; k < CPCs.length; k++) {
        assertEquals("inverse transform k=" + k,
                format.format(corrValues[k]),
                format.format(corrValuesTer[k]));
    }

    double jacobianDet2ToInf = 0.0;
    for (int k = 1; k < dim; k++) {
        double[] tempCPC = new double[k];
        double[] tempChol = new double[k];
        int l = k - 1;
        for (int i = 0; i < k; i++) {
            tempCPC[i] = CPCs[l];
            tempChol[i] = cholValues[l];
            l += dim - i - 2;
        }
        EuclideanToInfiniteNormUnitBallTransform transform2ToInf = new EuclideanToInfiniteNormUnitBallTransform(k);
        double[] cholValuesBis = transform2ToInf.inverse(tempCPC, 0, tempCPC.length);
        jacobianDet2ToInf += transform2ToInf.getLogJacobian(tempChol, 0, tempChol.length);
        for (int i = 0; i < k; i++) {
            assertEquals("spherical=" + k + i,
                    format.format(tempChol[i]),
                    format.format(cholValuesBis[i]));
        }
    }

    // Determinant
    double jacobianDetCholToCPC = transformChol.getLogJacobian(cholValues, 0, CPCs.length);
    double jacobianDetCorrToChol = transformCorrToChol.getLogJacobian(corrValues, 0, CPCs.length);
    double jacobianDetCorrToCPC = transform.getLogJacobian(corrValues, 0, CPCs.length);
    double jacobianDetCorrToCPCComp = transformComposition.getLogJacobian(corrValues, 0, CPCs.length);

    System.out.println("Log Jacobiant Det Chol to CPC=" + jacobianDetCholToCPC);
    System.out.println("Log Jacobiant Det Corr to Chol=" + jacobianDetCorrToChol);
    System.out.println("Log Jacobiant Det Corr to CPC=" + jacobianDetCorrToCPC);
    System.out.println("Log Jacobiant Det Corr to CPC composition=" + jacobianDetCorrToCPCComp);

    assertEquals("jacobian log det",
            format.format(jacobianDetCorrToCPC),
            format.format(jacobianDetCholToCPC + jacobianDetCorrToChol));

    assertEquals("jacobian log det",
            format.format(jacobianDetCorrToCPC),
            format.format(jacobianDetCorrToCPCComp));

    assertEquals("jacobian log det",
            format.format(jacobianDetCholToCPC),
            format.format(jacobianDet2ToInf));

    // Matrices
    DenseMatrix64F jacobianMatCholToCPC = new DenseMatrix64F(transformChol.computeJacobianMatrixInverse(CPCs));
    DenseMatrix64F jacobianMatCorrToChol = new DenseMatrix64F(transformCorrToChol.computeJacobianMatrixInverse(cholValues));
    DenseMatrix64F jacobianMatCorrToCPC = new DenseMatrix64F(transform.computeJacobianMatrixInverse(CPCs));

    DenseMatrix64F jacobianMatComposition = new DenseMatrix64F(jacobianMatCorrToCPC.numRows, jacobianMatCorrToCPC.numCols);
    CommonOps.mult(jacobianMatCholToCPC, jacobianMatCorrToChol, jacobianMatComposition);

    System.out.println("Jacobiant Corr to CPC=" + jacobianMatCorrToCPC);
    System.out.println("Jacobiant Composition=" + jacobianMatComposition);

    for (int i = 0; i < jacobianMatCorrToCPC.numRows; i++) {
        for (int j = i; j < jacobianMatCorrToCPC.numCols; j++) {
            assertEquals("jacobian matrix compose (" + i + ", " + j + "): ",
                    format.format(jacobianMatComposition.get(i, j)),
                    format.format(jacobianMatCorrToCPC.get(i, j)));
        }
    }

    // Update
    double[] gradient = new double[CPCs.length];
    System.arraycopy(CPCsLimit, 0, gradient, 0, CPCs.length);

    double[] updated = transform.updateGradientLogDensity(gradient, corrValues, 0, gradient.length);
    double[] updatedComposition = transformComposition.updateGradientLogDensity(gradient, corrValues, 0, gradient.length);

    for (int k = 0; k < updated.length; k++) {
        assertEquals("updated gradient " + k + ": ",
                format.format(updated[k]),
                format.format(updatedComposition[k]));
    }
}
 
Example 15
Source File: SafeMultivariateActualizedWithDriftIntegrator.java    From beast-mcmc with GNU Lesser General Public License v2.1 4 votes vote down vote up
private void transformMatrixBaseGeneral(DenseMatrix64F matrix, DenseMatrix64F rotation) {
    DenseMatrix64F tmp = new DenseMatrix64F(dimProcess, dimProcess);
    CommonOps.mult(rotation, matrix, tmp);
    CommonOps.invert(rotation); // Warning: side effect on rotation matrix.
    CommonOps.mult(tmp, rotation, matrix);
}
 
Example 16
Source File: SafeMultivariateActualizedWithDriftIntegrator.java    From beast-mcmc with GNU Lesser General Public License v2.1 4 votes vote down vote up
public static void transformMatrixBack(DenseMatrix64F matrix, DenseMatrix64F rotation) {
    int dim = matrix.getNumRows();
    DenseMatrix64F tmp = new DenseMatrix64F(dim, dim);
    CommonOps.multTransB(matrix, rotation, tmp);
    CommonOps.mult(rotation, tmp, matrix);
}
 
Example 17
Source File: SafeMultivariateActualizedWithDriftIntegrator.java    From beast-mcmc with GNU Lesser General Public License v2.1 4 votes vote down vote up
private static void transformMatrixSymmetric(DenseMatrix64F matrix, DenseMatrix64F rotation) {
    int dim = matrix.getNumRows();
    DenseMatrix64F tmp = new DenseMatrix64F(dim, dim);
    CommonOps.multTransA(rotation, matrix, tmp);
    CommonOps.mult(tmp, rotation, matrix);
}
 
Example 18
Source File: SafeMultivariateIntegrator.java    From beast-mcmc with GNU Lesser General Public License v2.1 4 votes vote down vote up
private InversionResult increaseVariances(int ibo,
                                          int iBuffer,
                                          final DenseMatrix64F Vdi,
                                          final DenseMatrix64F Pdi,
                                          final DenseMatrix64F Pip,
                                          final boolean getDeterminant) {

    if (TIMING) {
        startTime("peel1");
    }

    // A. Get current precision of i and j
    final DenseMatrix64F Pi = wrap(partials, ibo + dimTrait, dimTrait, dimTrait);

    if (TIMING) {
        endTime("peel1");
        startTime("peel2");
    }

    // B. Integrate along branch using two matrix inversions

    final boolean useVariancei = anyDiagonalInfinities(Pi);
    InversionResult ci = null;

    if (useVariancei) {

        final DenseMatrix64F Vip = matrix0;
        final DenseMatrix64F Vi = wrap(partials, ibo + dimTrait + dimTrait * dimTrait, dimTrait, dimTrait);
        CommonOps.add(Vi, Vdi, Vip);
        if (allZeroOrInfinite(Vip)) {
            throw new RuntimeException("Zero-length branch on data is not allowed.");
        }
        ci = safeInvert2(Vip, Pip, getDeterminant);

    } else {

        final DenseMatrix64F tmp1 = matrix0;
        CommonOps.add(Pi, Pdi, tmp1);
        final DenseMatrix64F tmp2 = matrix1;
        safeInvert2(tmp1, tmp2, false);
        CommonOps.mult(tmp2, Pi, tmp1);
        idMinusA(tmp1);
        if (getDeterminant) ci = safeDeterminant(tmp1, true);
        CommonOps.mult(Pi, tmp1, Pip);
        if (getDeterminant && getEffectiveDimension(iBuffer) > 0) {
            InversionResult cP = safeDeterminant(Pi, true);
            ci = mult(ci, cP);
        }
    }

    if (TIMING) {
        endTime("peel2");
    }

    return ci;
}
 
Example 19
Source File: IntegratedOUDiffusionModelDelegate.java    From beast-mcmc with GNU Lesser General Public License v2.1 4 votes vote down vote up
@Override
public double[][] getJointVariance(final double priorSampleSize,
                                   final double[][] treeVariance, final double[][] treeSharedLengths,
                                   final double[][] traitVariance) {

    double[] eigVals = this.getEigenValuesStrengthOfSelection();
    DenseMatrix64F V = wrap(this.getEigenVectorsStrengthOfSelection(), 0, dim, dim);
    DenseMatrix64F Vinv = new DenseMatrix64F(dim, dim);
    CommonOps.invert(V, Vinv);

    DenseMatrix64F transTraitVariance = new DenseMatrix64F(traitVariance);

    DenseMatrix64F tmp = new DenseMatrix64F(dim, dim);
    CommonOps.mult(Vinv, transTraitVariance, tmp);
    CommonOps.multTransB(tmp, Vinv, transTraitVariance);

    // Computation of matrix
    int ntaxa = tree.getExternalNodeCount();
    double ti;
    double tj;
    double tij;
    double ep;
    double eq;
    double var;
    DenseMatrix64F varTemp = new DenseMatrix64F(dim, dim);
    double[][] jointVariance = new double[dim * ntaxa][dim * ntaxa];
    for (int i = 0; i < ntaxa; ++i) {
        for (int j = 0; j < ntaxa; ++j) {
            ti = treeSharedLengths[i][i];
            tj = treeSharedLengths[j][j];
            tij = treeSharedLengths[i][j];
            for (int p = 0; p < dim; ++p) {
                for (int q = 0; q < dim; ++q) {
                    ep = eigVals[p];
                    eq = eigVals[q];
                    var = tij / ep / eq;
                    var += (1 - Math.exp(ep * tij)) * Math.exp(-ep * ti) / ep / ep / eq;
                    var += (1 - Math.exp(eq * tij)) * Math.exp(-eq * tj) / ep / eq / eq;
                    var -= (1 - Math.exp((ep + eq) * tij)) * Math.exp(-ep * ti) * Math.exp(-eq * tj) / ep / eq / (ep + eq);
                    var += (1 - Math.exp(-ep * ti)) * (1 - Math.exp(-eq * tj)) / ep / eq / priorSampleSize;
                    var += 1 / priorSampleSize;
                    varTemp.set(p, q, var * transTraitVariance.get(p, q));
                }
            }
            CommonOps.mult(V, varTemp, tmp);
            CommonOps.multTransB(tmp, V, varTemp);
            for (int p = 0; p < dim; ++p) {
                for (int q = 0; q < dim; ++q) {
                    jointVariance[i * dim + p][j * dim + q] = varTemp.get(p, q);
                }
            }
        }
    }
    return jointVariance;
}
 
Example 20
Source File: BranchRateGradient.java    From beast-mcmc with GNU Lesser General Public License v2.1 4 votes vote down vote up
public void makeGradientMatrices1(DenseMatrix64F additionalVariance, DenseMatrix64F quadraticComponent,
                                  NormalSufficientStatistics jointStatistics) {

    CommonOps.mult(quadraticComponent, jointStatistics.getRawVariance(), additionalVariance);
}