Python torch.dtype() Examples

The following are 30 code examples of torch.dtype(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source File: hooks.py    From mmdetection-annotated with Apache License 2.0 6 votes vote down vote up
def patch_forward_method(func, src_type, dst_type, convert_output=True):
    """Patch the forward method of a module.

    Args:
        func (callable): The original forward method.
        src_type (torch.dtype): Type of input arguments to be converted from.
        dst_type (torch.dtype): Type of input arguments to be converted to.
        convert_output (bool): Whether to convert the output back to src_type.

    Returns:
        callable: The patched forward method.
    """

    def new_forward(*args, **kwargs):
        output = func(*cast_tensor_type(args, src_type, dst_type),
                      **cast_tensor_type(kwargs, src_type, dst_type))
        if convert_output:
            output = cast_tensor_type(output, dst_type, src_type)
        return output

    return new_forward 
Example #2
Source File: hooks.py    From mmdetection with Apache License 2.0 6 votes vote down vote up
def patch_forward_method(func, src_type, dst_type, convert_output=True):
    """Patch the forward method of a module.

    Args:
        func (callable): The original forward method.
        src_type (torch.dtype): Type of input arguments to be converted from.
        dst_type (torch.dtype): Type of input arguments to be converted to.
        convert_output (bool): Whether to convert the output back to src_type.

    Returns:
        callable: The patched forward method.
    """

    def new_forward(*args, **kwargs):
        output = func(*cast_tensor_type(args, src_type, dst_type),
                      **cast_tensor_type(kwargs, src_type, dst_type))
        if convert_output:
            output = cast_tensor_type(output, dst_type, src_type)
        return output

    return new_forward 
Example #3
Source File: fcos_head.py    From AerialDetection with Apache License 2.0 6 votes vote down vote up
def get_points(self, featmap_sizes, dtype, device):
        """Get points according to feature map sizes.

        Args:
            featmap_sizes (list[tuple]): Multi-level feature map sizes.
            dtype (torch.dtype): Type of points.
            device (torch.device): Device of points.

        Returns:
            tuple: points of each image.
        """
        mlvl_points = []
        for i in range(len(featmap_sizes)):
            mlvl_points.append(
                self.get_points_single(featmap_sizes[i], self.strides[i],
                                       dtype, device))
        return mlvl_points 
Example #4
Source File: euclidean.py    From geoopt with Apache License 2.0 6 votes vote down vote up
def origin(
        self, *size, dtype=None, device=None, seed=42
    ) -> "geoopt.ManifoldTensor":
        """
        Zero point origin.

        Parameters
        ----------
        size : shape
            the desired shape
        device : torch.device
            the desired device
        dtype : torch.dtype
            the desired dtype
        seed : int
            ignored

        Returns
        -------
        ManifoldTensor
        """
        self._assert_check_shape(size2shape(*size), "x")
        return geoopt.ManifoldTensor(
            torch.zeros(*size, dtype=dtype, device=device), manifold=self
        ) 
Example #5
Source File: birkhoff_polytope.py    From geoopt with Apache License 2.0 6 votes vote down vote up
def random_naive(self, *size, dtype=None, device=None) -> torch.Tensor:
        """
        Naive approach to get random matrix on Birkhoff Polytope manifold.

        A helper function to sample a random point on the Birkhoff Polytope manifold.
        The measure is non-uniform for this method, but fast to compute.

        Parameters
        ----------
        size : shape
            the desired output shape
        dtype : torch.dtype
            desired dtype
        device : torch.device
            desired device

        Returns
        -------
        ManifoldTensor
            random point on Birkhoff Polytope manifold
        """
        self._assert_check_shape(size2shape(*size), "x")
        # projection requires all values be non-negative
        tens = torch.randn(*size, device=device, dtype=dtype).abs_()
        return ManifoldTensor(self.projx(tens), manifold=self) 
Example #6
Source File: stiefel.py    From geoopt with Apache License 2.0 6 votes vote down vote up
def origin(self, *size, dtype=None, device=None, seed=42) -> torch.Tensor:
        """
        Identity matrix point origin.

        Parameters
        ----------
        size : shape
            the desired shape
        device : torch.device
            the desired device
        dtype : torch.dtype
            the desired dtype
        seed : int
            ignored

        Returns
        -------
        ManifoldTensor
        """
        self._assert_check_shape(size2shape(*size), "x")
        eye = torch.zeros(*size, dtype=dtype, device=device)
        eye[..., torch.arange(eye.shape[-1]), torch.arange(eye.shape[-1])] += 1
        return ManifoldTensor(eye, manifold=self) 
Example #7
Source File: birkhoff_polytope.py    From geoopt with Apache License 2.0 6 votes vote down vote up
def origin(self, *size, dtype=None, device=None, seed=42) -> torch.Tensor:
        """
        Identity matrix point origin.

        Parameters
        ----------
        size : shape
            the desired shape
        device : torch.device
            the desired device
        dtype : torch.dtype
            the desired dtype
        seed : int
            ignored

        Returns
        -------
        ManifoldTensor
        """
        shape = size2shape(*size)
        self._assert_check_shape(shape, "x")
        eye = torch.eye(*shape[-2:], dtype=dtype, device=device)
        eye = eye.expand(shape)
        return ManifoldTensor(eye, manifold=self) 
Example #8
Source File: degree.py    From pytorch_geometric with MIT License 6 votes vote down vote up
def degree(index, num_nodes: Optional[int] = None,
           dtype: Optional[int] = None):
    r"""Computes the (unweighted) degree of a given one-dimensional index
    tensor.

    Args:
        index (LongTensor): Index tensor.
        num_nodes (int, optional): The number of nodes, *i.e.*
            :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`)
        dtype (:obj:`torch.dtype`, optional): The desired data type of the
            returned tensor.

    :rtype: :class:`Tensor`
    """
    N = maybe_num_nodes(index, num_nodes)
    out = torch.zeros((N, ), dtype=dtype, device=index.device)
    one = torch.ones((index.size(0), ), dtype=out.dtype, device=out.device)
    return out.scatter_add_(0, index, one) 
Example #9
Source File: grid.py    From pytorch_geometric with MIT License 6 votes vote down vote up
def grid(height, width, dtype=None, device=None):
    r"""Returns the edge indices of a two-dimensional grid graph with height
    :attr:`height` and width :attr:`width` and its node positions.

    Args:
        height (int): The height of the grid.
        width (int): The width of the grid.
        dtype (:obj:`torch.dtype`, optional): The desired data type of the
            returned position tensor.
        dtype (:obj:`torch.device`, optional): The desired device of the
            returned tensors.

    :rtype: (:class:`LongTensor`, :class:`Tensor`)
    """

    edge_index = grid_index(height, width, device)
    pos = grid_pos(height, width, dtype, device)
    return edge_index, pos 
Example #10
Source File: grid.py    From pytorch_geometric with MIT License 6 votes vote down vote up
def grid_index(height, width, device=None):
    w = width
    kernel = [-w - 1, -1, w - 1, -w, 0, w, -w + 1, 1, w + 1]
    kernel = torch.tensor(kernel, device=device)

    row = torch.arange(height * width, dtype=torch.long, device=device)
    row = row.view(-1, 1).repeat(1, kernel.size(0))
    col = row + kernel.view(1, -1)
    row, col = row.view(height, -1), col.view(height, -1)
    index = torch.arange(3, row.size(1) - 3, dtype=torch.long, device=device)
    row, col = row[:, index].view(-1), col[:, index].view(-1)

    mask = (col >= 0) & (col < height * width)
    row, col = row[mask], col[mask]

    edge_index = torch.stack([row, col], dim=0)
    edge_index, _ = coalesce(edge_index, None, height * width, height * width)

    return edge_index 
Example #11
Source File: hooks.py    From DenseMatchingBenchmark with MIT License 6 votes vote down vote up
def patch_forward_method(func, src_type, dst_type, convert_output=True):
    """Patch the forward method of a module.
    Args:
        func (callable): The original forward method.
        src_type (torch.dtype): Type of input arguments to be converted from.
        dst_type (torch.dtype): Type of input arguments to be converted to.
        convert_output (bool): Whether to convert the output back to src_type.
    Returns:
        callable: The patched forward method.
    """

    def new_forward(*args, **kwargs):
        output = func(*cast_tensor_type(args, src_type, dst_type),
                      **cast_tensor_type(kwargs, src_type, dst_type))
        if convert_output:
            output = cast_tensor_type(output, dst_type, src_type)
        return output

    return new_forward 
Example #12
Source File: types.py    From chainer-compiler with MIT License 6 votes vote down vote up
def torch_dtype_to_np_dtype(dtype):
    dtype_dict = {
            torch.bool    : np.dtype(np.bool),
            torch.uint8   : np.dtype(np.uint8),
            torch.int8    : np.dtype(np.int8),
            torch.int16   : np.dtype(np.int16),
            torch.short   : np.dtype(np.int16),
            torch.int32   : np.dtype(np.int32),
            torch.int     : np.dtype(np.int32),
            torch.int64   : np.dtype(np.int64),
            torch.long    : np.dtype(np.int64),
            torch.float16 : np.dtype(np.float16),
            torch.half    : np.dtype(np.float16),
            torch.float32 : np.dtype(np.float32),
            torch.float   : np.dtype(np.float32),
            torch.float64 : np.dtype(np.float64),
            torch.double  : np.dtype(np.float64),
            }
    return dtype_dict[dtype]


# ---------------------- InferenceEngine internal types ------------------------ 
Example #13
Source File: trial.py    From torchbearer with MIT License 6 votes vote down vote up
def update_device_and_dtype(state, *args, **kwargs):
    """Function gets data type and device values from the args / kwargs and updates state.

    Args:
        state (State): The :class:`.State` to update
        args: Arguments to the :func:`Trial.to` function
        kwargs: Keyword arguments to the :func:`Trial.to` function

    Returns:
        state
    """
    for key, _ in kwargs.items():
        if key == str(torchbearer.DATA_TYPE):
            state[torchbearer.DATA_TYPE] = kwargs['dtype']
        elif str(torchbearer.DEVICE) in kwargs:
            state[torchbearer.DEVICE] = kwargs['device']

    for arg in args:
        if isinstance(arg, torch.dtype):
            state[torchbearer.DATA_TYPE] = arg
        else:
            state[torchbearer.DEVICE] = arg

    return state 
Example #14
Source File: torch_serde.py    From PySyft with Apache License 2.0 6 votes vote down vote up
def protobuf_tensor_deserializer(
    worker: AbstractWorker, protobuf_tensor: TensorDataPB
) -> torch.Tensor:
    """Strategy to deserialize a binary input using Protobuf"""
    size = tuple(protobuf_tensor.shape.dims)
    data = getattr(protobuf_tensor, "contents_" + protobuf_tensor.dtype)

    if protobuf_tensor.is_quantized:
        # Drop the 'q' from the beginning of the quantized dtype to get the int type
        dtype = TORCH_STR_DTYPE[protobuf_tensor.dtype[1:]]
        int_tensor = torch.tensor(data, dtype=dtype).reshape(size)
        # Automatically converts int types to quantized types
        return torch._make_per_tensor_quantized_tensor(
            int_tensor, protobuf_tensor.scale, protobuf_tensor.zero_point
        )
    else:
        dtype = TORCH_STR_DTYPE[protobuf_tensor.dtype]
        return torch.tensor(data, dtype=dtype).reshape(size) 
Example #15
Source File: anchor_free_head.py    From mmdetection with Apache License 2.0 6 votes vote down vote up
def get_points(self, featmap_sizes, dtype, device, flatten=False):
        """Get points according to feature map sizes.

        Args:
            featmap_sizes (list[tuple]): Multi-level feature map sizes.
            dtype (torch.dtype): Type of points.
            device (torch.device): Device of points.

        Returns:
            tuple: points of each image.
        """
        mlvl_points = []
        for i in range(len(featmap_sizes)):
            mlvl_points.append(
                self._get_points_single(featmap_sizes[i], self.strides[i],
                                        dtype, device, flatten))
        return mlvl_points 
Example #16
Source File: fcos_head.py    From mmdetection-annotated with Apache License 2.0 6 votes vote down vote up
def get_points(self, featmap_sizes, dtype, device):
        """Get points according to feature map sizes.

        Args:
            featmap_sizes (list[tuple]): Multi-level feature map sizes.
            dtype (torch.dtype): Type of points.
            device (torch.device): Device of points.

        Returns:
            tuple: points of each image.
        """
        mlvl_points = []
        for i in range(len(featmap_sizes)):
            mlvl_points.append(
                self.get_points_single(featmap_sizes[i], self.strides[i],
                                       dtype, device))
        return mlvl_points 
Example #17
Source File: structures.py    From virtex with MIT License 6 votes vote down vote up
def to(self, *args, **kwargs) -> "Batch":
        new_batch = self.clone()
        device, dtype, non_blocking = torch._C._nn._parse_to(*args, **kwargs)

        if dtype is not None:
            if not dtype.is_floating_point:
                raise TypeError(
                    f"Can cast {self.__class__.__name__} to a floating point "
                    f"dtype, but got desired dtype={dtype}"
                )
            else:
                for key in new_batch.keys():
                    if new_batch[key].dtype.is_floating_point:
                        new_batch[key] = new_batch[key].to(dtype)

        if device is not None:
            for key in new_batch.keys():
                new_batch[key] = new_batch[key].to(device)
        return new_batch 
Example #18
Source File: textual_heads.py    From virtex with MIT License 6 votes vote down vote up
def _generate_future_mask(
        self, size: int, dtype: torch.dtype, device: torch.device
    ) -> torch.Tensor:
        r"""
        Generate a mask for "future" positions, useful when using this module
        for language modeling.

        Parameters
        ----------
        size: int
        """
        # Default mask is for forward direction. Flip for backward direction.
        mask = torch.triu(
            torch.ones(size, size, device=device, dtype=dtype), diagonal=1
        )
        mask = mask.masked_fill(mask == 1, float("-inf"))
        return mask 
Example #19
Source File: utils.py    From MONAI with Apache License 2.0 6 votes vote down vote up
def one_hot(labels, num_classes: int, dtype: torch.dtype = torch.float):
    """
    For a tensor `labels` of dimensions B1[spatial_dims], return a tensor of dimensions `BN[spatial_dims]`
    for `num_classes` N number of classes.

    Example:

        For every value v = labels[b,1,h,w], the value in the result at [b,v,h,w] will be 1 and all others 0.
        Note that this will include the background label, thus a binary mask should be treated as having 2 classes.
    """
    assert labels.dim() > 0, "labels should have dim of 1 or more."

    # if 1D, add singelton dim at the end
    if labels.dim() == 1:
        labels = labels.view(-1, 1)

    sh = list(labels.shape)

    assert sh[1] == 1, "labels should have a channel with length equals to one."
    sh[1] = num_classes

    o = torch.zeros(size=sh, dtype=dtype, device=labels.device)
    labels = o.scatter_(dim=1, index=labels.long(), value=1)

    return labels 
Example #20
Source File: hooks.py    From GCNet with Apache License 2.0 6 votes vote down vote up
def patch_forward_method(func, src_type, dst_type, convert_output=True):
    """Patch the forward method of a module.

    Args:
        func (callable): The original forward method.
        src_type (torch.dtype): Type of input arguments to be converted from.
        dst_type (torch.dtype): Type of input arguments to be converted to.
        convert_output (bool): Whether to convert the output back to src_type.

    Returns:
        callable: The patched forward method.
    """

    def new_forward(*args, **kwargs):
        output = func(*cast_tensor_type(args, src_type, dst_type),
                      **cast_tensor_type(kwargs, src_type, dst_type))
        if convert_output:
            output = cast_tensor_type(output, dst_type, src_type)
        return output

    return new_forward 
Example #21
Source File: fcos_head.py    From GCNet with Apache License 2.0 6 votes vote down vote up
def get_points(self, featmap_sizes, dtype, device):
        """Get points according to feature map sizes.

        Args:
            featmap_sizes (list[tuple]): Multi-level feature map sizes.
            dtype (torch.dtype): Type of points.
            device (torch.device): Device of points.

        Returns:
            tuple: points of each image.
        """
        mlvl_points = []
        for i in range(len(featmap_sizes)):
            mlvl_points.append(
                self.get_points_single(featmap_sizes[i], self.strides[i],
                                       dtype, device))
        return mlvl_points 
Example #22
Source File: test_msgpack_serde.py    From PySyft with Apache License 2.0 6 votes vote down vote up
def test_numpy_number_simplify(workers):
    """This tests our ability to simplify numpy.float objects

    At the time of writing, numpy number simplify to an object inside
    of a tuple where the first value is a byte representation of the number
    and the second value is the dtype
    """
    me = workers["me"]

    input = numpy.float32(2.0)
    output = msgpack.serde._simplify(me, input)

    # make sure simplified type ID is correct
    assert (
        msgpack.serde.msgpack_global_state.detailers[output[0]] == native_serde._detail_numpy_number
    )

    # make sure serialized form is correct
    assert type(output[1][0]) == bytes
    assert output[1][1] == msgpack.serde._simplify(me, input.dtype.name) 
Example #23
Source File: test_msgpack_serde.py    From PySyft with Apache License 2.0 6 votes vote down vote up
def test_torch_tensor_serde_generic(workers):
    """This tests our ability to ser-de torch.Tensor objects
    using "all" serialization strategy
    """

    worker = VirtualWorker(None, id="non-torch")

    # create a tensor
    input = Tensor(numpy.random.random((100, 100)))

    # ser-de the tensor
    output = msgpack.serde._simplify(worker, input)
    detailed = msgpack.serde._detail(worker, output)

    # check tensor contents
    assert input.size() == detailed.size()
    assert input.dtype == detailed.dtype
    assert (input == detailed).all() 
Example #24
Source File: serde_helpers.py    From PySyft with Apache License 2.0 6 votes vote down vote up
def make_numpy_number(dtype, **kwargs):
    num = numpy.array([2.2], dtype=dtype)[0]
    return [
        {
            "value": num,
            "simplified": (
                CODE[dtype],
                (
                    num.tobytes(),  # (bytes)
                    (CODE[str], (num.dtype.name.encode("utf-8"),)),  # (str) dtype.name
                ),
            ),
        }
    ]


########################################################################
# PyTorch.
########################################################################

# Utility functions. 
Example #25
Source File: serde_helpers.py    From PySyft with Apache License 2.0 6 votes vote down vote up
def make_numpy_ndarray(**kwargs):
    np_array = numpy.random.random((2, 2))

    def compare(detailed, original):
        """Compare numpy arrays"""
        assert numpy.array_equal(detailed, original)
        return True

    return [
        {
            "value": np_array,
            "simplified": (
                CODE[type(np_array)],
                (
                    np_array.tobytes(),  # (bytes) serialized bin
                    (CODE[tuple], (2, 2)),  # (tuple) shape
                    (CODE[str], (b"float64",)),  # (str) dtype.name
                ),
            ),
            "cmp_detailed": compare,
        }
    ]


# numpy.float32, numpy.float64, numpy.int32, numpy.int64 
Example #26
Source File: torch_serde.py    From PySyft with Apache License 2.0 6 votes vote down vote up
def protobuf_tensor_serializer(worker: AbstractWorker, tensor: torch.Tensor) -> TensorDataPB:
    """Strategy to serialize a tensor using Protobuf"""
    dtype = TORCH_DTYPE_STR[tensor.dtype]

    protobuf_tensor = TensorDataPB()

    if tensor.is_quantized:
        protobuf_tensor.is_quantized = True
        protobuf_tensor.scale = tensor.q_scale()
        protobuf_tensor.zero_point = tensor.q_zero_point()
        data = torch.flatten(tensor).int_repr().tolist()
    else:
        data = torch.flatten(tensor).tolist()

    protobuf_tensor.dtype = dtype
    protobuf_tensor.shape.dims.extend(tensor.size())
    getattr(protobuf_tensor, "contents_" + dtype).extend(data)

    return protobuf_tensor 
Example #27
Source File: utils.py    From mmdetection with Apache License 2.0 5 votes vote down vote up
def cast_tensor_type(inputs, src_type, dst_type):
    """Recursively convert Tensor in inputs from src_type to dst_type.

    Args:
        inputs: Inputs that to be casted.
        src_type (torch.dtype): Source type..
        dst_type (torch.dtype): Destination type.

    Returns:
        The same type with inputs, but all contained Tensors have been cast.
    """
    if isinstance(inputs, torch.Tensor):
        return inputs.to(dst_type)
    elif isinstance(inputs, str):
        return inputs
    elif isinstance(inputs, np.ndarray):
        return inputs
    elif isinstance(inputs, abc.Mapping):
        return type(inputs)({
            k: cast_tensor_type(v, src_type, dst_type)
            for k, v in inputs.items()
        })
    elif isinstance(inputs, abc.Iterable):
        return type(inputs)(
            cast_tensor_type(item, src_type, dst_type) for item in inputs)
    else:
        return inputs 
Example #28
Source File: torch_serde.py    From PySyft with Apache License 2.0 5 votes vote down vote up
def simplified_tensor_deserializer(worker: AbstractWorker, tensor_tuple: tuple) -> torch.Tensor:
    """Strategy to deserialize a simplified tensor into a Torch tensor"""

    size, dtype, data_arr = serde._detail(worker, tensor_tuple)
    tensor = torch.tensor(data_arr, dtype=TORCH_STR_DTYPE[dtype]).reshape(size)
    return tensor


# Simplify/Detail Torch Tensors 
Example #29
Source File: torch_serde.py    From PySyft with Apache License 2.0 5 votes vote down vote up
def _simplify_torch_dtype(worker: AbstractWorker, dtype: torch.dtype) -> Tuple[int]:
    return TORCH_DTYPE_STR[dtype] 
Example #30
Source File: fcos_head.py    From mmdetection-annotated with Apache License 2.0 5 votes vote down vote up
def get_points_single(self, featmap_size, stride, dtype, device):
        h, w = featmap_size
        x_range = torch.arange(
            0, w * stride, stride, dtype=dtype, device=device)
        y_range = torch.arange(
            0, h * stride, stride, dtype=dtype, device=device)
        y, x = torch.meshgrid(y_range, x_range)
        points = torch.stack(
            (x.reshape(-1), y.reshape(-1)), dim=-1) + stride // 2
        return points