Python torch.qr() Examples

The following are 29 code examples of torch.qr(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source File: tensor.py    From tntorch with GNU Lesser General Public License v3.0 6 votes vote down vote up
def factor_orthogonalize(self, mu):
        """
        Pushes the factor's non-orthogonal part to its corresponding core.

        This method works in place.

        :param mu: an int between 0 and N-1
        """

        if self.Us[mu] is None:
            return
        Q, R = torch.qr(self.Us[mu])
        self.Us[mu] = Q

        if self.batch:
            if self.cores[mu].dim() == 3:
                self.cores[mu] = torch.einsum('bjk,baj->bak', (self.cores[mu], R))
            else:
                self.cores[mu] = torch.einsum('bijk,baj->biak', (self.cores[mu], R))
        else:
            if self.cores[mu].dim() == 2:
                self.cores[mu] = torch.einsum('jk,aj->ak', (self.cores[mu], R))
            else:
                self.cores[mu] = torch.einsum('ijk,aj->iak', (self.cores[mu], R)) 
Example #2
Source File: model.py    From glow-pytorch with MIT License 6 votes vote down vote up
def __init__(self, in_channel):
        super().__init__()

        weight = np.random.randn(in_channel, in_channel)
        q, _ = la.qr(weight)
        w_p, w_l, w_u = la.lu(q.astype(np.float32))
        w_s = np.diag(w_u)
        w_u = np.triu(w_u, 1)
        u_mask = np.triu(np.ones_like(w_u), 1)
        l_mask = u_mask.T

        w_p = torch.from_numpy(w_p)
        w_l = torch.from_numpy(w_l)
        w_s = torch.from_numpy(w_s)
        w_u = torch.from_numpy(w_u)

        self.register_buffer('w_p', w_p)
        self.register_buffer('u_mask', torch.from_numpy(u_mask))
        self.register_buffer('l_mask', torch.from_numpy(l_mask))
        self.register_buffer('s_sign', torch.sign(w_s))
        self.register_buffer('l_eye', torch.eye(l_mask.shape[0]))
        self.w_l = nn.Parameter(w_l)
        self.w_s = nn.Parameter(logabs(w_s))
        self.w_u = nn.Parameter(w_u) 
Example #3
Source File: test_fisherreg_fd.py    From xfer with Apache License 2.0 5 votes vote down vote up
def test_fisher_matrix_matrix_matmul(self):
        model = torch.nn.Sequential(
            torch.nn.Linear(1, 400),
            torch.nn.ELU(),
            torch.nn.Linear(400, 400),
            torch.nn.ELU(),
            torch.nn.Linear(400, 1),
        )

        data = torch.randn(1500, 1)

        fvp = FVPR_FD(model, data)

        numpars = 0
        for p in model.parameters():
            numpars += p.numel()

        orthmat, _ = torch.qr(torch.randn(numpars, 80))
        emat = 1e-2 * torch.randn(80, 2)

        full_matmul = fvp.matmul(orthmat @ emat)
        split_matmul = fvp.matmul(orthmat) @ emat

        # check that F (Vy) = FV y
        self.assertLess(
            torch.norm(full_matmul - split_matmul) / split_matmul.norm(), 1e-2
        )

        # check that matrix columns work
        self.assertLess(
            torch.norm(full_matmul[:, 0] - fvp.matmul(orthmat @ emat[:, 0])), 1e-5
        ) 
Example #4
Source File: glow.py    From Tacotron2-Mandarin with MIT License 5 votes vote down vote up
def __init__(self, c):
        super(Invertible1x1Conv, self).__init__()
        self.conv = torch.nn.Conv1d(c, c, kernel_size=1, stride=1, padding=0,
                                    bias=False)

        # Sample a random orthonormal matrix to initialize weights
        W = torch.qr(torch.FloatTensor(c, c).normal_())[0]

        # Ensure determinant is 1.0 not -1.0
        if torch.det(W) < 0:
            W[:, 0] = -1 * W[:, 0]
        W = W.view(c, c, 1)
        self.conv.weight.data = W 
Example #5
Source File: tensor.py    From tntorch with GNU Lesser General Public License v3.0 5 votes vote down vote up
def right_orthogonalize(self, mu):
        """
        Makes the mu-th core right-orthogonal and pushes the L factor to its left core. Note: this may change the ranks
         of the tensor.

        This method works in place.

        Note: internally, this method will turn CP (or CP-Tucker) cores into TT (or TT-Tucker) ones.

        :param mu: an int between 0 and N-1

        :return: the L factor
        """

        assert 1 <= mu < self.dim()
        self.factor_orthogonalize(mu)
        # Torch has no rq() decomposition
        if self.batch:
            Q, L = torch.qr(tn.right_unfolding(self.cores[mu], batch=self.batch).permute(0, 2, 1))
            L = L.permute(0, 2, 1)
            Q = Q.permute(0, 2, 1)
        else:
            Q, L = torch.qr(tn.right_unfolding(self.cores[mu], batch=self.batch).permute(1, 0))
            L = L.permute(1, 0)
            Q = Q.permute(1, 0)

        if self.batch:
            self.cores[mu] = torch.reshape(Q, (Q.shape[:2]) + self.cores[mu].shape[2:])
        else:
            self.cores[mu] = torch.reshape(Q, (Q.shape[0], ) + self.cores[mu].shape[1:])

        leftcoreL = tn.left_unfolding(self.cores[mu-1], batch=self.batch)
        if self.batch:
            self.cores[mu-1] = torch.reshape(torch.matmul(leftcoreL, L), self.cores[mu-1].shape[:-1] + (L.shape[2], ))
        else:
            self.cores[mu-1] = torch.reshape(torch.mm(leftcoreL, L), self.cores[mu-1].shape[:-1] + (L.shape[1], ))
        return L 
Example #6
Source File: tensor.py    From tntorch with GNU Lesser General Public License v3.0 5 votes vote down vote up
def left_orthogonalize(self, mu):
        """
        Makes the mu-th core left-orthogonal and pushes the R factor to its right core. This may change the ranks
        of the cores.

        This method works in place.

        Note: internally, this method will turn CP (or CP-Tucker) cores into TT (or TT-Tucker) ones.

        :param mu: an int between 0 and N-1

        :return: the R factor
        """

        assert 0 <= mu < self.dim()-1
        self.factor_orthogonalize(mu)
        Q, R = torch.qr(tn.left_unfolding(self.cores[mu], batch=self.batch))

        if self.batch:
            self.cores[mu] = torch.reshape(Q, self.cores[mu].shape[:-1] + (Q.shape[2], ))
        else:
            self.cores[mu] = torch.reshape(Q, self.cores[mu].shape[:-1] + (Q.shape[1], ))

        rightcoreR = tn.right_unfolding(self.cores[mu+1], batch=self.batch)

        if self.batch:
            self.cores[mu+1] = torch.reshape(torch.matmul(R, rightcoreR), (R.shape[0], R.shape[1]) + self.cores[mu+1].shape[2:])
        else:
            self.cores[mu+1] = torch.reshape(torch.mm(R, rightcoreR), (R.shape[0], ) + self.cores[mu+1].shape[1:])
        return R 
Example #7
Source File: tensor.py    From tntorch with GNU Lesser General Public License v3.0 5 votes vote down vote up
def lstsq(b, A):
    if A.dim() == 3:
        batch = True
    elif A.dim() == 2:
        batch = False
    else:
        raise RuntimeError('Wrong shape of A')

    q, r = torch.qr(A)
    if batch:
        return torch.cat([torch.matmul(torch.matmul(r[i].inverse(), q[i].t()), b[i])[None, ...] for i in range(len(q))]).transpose(-1, -2)
    else:
        return torch.matmul(torch.matmul(r.inverse(), q.t()), b).transpose(-1, -2) 
Example #8
Source File: model_utils.py    From forte with Apache License 2.0 5 votes vote down vote up
def orthonormal_init(param: nn.Parameter, n_blocks: int):
        size0, size1 = param.size()
        size0 //= n_blocks
        size_min = min(size0, size1)
        init_values = []
        for _ in range(n_blocks):
            m1 = torch.randn(size0, size0, dtype=param.dtype)
            m2 = torch.randn(size1, size1, dtype=param.dtype)
            q1, r1 = torch.qr(m1)
            q2, r2 = torch.qr(m2)
            q1 *= torch.sign(torch.diag(r1))
            q2 *= torch.sign(torch.diag(r2))
            value = torch.mm(q1[:, :size_min], q2[:size_min, :])
            init_values.append(value)
        param.data = torch.cat(init_values, dim=0) 
Example #9
Source File: model_utils.py    From Greedy_InfoMax with MIT License 5 votes vote down vote up
def genOrthgonal(dim):
    a = torch.zeros((dim, dim)).normal_(0, 1)
    q, r = torch.qr(a)
    d = torch.diag(r, 0).sign()
    diag_size = d.size(0)
    d_exp = d.view(1, diag_size).expand(diag_size, diag_size)
    q.mul_(d_exp)
    return q 
Example #10
Source File: glow.py    From fac-via-ppg with Apache License 2.0 5 votes vote down vote up
def __init__(self, c):
        super(Invertible1x1Conv, self).__init__()
        self.conv = torch.nn.Conv1d(c, c, kernel_size=1, stride=1, padding=0,
                                    bias=False)

        # Sample a random orthonormal matrix to initialize weights
        W = torch.qr(torch.FloatTensor(c, c).normal_())[0]

        # Ensure determinant is 1.0 not -1.0
        if torch.det(W) < 0:
            W[:,0] = -1*W[:,0]
        W = W.view(c, c, 1)
        self.conv.weight.data = W 
Example #11
Source File: model.py    From glow-pytorch with MIT License 5 votes vote down vote up
def __init__(self, in_channel):
        super().__init__()

        weight = torch.randn(in_channel, in_channel)
        q, _ = torch.qr(weight)
        weight = q.unsqueeze(2).unsqueeze(3)
        self.weight = nn.Parameter(weight) 
Example #12
Source File: glow.py    From FastSpeech with MIT License 5 votes vote down vote up
def __init__(self, c):
        super(Invertible1x1Conv, self).__init__()
        self.conv = torch.nn.Conv1d(c, c, kernel_size=1, stride=1, padding=0,
                                    bias=False)

        # Sample a random orthonormal matrix to initialize weights
        W = torch.qr(torch.FloatTensor(c, c).normal_())[0]

        # Ensure determinant is 1.0 not -1.0
        if torch.det(W) < 0:
            W[:, 0] = -1*W[:, 0]
        W = W.view(c, c, 1)
        self.conv.weight.data = W 
Example #13
Source File: glow.py    From FastSpeech with MIT License 5 votes vote down vote up
def __init__(self, c):
        super(Invertible1x1Conv, self).__init__()
        self.conv = torch.nn.Conv1d(c, c, kernel_size=1, stride=1, padding=0,
                                    bias=False)

        # Sample a random orthonormal matrix to initialize weights
        W = torch.qr(torch.FloatTensor(c, c).normal_())[0]

        # Ensure determinant is 1.0 not -1.0
        if torch.det(W) < 0:
            W[:,0] = -1*W[:,0]
        W = W.view(c, c, 1)
        self.conv.weight.data = W 
Example #14
Source File: core.py    From tensorqtl with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def __init__(self, C_t):
        # center and orthogonalize
        self.Q_t, _ = torch.qr(C_t - C_t.mean(0))
        self.dof = C_t.shape[0] - 2 - C_t.shape[1] 
Example #15
Source File: waveglow.py    From NeMo with Apache License 2.0 5 votes vote down vote up
def __init__(self, c):
        super(Invertible1x1Conv, self).__init__()
        self.conv = torch.nn.Conv1d(c, c, kernel_size=1, stride=1, padding=0, bias=False)

        # Sample a random orthonormal matrix to initialize weights
        W = torch.qr(torch.FloatTensor(c, c).normal_())[0]

        # Ensure determinant is 1.0 not -1.0
        if torch.det(W) < 0:
            W[:, 0] = -1 * W[:, 0]
        W = W.view(c, c, 1)
        self.conv.weight.data = W 
Example #16
Source File: torchutils.py    From nsf with MIT License 5 votes vote down vote up
def random_orthogonal(size):
    """
    Returns a random orthogonal matrix as a 2-dim tensor of shape [size, size].
    """

    # Use the QR decomposition of a random Gaussian matrix.
    x = torch.randn(size, size)
    q, _ = torch.qr(x)
    return q 
Example #17
Source File: glow.py    From waveglow with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def __init__(self, c):
        super(Invertible1x1Conv, self).__init__()
        self.conv = torch.nn.Conv1d(c, c, kernel_size=1, stride=1, padding=0,
                                    bias=False)

        # Sample a random orthonormal matrix to initialize weights
        W = torch.qr(torch.FloatTensor(c, c).normal_())[0]

        # Ensure determinant is 1.0 not -1.0
        if torch.det(W) < 0:
            W[:,0] = -1*W[:,0]
        W = W.view(c, c, 1)
        self.conv.weight.data = W 
Example #18
Source File: glow.py    From LightSpeech with MIT License 5 votes vote down vote up
def __init__(self, c):
        super(Invertible1x1Conv, self).__init__()
        self.conv = torch.nn.Conv1d(c, c, kernel_size=1, stride=1, padding=0,
                                    bias=False)

        # Sample a random orthonormal matrix to initialize weights
        W = torch.qr(torch.FloatTensor(c, c).normal_())[0]

        # Ensure determinant is 1.0 not -1.0
        if torch.det(W) < 0:
            W[:,0] = -1*W[:,0]
        W = W.view(c, c, 1)
        self.conv.weight.data = W 
Example #19
Source File: glow.py    From LightSpeech with MIT License 5 votes vote down vote up
def __init__(self, c):
        super(Invertible1x1Conv, self).__init__()
        self.conv = torch.nn.Conv1d(c, c, kernel_size=1, stride=1, padding=0,
                                    bias=False)

        # Sample a random orthonormal matrix to initialize weights
        W = torch.qr(torch.FloatTensor(c, c).normal_())[0]

        # Ensure determinant is 1.0 not -1.0
        if torch.det(W) < 0:
            W[:,0] = -1*W[:,0]
        W = W.view(c, c, 1)
        self.conv.weight.data = W 
Example #20
Source File: qr.py    From tensorgrad with Apache License 2.0 5 votes vote down vote up
def forward(self, A):
        Q, R = torch.qr(A)
        self.save_for_backward(A, Q, R)
        return Q, R 
Example #21
Source File: lazy_tensor.py    From gpytorch with MIT License 5 votes vote down vote up
def _inv_matmul_preconditioner(self):
        """
        (Optional) define a preconditioner that can be used for linear systems, but not necessarily
        for log determinants. By default, this can call :meth:`~gpytorch.lazy.LazyTensor._preconditioner`.

        Returns:
            function: a function on x which performs P^{-1}(x)
        """
        base_precond, _, _ = self._preconditioner()

        if base_precond is not None:
            return base_precond
        elif gpytorch.beta_features.default_preconditioner.on():
            if hasattr(self, "_default_preconditioner_cache"):
                U, S, V = self._default_preconditioner_cache
            else:
                precond_basis_size = min(gpytorch.settings.max_preconditioner_size.value(), self.size(-1))
                random_basis = torch.randn(
                    self.batch_shape + torch.Size((self.size(-2), precond_basis_size)),
                    device=self.device,
                    dtype=self.dtype,
                )
                projected_mat = self._matmul(random_basis)
                proj_q = torch.qr(projected_mat)
                orthog_projected_mat = self._matmul(proj_q).transpose(-2, -1)
                U, S, V = torch.svd(orthog_projected_mat)
                U = proj_q.matmul(U)

                self._default_preconditioner_cache = (U, S, V)

            def preconditioner(v):
                res = V.transpose(-2, -1).matmul(v)
                res = (1 / S).unsqueeze(-1) * res
                res = U.matmul(res)
                return res

            return preconditioner
        else:
            return None 
Example #22
Source File: added_diag_lazy_tensor.py    From gpytorch with MIT License 5 votes vote down vote up
def _init_cache_for_non_constant_diag(self, eye, batch_shape, n):
        # With non-constant diagonals, we cant factor out the noise as easily
        self._q_cache, self._r_cache = torch.qr(torch.cat((self._piv_chol_self / self._noise.sqrt(), eye)))
        self._q_cache = self._q_cache[..., :n, :] / self._noise.sqrt()

        logdet = self._r_cache.diagonal(dim1=-1, dim2=-2).abs().log().sum(-1).mul(2)
        logdet -= (1.0 / self._noise).log().sum([-1, -2])
        self._precond_logdet_cache = logdet.view(*batch_shape) if len(batch_shape) else logdet.squeeze() 
Example #23
Source File: added_diag_lazy_tensor.py    From gpytorch with MIT License 5 votes vote down vote up
def _init_cache_for_constant_diag(self, eye, batch_shape, n, k):
        # We can factor out the noise for for both QR and solves.
        self._noise = self._noise.narrow(-2, 0, 1)
        self._q_cache, self._r_cache = torch.qr(torch.cat((self._piv_chol_self, self._noise.sqrt() * eye), dim=-2))
        self._q_cache = self._q_cache[..., :n, :]

        # Use the matrix determinant lemma for the logdet, using the fact that R'R = L_k'L_k + s*I
        logdet = self._r_cache.diagonal(dim1=-1, dim2=-2).abs().log().sum(-1).mul(2)
        logdet = logdet + (n - k) * self._noise.squeeze(-2).squeeze(-1).log()
        self._precond_logdet_cache = logdet.view(*batch_shape) if len(batch_shape) else logdet.squeeze() 
Example #24
Source File: xinit.py    From NQG with GNU General Public License v3.0 5 votes vote down vote up
def orthogonal(tensor, gain=1):
    """Fills the input Tensor or Variable with a (semi) orthogonal matrix, as described in "Exact solutions to the
    nonlinear dynamics of learning in deep linear neural networks" - Saxe, A. et al. (2013). The input tensor must have
    at least 2 dimensions, and for tensors with more than 2 dimensions the trailing dimensions are flattened.

    Args:
        tensor: an n-dimensional torch.Tensor or autograd.Variable, where n >= 2
        gain: optional scaling factor

    Examples:
        >>> w = torch.Tensor(3, 5)
        >>> nn.init.orthogonal(w)
    """
    if isinstance(tensor, Variable):
        orthogonal(tensor.data, gain=gain)
        return tensor

    if tensor.ndimension() < 2:
        raise ValueError("Only tensors with 2 or more dimensions are supported")

    rows = tensor.size(0)
    cols = tensor[0].numel()
    flattened = torch.Tensor(rows, cols).normal_(0, 1)
    # Compute the qr factorization
    q, r = torch.qr(flattened)
    # Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf
    d = torch.diag(r, 0)
    ph = d.sign()
    q *= ph.expand_as(q)
    # Pad zeros to Q (if rows smaller than cols)
    if rows < cols:
        padding = torch.zeros(rows, cols - rows)
        if q.is_cuda:
            q = torch.cat([q, padding.cuda()], 1)
        else:
            q = torch.cat([q, padding], 1)

    tensor.view_as(q).copy_(q)
    tensor.mul_(gain)
    return tensor 
Example #25
Source File: xinit.py    From SEASS with MIT License 5 votes vote down vote up
def orthogonal(tensor, gain=1):
    """Fills the input Tensor or Variable with a (semi) orthogonal matrix, as described in "Exact solutions to the
    nonlinear dynamics of learning in deep linear neural networks" - Saxe, A. et al. (2013). The input tensor must have
    at least 2 dimensions, and for tensors with more than 2 dimensions the trailing dimensions are flattened.

    Args:
        tensor: an n-dimensional torch.Tensor or autograd.Variable, where n >= 2
        gain: optional scaling factor

    Examples:
        >>> w = torch.Tensor(3, 5)
        >>> nn.init.orthogonal(w)
    """
    if isinstance(tensor, Variable):
        orthogonal(tensor.data, gain=gain)
        return tensor

    if tensor.ndimension() < 2:
        raise ValueError("Only tensors with 2 or more dimensions are supported")

    rows = tensor.size(0)
    cols = tensor[0].numel()
    flattened = torch.Tensor(rows, cols).normal_(0, 1)
    # Compute the qr factorization
    q, r = torch.qr(flattened)
    # Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf
    d = torch.diag(r, 0)
    ph = d.sign()
    q *= ph.expand_as(q)
    # Pad zeros to Q (if rows smaller than cols)
    if rows < cols:
        padding = torch.zeros(rows, cols - rows)
        if q.is_cuda:
            q = torch.cat([q, padding.cuda()], 1)
        else:
            q = torch.cat([q, padding], 1)

    tensor.view_as(q).copy_(q)
    tensor.mul_(gain)
    return tensor 
Example #26
Source File: glow.py    From tn2-wg with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def __init__(self, c):
        super(Invertible1x1Conv, self).__init__()
        self.conv = torch.nn.Conv1d(c, c, kernel_size=1, stride=1, padding=0,
                                    bias=False)

        # Sample a random orthonormal matrix to initialize weights
        W = torch.qr(torch.FloatTensor(c, c).normal_())[0]

        # Ensure determinant is 1.0 not -1.0
        if torch.det(W) < 0:
            W[:,0] = -1*W[:,0]
        W = W.view(c, c, 1)
        self.conv.weight.data = W 
Example #27
Source File: WKPooling.py    From sentence-transformers with Apache License 2.0 5 votes vote down vote up
def unify_token(self, token_feature):
        """
            Unify Token Representation
        """
        window_size = self.context_window_size

        alpha_alignment = torch.zeros(token_feature.size()[0], device=token_feature.device)
        alpha_novelty = torch.zeros(token_feature.size()[0], device=token_feature.device)

        for k in range(token_feature.size()[0]):
            left_window = token_feature[k - window_size:k, :]
            right_window = token_feature[k + 1:k + window_size + 1, :]
            window_matrix = torch.cat([left_window, right_window, token_feature[k, :][None, :]])
            Q, R = torch.qr(window_matrix.T)

            r = R[:, -1]
            alpha_alignment[k] = torch.mean(self.norm_vector(R[:-1, :-1], dim=0), dim=1).matmul(R[:-1, -1]) / torch.norm(r[:-1])
            alpha_alignment[k] = 1 / (alpha_alignment[k] * window_matrix.size()[0] * 2)
            alpha_novelty[k] = torch.abs(r[-1]) / torch.norm(r)

        # Sum Norm
        alpha_alignment = alpha_alignment / torch.sum(alpha_alignment)  # Normalization Choice
        alpha_novelty = alpha_novelty / torch.sum(alpha_novelty)

        alpha = alpha_novelty + alpha_alignment
        alpha = alpha / torch.sum(alpha)  # Normalize

        out_embedding = torch.mv(token_feature.t(), alpha)
        return out_embedding 
Example #28
Source File: WKPooling.py    From sentence-transformers with Apache License 2.0 5 votes vote down vote up
def forward(self, features: Dict[str, Tensor]):
        ft_all_layers = features['all_layer_embeddings']
        org_device = ft_all_layers[0].device
        all_layer_embedding = torch.stack(ft_all_layers).transpose(1,0)
        all_layer_embedding = all_layer_embedding[:, self.layer_start:, :, :]  # Start from 4th layers output

        # torch.qr is slow on GPU (see https://github.com/pytorch/pytorch/issues/22573). So compute it on CPU until issue is fixed
        all_layer_embedding = all_layer_embedding.cpu()

        attention_mask = features['attention_mask'].cpu().numpy()
        unmask_num = np.array([sum(mask) for mask in attention_mask]) - 1  # Not considering the last item
        embedding = []

        # One sentence at a time
        for sent_index in range(len(unmask_num)):
            sentence_feature = all_layer_embedding[sent_index, :, :unmask_num[sent_index], :]
            one_sentence_embedding = []
            # Process each token
            for token_index in range(sentence_feature.shape[1]):
                token_feature = sentence_feature[:, token_index, :]
                # 'Unified Word Representation'
                token_embedding = self.unify_token(token_feature)
                one_sentence_embedding.append(token_embedding)

            features.update({'sentence_embedding': features['cls_token_embeddings']})

            one_sentence_embedding = torch.stack(one_sentence_embedding)
            sentence_embedding = self.unify_sentence(sentence_feature, one_sentence_embedding)
            embedding.append(sentence_embedding)

        output_vector = torch.stack(embedding).to(org_device)

        features.update({'sentence_embedding': output_vector})

        return features 
Example #29
Source File: qr.py    From heat with MIT License 4 votes vote down vote up
def __split0_send_q_to_diag_pr(
    col, pr0, pr1, diag_process, comm, q_dict, key, q_dict_waits, q_dtype, q_device
):
    """

    This function sends the merged Q to the diagonal process. Buffered send it used for sending
    Q. This is needed for the Q calculation when two processes are merged and neither is the diagonal
    process.

    Parameters
    ----------
    col : int
        The current column used in the parent QR loop
    pr0, pr1 : int, int
        Rank of processes 0 and 1. These are the processes used in the calculation of q
    diag_process : int
        The rank of the process which has the tile along the diagonal for the given column
    comm : MPICommunication (ht.DNDarray.comm)
        The communicator used. (Intended as the communication of the DNDarray 'a' given to qr)
    q_dict : Dict
        dictionary containing the Q values calculated for finding R
    key : string
        key for q_dict[col] which corresponds to the Q to send
    q_dict_waits : Dict
        Dictionary used in the collection of the Qs which are sent to the diagonal process
    q_dtype : torch.type
        Type of the Q tensor
    q_device : torch.Device
        Device of the Q tensor

    Returns
    -------
    None, sets the values of q_dict_waits with the with *waits* for the values of Q, upper.shape,
        and lower.shape
    """
    if comm.rank not in [pr0, pr1, diag_process]:
        return
    # this is to send the merged q to the diagonal process for the forming of q
    base_tag = "1" + str(pr1.item() if isinstance(pr1, torch.Tensor) else pr1)
    if comm.rank == pr1:
        q = q_dict[col][key][0]
        u_shape = q_dict[col][key][1]
        l_shape = q_dict[col][key][2]
        comm.send(tuple(q.shape), dest=diag_process, tag=int(base_tag + "1"))
        comm.Isend(q, dest=diag_process, tag=int(base_tag + "12"))
        comm.send(u_shape, dest=diag_process, tag=int(base_tag + "123"))
        comm.send(l_shape, dest=diag_process, tag=int(base_tag + "1234"))
    if comm.rank == diag_process:
        # q_dict_waits now looks like a
        q_sh = comm.recv(source=pr1, tag=int(base_tag + "1"))
        q_recv = torch.zeros(q_sh, dtype=q_dtype, device=q_device)
        k = "p0" + str(pr0) + "p1" + str(pr1)
        q_dict_waits[col][k] = []
        q_wait = comm.Irecv(q_recv, source=pr1, tag=int(base_tag + "12"))
        q_dict_waits[col][k].append([q_recv, q_wait])
        q_dict_waits[col][k].append(comm.irecv(source=pr1, tag=int(base_tag + "123")))
        q_dict_waits[col][k].append(comm.irecv(source=pr1, tag=int(base_tag + "1234")))
        q_dict_waits[col][k].append(key[0])