Python torch.int16() Examples

The following are 30 code examples of torch.int16(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source File: types.py    From chainer-compiler with MIT License 6 votes vote down vote up
def torch_dtype_to_np_dtype(dtype):
    dtype_dict = {
            torch.bool    : np.dtype(np.bool),
            torch.uint8   : np.dtype(np.uint8),
            torch.int8    : np.dtype(np.int8),
            torch.int16   : np.dtype(np.int16),
            torch.short   : np.dtype(np.int16),
            torch.int32   : np.dtype(np.int32),
            torch.int     : np.dtype(np.int32),
            torch.int64   : np.dtype(np.int64),
            torch.long    : np.dtype(np.int64),
            torch.float16 : np.dtype(np.float16),
            torch.half    : np.dtype(np.float16),
            torch.float32 : np.dtype(np.float32),
            torch.float   : np.dtype(np.float32),
            torch.float64 : np.dtype(np.float64),
            torch.double  : np.dtype(np.float64),
            }
    return dtype_dict[dtype]


# ---------------------- InferenceEngine internal types ------------------------ 
Example #2
Source File: constants.py    From heat with MIT License 6 votes vote down vote up
def sanitize_infinity(dtype):
    """
    Returns largest possible value for the specified dtype.

    Parameters:
    -----------
    dtype: torch dtype

    Returns:
    --------
    large_enough: largest possible value for the given dtype
    """
    if dtype is torch.int8:
        large_enough = (1 << 7) - 1
    elif dtype is torch.int16:
        large_enough = (1 << 15) - 1
    elif dtype is torch.int32:
        large_enough = (1 << 31) - 1
    elif dtype is torch.int64:
        large_enough = (1 << 63) - 1
    else:
        large_enough = float("inf")

    return large_enough 
Example #3
Source File: test_types.py    From heat with MIT License 6 votes vote down vote up
def test_canonical_heat_type(self):
        self.assertEqual(ht.core.types.canonical_heat_type(ht.float32), ht.float32)
        self.assertEqual(ht.core.types.canonical_heat_type("?"), ht.bool)
        self.assertEqual(ht.core.types.canonical_heat_type(int), ht.int32)
        self.assertEqual(ht.core.types.canonical_heat_type("u1"), ht.uint8)
        self.assertEqual(ht.core.types.canonical_heat_type(np.int8), ht.int8)
        self.assertEqual(ht.core.types.canonical_heat_type(torch.short), ht.int16)

        with self.assertRaises(TypeError):
            ht.core.types.canonical_heat_type({})
        with self.assertRaises(TypeError):
            ht.core.types.canonical_heat_type(object)
        with self.assertRaises(TypeError):
            ht.core.types.canonical_heat_type(1)
        with self.assertRaises(TypeError):
            ht.core.types.canonical_heat_type("i7") 
Example #4
Source File: ops.py    From adeptRL with GNU General Public License v3.0 6 votes vote down vote up
def update_dtype(self, old_dtype):
        updated = {}
        for k, v in old_dtype.items():
            if v == np.float32:
                dt = torch.float32
            elif v == np.float64:
                dt = torch.float64
            elif v == np.float16:
                dt = torch.float16
            elif v == np.uint8:
                dt = torch.uint8
            elif v == np.int8:
                dt = torch.int8
            elif v == np.int16:
                dt = torch.int16
            elif v == np.int32:
                dt = torch.int32
            elif v == np.int16:
                dt = torch.int16
            else:
                raise ValueError("Unsupported dtype {}".format(v))
            updated[k] = dt
        return updated 
Example #5
Source File: pytorch_abstract_types.py    From myia with MIT License 6 votes vote down vote up
def pytorch_dtype_to_type(dtype):
    """Map a pytorch dtype to a myia type."""
    import torch

    _type_map = {
        torch.int8: Int[8],
        torch.int16: Int[16],
        torch.int32: Int[32],
        torch.int64: Int[64],
        torch.uint8: UInt[8],
        torch.float16: Float[16],
        torch.float32: Float[32],
        torch.float64: Float[64],
        torch.bool: Bool,
    }
    if dtype not in _type_map:
        raise TypeError(f"Unsupported dtype {dtype}")
    return _type_map[dtype] 
Example #6
Source File: coders.py    From L3C-PyTorch with GNU General Public License v3.0 6 votes vote down vote up
def range_decode(self, encoded_bytes, cdf, time_logger: StackTimeLogger = no_op.NoOp):
        """
        :param encoded_bytes: bytes encoded by range_encode
        :param cdf: cdf to use, either a NHWLp matrix or instance of CDFOut
        :return: decoded matrix as np.int16, NHW
        """
        if isinstance(cdf, CDFOut):
            logit_probs_c_sm, means_c, log_scales_c, K, targets = cdf

            N, _, H, W = means_c.shape

            with time_logger.run('ac.encode'):
                decoded = torchac.decode_logistic_mixture(
                        targets, means_c, log_scales_c, logit_probs_c_sm, encoded_bytes)

        else:
            N, H, W, Lp = cdf.shape
            assert Lp == self.L + 1, (Lp, self.L)

            with time_logger.run('ac.encode'):
                decoded = torchac.decode_cdf(cdf, encoded_bytes)

        return decoded.reshape(N, H, W) 
Example #7
Source File: wav_utils.py    From audio with BSD 2-Clause "Simplified" License 6 votes vote down vote up
def normalize_wav(tensor: torch.Tensor) -> torch.Tensor:
    if tensor.dtype == torch.float32:
        pass
    elif tensor.dtype == torch.int32:
        tensor = tensor.to(torch.float32)
        tensor[tensor > 0] /= 2147483647.
        tensor[tensor < 0] /= 2147483648.
    elif tensor.dtype == torch.int16:
        tensor = tensor.to(torch.float32)
        tensor[tensor > 0] /= 32767.
        tensor[tensor < 0] /= 32768.
    elif tensor.dtype == torch.uint8:
        tensor = tensor.to(torch.float32) - 128
        tensor[tensor > 0] /= 127.
        tensor[tensor < 0] /= 128.
    return tensor 
Example #8
Source File: pytorch.py    From incubator-tvm with Apache License 2.0 6 votes vote down vote up
def _convert_dtype_value(val):
    """converts a PyTorch the PyTorch numeric type id to a torch scalar type."""
    convert_torch_dtype_map = {7:"torch.float64",
                               6:"torch.float32",
                               5:"torch.float16",
                               4:"torch.int64",
                               3:"torch.int32",
                               2:"torch.int16",
                               1:"torch.int8",
                               0:"torch.unit8",
                               None:"torch.int64"} # Default is torch.int64
    if val in convert_torch_dtype_map:
        return _convert_data_type(convert_torch_dtype_map[val])
    else:
        msg = "Torch data type value %d is not handled yet." % (val)
        raise NotImplementedError(msg) 
Example #9
Source File: bitcoding.py    From L3C-PyTorch with GNU General Public License v3.0 6 votes vote down vote up
def encode_uniform(self, dmll, S, fout):
        """ encode coarsest scale, for which we assume a uniform prior. """
        write_shape(S.shape, fout)
        r = ArithmeticCoder(dmll.L)

        entropy_coding_bytes = 0
        with self.times.prefix_scope('uniform encode'):
            c_uniform = self._get_uniform_cdf(S.shape, dmll.L)
            for c in range(S.shape[1]):
                S_c = S[:, c, ...].to(torch.int16)
                encoded = r.range_encode(S_c, c_uniform, self.times)
                write_num_bytes_encoded(len(encoded), fout)
                entropy_coding_bytes += len(encoded)
                fout.write(encoded)

        return entropy_coding_bytes 
Example #10
Source File: pytorch.py    From incubator-tvm with Apache License 2.0 6 votes vote down vote up
def _create_typed_const(data, dtype):
    """create a (scalar) constant of given value and dtype.
       dtype should be a TVM dtype"""

    if dtype == "float64":
        typed_data = _expr.const(np.float64(data), dtype=dtype)
    elif dtype == "float32":
        typed_data = _expr.const(np.float32(data), dtype=dtype)
    elif dtype == "float16":
        typed_data = _expr.const(np.float16(data), dtype=dtype)
    elif dtype == "int64":
        typed_data = _expr.const(np.int64(data), dtype=dtype)
    elif dtype == "int32":
        typed_data = _expr.const(np.int32(data), dtype=dtype)
    elif dtype == "int16":
        typed_data = _expr.const(np.int16(data), dtype=dtype)
    elif dtype == "int8":
        typed_data = _expr.const(np.int8(data), dtype=dtype)
    elif dtype == "uint8":
        typed_data = _expr.const(np.uint8(data), dtype=dtype)
    else:
        raise NotImplementedError("input_type {} is not handled yet".format(dtype))
    return typed_data 
Example #11
Source File: types.py    From heat with MIT License 5 votes vote down vote up
def promote_types(type1, type2):
    """
    Returns the data type with the smallest size and smallest scalar kind to which both type1 and type2 may be
    intuitively cast to, where intuitive casting refers to maintaining the same bit length if possible. This
    function is symmetric.

    Parameters
    ----------
    type1 : type, str, ht.dtype
        type of first operand
    type2 : type, str, ht.dtype
        type of second operand

    Returns
    -------
    out : ht.dtype
        The promoted data type.

    Examples
    --------
    >>> ht.promote_types(ht.uint8, ht.uint8)
    ht.uint8
    >>> ht.promote_types(ht.int32, ht.float32)
    ht.float32
    >>> ht.promote_types(ht.int8, ht.uint8)
    ht.int16
    >>> ht.promote_types("i8", "f4")
    ht.float64
    """
    typecode_type1 = __type_codes[canonical_heat_type(type1)]
    typecode_type2 = __type_codes[canonical_heat_type(type2)]

    return __type_promotions[typecode_type1][typecode_type2] 
Example #12
Source File: test_types.py    From heat with MIT License 5 votes vote down vote up
def test_int16(self):
        self.assert_is_instantiable_heat_type(ht.int16, torch.int16)
        self.assert_is_instantiable_heat_type(ht.short, torch.int16) 
Example #13
Source File: generate_opus.py    From audio with BSD 2-Clause "Simplified" License 5 votes vote down vote up
def _generate(num_channels, compression_level, bitrate):
    org_path = 'original.wav'
    ops_path = f'{bitrate}_{compression_level}_{num_channels}ch.opus'

    # Note: ffmpeg forces sample rate 48k Hz for opus https://stackoverflow.com/a/39186779
    # 1. generate original wav
    data = torch.linspace(-32768, 32767, 32768, dtype=torch.int16).repeat([num_channels, 1]).t()
    scipy.io.wavfile.write(org_path, 48000, data.numpy())
    # 2. convert to opus
    convert_to_opus(org_path, ops_path, bitrate=bitrate, compression_level=compression_level) 
Example #14
Source File: test_types.py    From heat with MIT License 5 votes vote down vote up
def test_type_promotions(self):
        self.assertEqual(ht.promote_types(ht.uint8, ht.uint8), ht.uint8)
        self.assertEqual(ht.promote_types(ht.int8, ht.uint8), ht.int16)
        self.assertEqual(ht.promote_types(ht.int32, ht.float32), ht.float32)
        self.assertEqual(ht.promote_types("f4", ht.float), ht.float32)
        self.assertEqual(ht.promote_types(ht.bool_, "?"), ht.bool)

        # exceptions
        with self.assertRaises(TypeError):
            ht.promote_types(1, "?")
        with self.assertRaises(TypeError):
            ht.promote_types(ht.float32, "hello world") 
Example #15
Source File: test_dndarray.py    From heat with MIT License 5 votes vote down vote up
def test_and(self):
        int16_tensor = ht.array([[1, 1], [2, 2]], dtype=ht.int16)
        int16_vector = ht.array([[3, 4]], dtype=ht.int16)

        self.assertTrue(
            ht.equal(int16_tensor & int16_vector, ht.bitwise_and(int16_tensor, int16_vector))
        ) 
Example #16
Source File: test_dndarray.py    From heat with MIT License 5 votes vote down vote up
def test_or(self):
        int16_tensor = ht.array([[1, 1], [2, 2]], dtype=ht.int16)
        int16_vector = ht.array([[3, 4]], dtype=ht.int16)

        self.assertTrue(
            ht.equal(int16_tensor | int16_vector, ht.bitwise_or(int16_tensor, int16_vector))
        ) 
Example #17
Source File: test_dndarray.py    From heat with MIT License 5 votes vote down vote up
def test_xor(self):
        int16_tensor = ht.array([[1, 1], [2, 2]], dtype=ht.int16)
        int16_vector = ht.array([[3, 4]], dtype=ht.int16)

        self.assertTrue(
            ht.equal(int16_tensor ^ int16_vector, ht.bitwise_xor(int16_tensor, int16_vector))
        ) 
Example #18
Source File: pytorch.py    From incubator-tvm with Apache License 2.0 5 votes vote down vote up
def _convert_data_type(input_type, default_dtype=None):
    """converts the PyTorch scalar type input_type to a TVM dtype.
       optionally, default_dtype can be a TVM dtype that is used
       if input_type is None (but not when it is unknown)"""
    if input_type is None and default_dtype is not None:
        return default_dtype

    input_type = input_type.lower()
    if input_type in ["double", "torch.float64"]:
        return "float64"
    elif input_type in ["float", "torch.float32"]:
        return "float32"
    elif input_type in ["half", "torch.float16"]:
        return "float16"
    elif input_type in ["long", "torch.int64"]:
        return "int64"
    elif input_type in ["int", "torch.int32"]:
        return "int32"
    elif input_type in ["short", "torch.int16"]:
        return "int16"
    elif input_type in ["char", "torch.int8"]:
        return "int8"
    elif input_type in ["byte", "torch.uint8"]:
        return "uint8"
    elif input_type in ["quint8", "torch.quint8"]:
        return "quint8"
    elif input_type in ["qint8", "torch.qint8"]:
        return "qint8"
    elif input_type in ["qint32", "torch.qint32"]:
        return "qint32"
    elif input_type in ["bool", "torch.bool"]:
        return "bool"
    else:
        raise NotImplementedError("input_type {} is not handled yet".format(input_type))
    return "float32"  # Never reached 
Example #19
Source File: algorithmic.py    From spectre with Apache License 2.0 5 votes vote down vote up
def split(self, data: torch.Tensor) -> torch.Tensor:
        ret = torch.take(data, self._sorted_indices)
        assert ret.dtype not in {torch.int8, torch.int16, torch.int32, torch.int64}, \
            'tensor cannot be any type of int, recommended to use float32'
        ret.masked_fill_(self._padding_mask, np.nan)
        return ret 
Example #20
Source File: test_forward.py    From incubator-tvm with Apache License 2.0 5 votes vote down vote up
def test_type_as():
    torch.set_grad_enabled(False)
    input_shape = [1, 3]

    def _create_module(dtype):
        class TypeAs(Module):
            def forward(self, *args):
                expected_type_tensor = torch.zeros(1, 3, dtype=dtype)
                return args[0].type_as(expected_type_tensor)

        return TypeAs()

    input_data = torch.randn(input_shape).float()
    verify_model(_create_module(torch.float64), input_data=input_data)
    verify_model(_create_module(torch.float32), input_data=input_data)
    verify_model(_create_module(torch.int64), input_data=input_data)
    verify_model(_create_module(torch.int32), input_data=input_data)
    verify_model(_create_module(torch.int16), input_data=input_data)
    verify_model(_create_module(torch.int8), input_data=input_data)

    if torch.cuda.is_available():
        check_fp16 = False
        try:
            # Only check half precision on supported hardwares.
            if have_fp16(tvm.gpu(0).compute_version):
                check_fp16 = True
        except Exception as e:
            # If GPU is not enabled in TVM, skip the fp16 test.
            pass

        # Temporary disable fp16 test
        check_fp16 = False

        if check_fp16:
            verify_model(_create_module(torch.float16), input_data=input_data) 
Example #21
Source File: factor.py    From spectre with Apache License 2.0 5 votes vote down vote up
def compute(self, data: torch.Tensor) -> torch.Tensor:
        if data.dtype in {torch.int8, torch.int16, torch.int32, torch.int64}:
            raise ValueError('factor.shift() does not support `int` type, '
                             'please convert to float by using `factor.float()`, upstreams: {}'
                             .format(self.inputs))

        shift = data.roll(self.periods, dims=1)
        if self.periods > 0:
            shift[:, 0:self.periods] = np.nan
        else:
            shift[:, self.periods:] = np.nan
        return shift 
Example #22
Source File: utils.py    From Adversarial-Continual-Learning with MIT License 5 votes vote down vote up
def read_sn3_pascalvincent_tensor(path, strict=True):
    """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh').
       Argument may be a filename, compressed filename, or file object.
    """
    # typemap
    if not hasattr(read_sn3_pascalvincent_tensor, 'typemap'):
        read_sn3_pascalvincent_tensor.typemap = {
            8: (torch.uint8, np.uint8, np.uint8),
            9: (torch.int8, np.int8, np.int8),
            11: (torch.int16, np.dtype('>i2'), 'i2'),
            12: (torch.int32, np.dtype('>i4'), 'i4'),
            13: (torch.float32, np.dtype('>f4'), 'f4'),
            14: (torch.float64, np.dtype('>f8'), 'f8')}
    # read
    with open_maybe_compressed_file(path) as f:
        data = f.read()
    # parse
    magic = get_int(data[0:4])
    nd = magic % 256
    ty = magic // 256
    assert nd >= 1 and nd <= 3
    assert ty >= 8 and ty <= 14
    m = read_sn3_pascalvincent_tensor.typemap[ty]
    s = [get_int(data[4 * (i + 1): 4 * (i + 2)]) for i in range(nd)]
    parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1)))
    assert parsed.shape[0] == np.prod(s) or not strict
    return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s) 
Example #23
Source File: mnist.py    From catalyst with Apache License 2.0 5 votes vote down vote up
def read_sn3_pascalvincent_tensor(path, strict=True):
    """Read a SN3 file in "Pascal Vincent" format.
    Argument may be a filename, compressed filename, or file object.
    """
    # typemap
    if not hasattr(read_sn3_pascalvincent_tensor, "typemap"):
        read_sn3_pascalvincent_tensor.typemap = {
            8: (torch.uint8, np.uint8, np.uint8),
            9: (torch.int8, np.int8, np.int8),
            11: (torch.int16, np.dtype(">i2"), "i2"),
            12: (torch.int32, np.dtype(">i4"), "i4"),
            13: (torch.float32, np.dtype(">f4"), "f4"),
            14: (torch.float64, np.dtype(">f8"), "f8"),
        }
    # read
    with open_maybe_compressed_file(path) as f:
        data = f.read()
    # parse
    magic = get_int(data[0:4])  # noqa: WPS349
    nd = magic % 256
    ty = magic // 256
    assert nd >= 1 and nd <= 3
    assert ty >= 8 and ty <= 14
    m = read_sn3_pascalvincent_tensor.typemap[ty]
    s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)]
    parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1)))
    assert parsed.shape[0] == np.prod(s) or not strict
    return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s) 
Example #24
Source File: pytorch.py    From incubator-tvm with Apache License 2.0 5 votes vote down vote up
def _pytorch_result_type(dtypes, non_tensor_inputs):
    """This promotes TVM dtypes like PyTorch would"""
    import torch
    dtype_map = {
        "float64": torch.float64,
        "float32": torch.float32,
        "float16": torch.float16,
        "bfloat16": torch.bfloat16,
        "int64": torch.int64,
        "int32": torch.int32,
        "int16": torch.int16,
        "int8": torch.int8,
        "uint8": torch.uint8,
        "bool": torch.bool
        }
    if len(dtypes) > 0:
        result_type = dtypes[0]
        for dt in dtypes[1:]:
            if dt != result_type: # we don't want to work with same types as we
                                  # don't do quantized here (which cannot be promoted?)
                result_type = _convert_data_type(str(torch.result_type(
                    torch.zeros((), dtype=dtype_map[result_type]),
                    torch.zeros((), dtype=dtype_map[dt]))))
    else:
        result_type = "bool"  # this is the smallest type...
    for inp in non_tensor_inputs:
        result_type = _convert_data_type(
            str(torch.result_type(torch.zeros((), dtype=dtype_map[result_type]),
                                  inp)))
    return result_type 
Example #25
Source File: data_utils.py    From BraTS-DMFNet with Apache License 2.0 5 votes vote down vote up
def sample(x, size):
    #https://gist.github.com/yoavram/4134617
    i = random.sample(range(x.shape[0]), size)
    return torch.tensor(x[i], dtype=torch.int16)
    #x = np.random.permutation(x)
    #return torch.tensor(x[:size]) 
Example #26
Source File: data_utils.py    From BraTS-DMFNet with Apache License 2.0 5 votes vote down vote up
def get_all_coords(stride):
    return torch.tensor(
        np.stack([v.reshape(-1) for v in
            np.meshgrid(
                    *[stride//2 + np.arange(0, s, stride) for s in _shape],
                    indexing='ij')],
            -1), dtype=torch.int16) 
Example #27
Source File: torchac.py    From L3C-PyTorch with GNU General Public License v3.0 5 votes vote down vote up
def encode_cdf(cdf, sym):
    """
    :param cdf: CDF as 1HWLp, as int16, on CPU!
    :param sym: the symbols to encode, as int16, on CPU
    :return: byte-string, encoding `sym`
    """
    if cdf.is_cuda or sym.is_cuda:
        raise ValueError('CDF and symbols must be on CPU for `encode_cdf`')
    # encode_cdf is defined in both backends, so doesn't matter which one we use!
    return any_backend.encode_cdf(cdf, sym) 
Example #28
Source File: fsl_trainer.py    From FEAT with MIT License 5 votes vote down vote up
def evaluate(self, data_loader):
        # restore model args
        args = self.args
        # evaluation mode
        self.model.eval()
        record = np.zeros((args.num_eval_episodes, 2)) # loss and acc
        label = torch.arange(args.eval_way, dtype=torch.int16).repeat(args.eval_query)
        label = label.type(torch.LongTensor)
        if torch.cuda.is_available():
            label = label.cuda()
        print('best epoch {}, best val acc={:.4f} + {:.4f}'.format(
                self.trlog['max_acc_epoch'],
                self.trlog['max_acc'],
                self.trlog['max_acc_interval']))
        with torch.no_grad():
            for i, batch in enumerate(data_loader, 1):
                if torch.cuda.is_available():
                    data, _ = [_.cuda() for _ in batch]
                else:
                    data = batch[0]

                logits = self.model(data)
                loss = F.cross_entropy(logits, label)
                acc = count_acc(logits, label)
                record[i-1, 0] = loss.item()
                record[i-1, 1] = acc
                
        assert(i == record.shape[0])
        vl, _ = compute_confidence_interval(record[:,0])
        va, vap = compute_confidence_interval(record[:,1])
        
        # train mode
        self.model.train()
        if self.args.fix_BN:
            self.model.encoder.eval()

        return vl, va, vap 
Example #29
Source File: fsl_trainer.py    From FEAT with MIT License 5 votes vote down vote up
def prepare_label(self):
        args = self.args

        # prepare one-hot label
        label = torch.arange(args.way, dtype=torch.int16).repeat(args.query)
        label_aux = torch.arange(args.way, dtype=torch.int8).repeat(args.shot + args.query)
        
        label = label.type(torch.LongTensor)
        label_aux = label_aux.type(torch.LongTensor)
        
        if torch.cuda.is_available():
            label = label.cuda()
            label_aux = label_aux.cuda()
            
        return label, label_aux 
Example #30
Source File: test_common.py    From CrypTen with MIT License 5 votes vote down vote up
def test_encode_decode(self):
        """Tests tensor encoding and decoding."""
        for float in [False, True]:
            if float:
                fpe = FixedPointEncoder(precision_bits=16)
            else:
                fpe = FixedPointEncoder(precision_bits=0)
            tensor = get_test_tensor(float=float)
            decoded = fpe.decode(fpe.encode(tensor))
            self._check(
                decoded,
                tensor,
                "Encoding/decoding a %s failed." % "float" if float else "long",
            )

        # Make sure encoding a subclass of CrypTensor is a no-op
        crypten.mpc.set_default_provider(crypten.mpc.provider.TrustedFirstParty)
        crypten.init()

        tensor = get_test_tensor(float=True)
        encrypted_tensor = crypten.cryptensor(tensor)
        encrypted_tensor = fpe.encode(encrypted_tensor)
        self._check(
            encrypted_tensor.get_plain_text(),
            tensor,
            "Encoding an EncryptedTensor failed.",
        )

        # Try a few other types.
        fpe = FixedPointEncoder(precision_bits=0)
        for dtype in [torch.uint8, torch.int8, torch.int16]:
            tensor = torch.zeros(5, dtype=dtype).random_()
            decoded = fpe.decode(fpe.encode(tensor)).type(dtype)
            self._check(decoded, tensor, "Encoding/decoding a %s failed." % dtype)