Python torch.int8() Examples
The following are 30
code examples of torch.int8().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch
, or try the search function
.
Example #1
Source File: test_quantization.py From nlp-architect with Apache License 2.0 | 6 votes |
def test_export_to_8bit_with_bias(self): qlinear = QuantizedLinear(10, 5, mode="EMA") qlinear.eval() state_dict = qlinear.state_dict() self.assertTrue("weight" in state_dict) self.assertTrue("bias" in state_dict) self.assertTrue("quantized_weight" not in state_dict) self.assertTrue("_quantized_bias" not in state_dict) self.assertTrue("bias_scale" not in state_dict) qlinear.mode_8bit = True state_dict = qlinear.state_dict() self.assertTrue("weight" not in state_dict) self.assertTrue("bias" not in state_dict) self.assertTrue("quantized_weight" in state_dict) self.assertTrue(state_dict["quantized_weight"].dtype == torch.int8) self.assertTrue("_quantized_bias" in state_dict) self.assertTrue(state_dict["_quantized_bias"].dtype == torch.int32) self.assertTrue("bias_scale" in state_dict) qlinear.mode_8bit = False state_dict = qlinear.state_dict() self.assertTrue("weight" in state_dict) self.assertTrue("bias" in state_dict) self.assertTrue("quantized_weight" not in state_dict) self.assertTrue("_quantized_bias" not in state_dict) self.assertTrue("bias_scale" not in state_dict)
Example #2
Source File: types.py From chainer-compiler with MIT License | 6 votes |
def torch_dtype_to_np_dtype(dtype): dtype_dict = { torch.bool : np.dtype(np.bool), torch.uint8 : np.dtype(np.uint8), torch.int8 : np.dtype(np.int8), torch.int16 : np.dtype(np.int16), torch.short : np.dtype(np.int16), torch.int32 : np.dtype(np.int32), torch.int : np.dtype(np.int32), torch.int64 : np.dtype(np.int64), torch.long : np.dtype(np.int64), torch.float16 : np.dtype(np.float16), torch.half : np.dtype(np.float16), torch.float32 : np.dtype(np.float32), torch.float : np.dtype(np.float32), torch.float64 : np.dtype(np.float64), torch.double : np.dtype(np.float64), } return dtype_dict[dtype] # ---------------------- InferenceEngine internal types ------------------------
Example #3
Source File: test_quantization.py From nlp-architect with Apache License 2.0 | 6 votes |
def test_export_to_8bit_without_bias(self): qlinear = QuantizedLinear(10, 5, bias=False, mode="EMA") qlinear.eval() qlinear.mode_8bit = True state_dict = qlinear.state_dict() self.assertTrue("weight" not in state_dict) self.assertTrue("bias" not in state_dict) self.assertTrue("quantized_weight" in state_dict) self.assertTrue(state_dict["quantized_weight"].dtype == torch.int8) self.assertTrue("_quantized_bias" not in state_dict) self.assertTrue("bias_scale" not in state_dict) qlinear.mode_8bit = False state_dict = qlinear.state_dict() self.assertTrue("weight" in state_dict) self.assertTrue("bias" not in state_dict) self.assertTrue("quantized_weight" not in state_dict) self.assertTrue("_quantized_bias" not in state_dict) self.assertTrue("bias_scale" not in state_dict)
Example #4
Source File: test_forward.py From incubator-tvm with Apache License 2.0 | 6 votes |
def test_forward_logical_xor(): torch.set_grad_enabled(False) class LogicalXor1(Module): def forward(self, *args): return torch.logical_xor(args[0], args[1]) class LogicalXor2(Module): def forward(self, *args): rhs = torch.tensor([1, 0, 3], dtype=torch.int8) if torch.cuda.is_available(): rhs = rhs.cuda() return torch.logical_xor(args[0], rhs) lhs = torch.tensor([-1, -2, 3], dtype=torch.int8) rhs = torch.tensor([1, 0, 3], dtype=torch.int8) verify_model(LogicalXor1().float().eval(), input_data=[lhs, rhs]) lhs = torch.tensor([True, True, False]) rhs = torch.tensor([False, True, False]) verify_model(LogicalXor1().float().eval(), input_data=[lhs, rhs]) lhs = torch.tensor([-1, -2, 3], dtype=torch.int8) verify_model(LogicalXor2().float().eval(), input_data=[lhs])
Example #5
Source File: test_forward.py From incubator-tvm with Apache License 2.0 | 6 votes |
def test_forward_bitwise_xor(): torch.set_grad_enabled(False) class BitwiseXor1(Module): def forward(self, *args): return torch.bitwise_xor(args[0], args[1]) class BitwiseXor2(Module): def forward(self, *args): rhs = torch.tensor([1, 0, 3], dtype=torch.int8) if torch.cuda.is_available(): rhs = rhs.cuda() return torch.bitwise_xor(args[0], rhs) lhs = torch.tensor([-1, -2, 3], dtype=torch.int8) rhs = torch.tensor([1, 0, 3], dtype=torch.int8) verify_model(BitwiseXor1().float().eval(), input_data=[lhs, rhs]) lhs = torch.tensor([True, True, False]) rhs = torch.tensor([False, True, False]) verify_model(BitwiseXor1().float().eval(), input_data=[lhs, rhs]) lhs = torch.tensor([-1, -2, 3], dtype=torch.int8) verify_model(BitwiseXor2().float().eval(), input_data=[lhs])
Example #6
Source File: ops.py From adeptRL with GNU General Public License v3.0 | 6 votes |
def update_dtype(self, old_dtype): updated = {} for k, v in old_dtype.items(): if v == np.float32: dt = torch.float32 elif v == np.float64: dt = torch.float64 elif v == np.float16: dt = torch.float16 elif v == np.uint8: dt = torch.uint8 elif v == np.int8: dt = torch.int8 elif v == np.int16: dt = torch.int16 elif v == np.int32: dt = torch.int32 elif v == np.int16: dt = torch.int16 else: raise ValueError("Unsupported dtype {}".format(v)) updated[k] = dt return updated
Example #7
Source File: pytorch_abstract_types.py From myia with MIT License | 6 votes |
def pytorch_dtype_to_type(dtype): """Map a pytorch dtype to a myia type.""" import torch _type_map = { torch.int8: Int[8], torch.int16: Int[16], torch.int32: Int[32], torch.int64: Int[64], torch.uint8: UInt[8], torch.float16: Float[16], torch.float32: Float[32], torch.float64: Float[64], torch.bool: Bool, } if dtype not in _type_map: raise TypeError(f"Unsupported dtype {dtype}") return _type_map[dtype]
Example #8
Source File: test_forward.py From incubator-tvm with Apache License 2.0 | 6 votes |
def test_forward_logical_not(): torch.set_grad_enabled(False) class LogicalNot1(Module): def forward(self, *args): return torch.logical_not(args[0]) input_data = torch.tensor([True, False]) verify_model(LogicalNot1().float().eval(), input_data=input_data) input_data = torch.tensor([0, 1, -10], dtype=torch.int8) verify_model(LogicalNot1().float().eval(), input_data=input_data) input_data = torch.tensor([0., 1.5, -10.], dtype=torch.double) verify_model(LogicalNot1().float().eval(), input_data=input_data) input_data = torch.tensor([0., 1., -10.], dtype=torch.int32) verify_model(LogicalNot1().float().eval(), input_data=input_data)
Example #9
Source File: distributed_communicator.py From CrypTen with MIT License | 6 votes |
def broadcast_obj(self, obj, src, group=None): """Broadcasts a given object to all parties.""" if group is None: group = self.main_group if self.rank == src: assert obj is not None, "src party must provide obj for broadcast" buf = pickle.dumps(obj) size = torch.tensor(len(buf), dtype=torch.int32) arr = torch.from_numpy(numpy.frombuffer(buf, dtype=numpy.int8)) dist.broadcast(size, src, group=group) dist.broadcast(arr, src, group=group) else: size = torch.tensor(1, dtype=torch.int32) dist.broadcast(size, src, group=group) data = torch.empty(size=(size,), dtype=torch.int8) dist.broadcast(data, src, group=group) buf = data.numpy().tobytes() obj = serial.restricted_loads(buf) return obj
Example #10
Source File: pytorch.py From incubator-tvm with Apache License 2.0 | 6 votes |
def _create_typed_const(data, dtype): """create a (scalar) constant of given value and dtype. dtype should be a TVM dtype""" if dtype == "float64": typed_data = _expr.const(np.float64(data), dtype=dtype) elif dtype == "float32": typed_data = _expr.const(np.float32(data), dtype=dtype) elif dtype == "float16": typed_data = _expr.const(np.float16(data), dtype=dtype) elif dtype == "int64": typed_data = _expr.const(np.int64(data), dtype=dtype) elif dtype == "int32": typed_data = _expr.const(np.int32(data), dtype=dtype) elif dtype == "int16": typed_data = _expr.const(np.int16(data), dtype=dtype) elif dtype == "int8": typed_data = _expr.const(np.int8(data), dtype=dtype) elif dtype == "uint8": typed_data = _expr.const(np.uint8(data), dtype=dtype) else: raise NotImplementedError("input_type {} is not handled yet".format(dtype)) return typed_data
Example #11
Source File: pytorch.py From incubator-tvm with Apache License 2.0 | 6 votes |
def _convert_dtype_value(val): """converts a PyTorch the PyTorch numeric type id to a torch scalar type.""" convert_torch_dtype_map = {7:"torch.float64", 6:"torch.float32", 5:"torch.float16", 4:"torch.int64", 3:"torch.int32", 2:"torch.int16", 1:"torch.int8", 0:"torch.unit8", None:"torch.int64"} # Default is torch.int64 if val in convert_torch_dtype_map: return _convert_data_type(convert_torch_dtype_map[val]) else: msg = "Torch data type value %d is not handled yet." % (val) raise NotImplementedError(msg)
Example #12
Source File: external_configurables_test.py From gin-config with Apache License 2.0 | 6 votes |
def testDtypes(self): # Spot check a few. config_str = """ # Test without torch prefix, but using the # prefix is strongly recommended! configurable.float32 = %float32 # Test with torch prefix. configurable.int8 = %torch.int8 configurable.float16 = %torch.float16 """ config.parse_config(config_str) vals = configurable() # pylint: disable=E1101 self.assertIs(vals['float32'], torch.float32) self.assertIs(vals['int8'], torch.int8) self.assertIs(vals['float16'], torch.float16) # pylint: disable=E1101
Example #13
Source File: constants.py From heat with MIT License | 6 votes |
def sanitize_infinity(dtype): """ Returns largest possible value for the specified dtype. Parameters: ----------- dtype: torch dtype Returns: -------- large_enough: largest possible value for the given dtype """ if dtype is torch.int8: large_enough = (1 << 7) - 1 elif dtype is torch.int16: large_enough = (1 << 15) - 1 elif dtype is torch.int32: large_enough = (1 << 31) - 1 elif dtype is torch.int64: large_enough = (1 << 63) - 1 else: large_enough = float("inf") return large_enough
Example #14
Source File: seq2slate.py From ReAgent with BSD 3-Clause "New" or "Revised" License | 5 votes |
def subsequent_mask(size, device): """ Mask out subsequent positions. Mainly used in the decoding process, in which an item should not attend subsequent items. """ attn_shape = (1, size, size) subsequent_mask = ( 1 - torch.triu(torch.ones(*attn_shape, device=device), diagonal=1) ).type(torch.int8) return subsequent_mask
Example #15
Source File: tensor.py From pytorch_sparse with MIT License | 5 votes |
def char(self): return self.type_as( torch.tensor(0, dtype=torch.int8, device=self.device()))
Example #16
Source File: seq2slate.py From ReAgent with BSD 3-Clause "New" or "Revised" License | 5 votes |
def subsequent_and_padding_mask(tgt_in_idx): """ Create a mask to hide padding and future items """ # tgt_in_idx shape: batch_size, seq_len # tgt_tgt_mask shape: batch_size, 1, seq_len tgt_tgt_mask = (tgt_in_idx != PADDING_SYMBOL).unsqueeze(-2).type(torch.int8) # subseq_mask shape: 1, seq_len, seq_len subseq_mask = subsequent_mask(tgt_in_idx.size(-1), tgt_in_idx.device) # tgt_tgt_mask shape: batch_size, seq_len, seq_len tgt_tgt_mask = tgt_tgt_mask & subseq_mask return tgt_tgt_mask
Example #17
Source File: mutation_gpu.py From scikit-opt with MIT License | 5 votes |
def mutation(self): ''' mutation of 0/1 type chromosome faster than `self.Chrom = (mask + self.Chrom) % 2` :param self: :return: ''' mask = (torch.rand(size=(self.size_pop, self.len_chrom), device=self.device) < self.prob_mut).type(torch.int8) self.Chrom ^= mask return self.Chrom
Example #18
Source File: test_matcher.py From detectron2 with Apache License 2.0 | 5 votes |
def test_scriptability(self): cfg = get_cfg() anchor_matcher = Matcher( cfg.MODEL.RPN.IOU_THRESHOLDS, cfg.MODEL.RPN.IOU_LABELS, allow_low_quality_matches=True ) match_quality_matrix = torch.tensor( [[0.15, 0.45, 0.2, 0.6], [0.3, 0.65, 0.05, 0.1], [0.05, 0.4, 0.25, 0.4]] ) expected_matches = torch.tensor([1, 1, 2, 0]) expected_match_labels = torch.tensor([-1, 1, 0, 1], dtype=torch.int8) matches, match_labels = anchor_matcher(match_quality_matrix) self.assertTrue(torch.allclose(matches, expected_matches)) self.assertTrue(torch.allclose(match_labels, expected_match_labels)) # nonzero_tuple must be import explicitly to let jit know what it is. # https://github.com/pytorch/pytorch/issues/38964 from detectron2.layers import nonzero_tuple # noqa F401 scripted_matcher = torch.jit.script(Matcher) scripted_anchor_matcher = scripted_matcher( cfg.MODEL.RPN.IOU_THRESHOLDS, cfg.MODEL.RPN.IOU_LABELS, allow_low_quality_matches=True ) matches, match_labels = scripted_anchor_matcher(match_quality_matrix) self.assertTrue(torch.allclose(matches, expected_matches)) self.assertTrue(torch.allclose(match_labels, expected_match_labels))
Example #19
Source File: aev.py From torchani with MIT License | 5 votes |
def triple_by_molecule(atom_index12: Tensor) -> Tuple[Tensor, Tensor, Tensor]: """Input: indices for pairs of atoms that are close to each other. each pair only appear once, i.e. only one of the pairs (1, 2) and (2, 1) exists. Output: indices for all central atoms and it pairs of neighbors. For example, if input has pair (0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4), then the output would have central atom 0, 1, 2, 3, 4 and for cental atom 0, its pairs of neighbors are (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4) """ # convert representation from pair to central-others ai1 = atom_index12.view(-1) sorted_ai1, rev_indices = ai1.sort() # sort and compute unique key uniqued_central_atom_index, counts = torch.unique_consecutive(sorted_ai1, return_inverse=False, return_counts=True) # compute central_atom_index pair_sizes = counts * (counts - 1) // 2 pair_indices = torch.repeat_interleave(pair_sizes) central_atom_index = uniqued_central_atom_index.index_select(0, pair_indices) # do local combinations within unique key, assuming sorted m = counts.max().item() if counts.numel() > 0 else 0 n = pair_sizes.shape[0] intra_pair_indices = torch.tril_indices(m, m, -1, device=ai1.device).unsqueeze(1).expand(-1, n, -1) mask = (torch.arange(intra_pair_indices.shape[2], device=ai1.device) < pair_sizes.unsqueeze(1)).flatten() sorted_local_index12 = intra_pair_indices.flatten(1, 2)[:, mask] sorted_local_index12 += cumsum_from_zero(counts).index_select(0, pair_indices) # unsort result from last part local_index12 = rev_indices[sorted_local_index12] # compute mapping between representation of central-other to pair n = atom_index12.shape[1] sign12 = ((local_index12 < n).to(torch.int8) * 2) - 1 return central_atom_index, local_index12 % n, sign12
Example #20
Source File: crossover_gpu.py From scikit-opt with MIT License | 5 votes |
def crossover_2point_bit(self): Chrom, size_pop, len_chrom = self.Chrom, self.size_pop, self.len_chrom half_size_pop = int(size_pop / 2) Chrom1, Chrom2 = Chrom[:half_size_pop], Chrom[half_size_pop:] mask = torch.zeros(size=(half_size_pop, len_chrom), dtype=torch.int8, device=self.device) for i in range(half_size_pop): n1, n2 = np.random.randint(0, self.len_chrom, 2) if n1 > n2: n1, n2 = n2, n1 mask[i, n1:n2] = 1 mask2 = (Chrom1 ^ Chrom2) & mask Chrom1 ^= mask2 Chrom2 ^= mask2 return self.Chrom
Example #21
Source File: mnist.py From catalyst with Apache License 2.0 | 5 votes |
def read_sn3_pascalvincent_tensor(path, strict=True): """Read a SN3 file in "Pascal Vincent" format. Argument may be a filename, compressed filename, or file object. """ # typemap if not hasattr(read_sn3_pascalvincent_tensor, "typemap"): read_sn3_pascalvincent_tensor.typemap = { 8: (torch.uint8, np.uint8, np.uint8), 9: (torch.int8, np.int8, np.int8), 11: (torch.int16, np.dtype(">i2"), "i2"), 12: (torch.int32, np.dtype(">i4"), "i4"), 13: (torch.float32, np.dtype(">f4"), "f4"), 14: (torch.float64, np.dtype(">f8"), "f8"), } # read with open_maybe_compressed_file(path) as f: data = f.read() # parse magic = get_int(data[0:4]) # noqa: WPS349 nd = magic % 256 ty = magic // 256 assert nd >= 1 and nd <= 3 assert ty >= 8 and ty <= 14 m = read_sn3_pascalvincent_tensor.typemap[ty] s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)] parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1))) assert parsed.shape[0] == np.prod(s) or not strict return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
Example #22
Source File: test_quantization.py From nlp-architect with Apache License 2.0 | 5 votes |
def test_export_to_8bit(self): qembed = QuantizedEmbedding(10, 5, mode="EMA") qembed.eval() state_dict = qembed.state_dict() self.assertTrue("quantized_weight" not in state_dict) self.assertTrue("weight" in state_dict) qembed.mode_8bit = True state_dict = qembed.state_dict() self.assertTrue("quantized_weight" in state_dict) self.assertTrue(state_dict["quantized_weight"].dtype == torch.int8) self.assertTrue("weight" not in state_dict) qembed.mode_8bit = False state_dict = qembed.state_dict() self.assertTrue("quantized_weight" not in state_dict) self.assertTrue("weight" in state_dict)
Example #23
Source File: pytorch.py From incubator-tvm with Apache License 2.0 | 5 votes |
def _pytorch_result_type(dtypes, non_tensor_inputs): """This promotes TVM dtypes like PyTorch would""" import torch dtype_map = { "float64": torch.float64, "float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16, "int64": torch.int64, "int32": torch.int32, "int16": torch.int16, "int8": torch.int8, "uint8": torch.uint8, "bool": torch.bool } if len(dtypes) > 0: result_type = dtypes[0] for dt in dtypes[1:]: if dt != result_type: # we don't want to work with same types as we # don't do quantized here (which cannot be promoted?) result_type = _convert_data_type(str(torch.result_type( torch.zeros((), dtype=dtype_map[result_type]), torch.zeros((), dtype=dtype_map[dt])))) else: result_type = "bool" # this is the smallest type... for inp in non_tensor_inputs: result_type = _convert_data_type( str(torch.result_type(torch.zeros((), dtype=dtype_map[result_type]), inp))) return result_type
Example #24
Source File: pytorch.py From incubator-tvm with Apache License 2.0 | 5 votes |
def _convert_data_type(input_type, default_dtype=None): """converts the PyTorch scalar type input_type to a TVM dtype. optionally, default_dtype can be a TVM dtype that is used if input_type is None (but not when it is unknown)""" if input_type is None and default_dtype is not None: return default_dtype input_type = input_type.lower() if input_type in ["double", "torch.float64"]: return "float64" elif input_type in ["float", "torch.float32"]: return "float32" elif input_type in ["half", "torch.float16"]: return "float16" elif input_type in ["long", "torch.int64"]: return "int64" elif input_type in ["int", "torch.int32"]: return "int32" elif input_type in ["short", "torch.int16"]: return "int16" elif input_type in ["char", "torch.int8"]: return "int8" elif input_type in ["byte", "torch.uint8"]: return "uint8" elif input_type in ["quint8", "torch.quint8"]: return "quint8" elif input_type in ["qint8", "torch.qint8"]: return "qint8" elif input_type in ["qint32", "torch.qint32"]: return "qint32" elif input_type in ["bool", "torch.bool"]: return "bool" else: raise NotImplementedError("input_type {} is not handled yet".format(input_type)) return "float32" # Never reached
Example #25
Source File: test_forward.py From incubator-tvm with Apache License 2.0 | 5 votes |
def test_type_as(): torch.set_grad_enabled(False) input_shape = [1, 3] def _create_module(dtype): class TypeAs(Module): def forward(self, *args): expected_type_tensor = torch.zeros(1, 3, dtype=dtype) return args[0].type_as(expected_type_tensor) return TypeAs() input_data = torch.randn(input_shape).float() verify_model(_create_module(torch.float64), input_data=input_data) verify_model(_create_module(torch.float32), input_data=input_data) verify_model(_create_module(torch.int64), input_data=input_data) verify_model(_create_module(torch.int32), input_data=input_data) verify_model(_create_module(torch.int16), input_data=input_data) verify_model(_create_module(torch.int8), input_data=input_data) if torch.cuda.is_available(): check_fp16 = False try: # Only check half precision on supported hardwares. if have_fp16(tvm.gpu(0).compute_version): check_fp16 = True except Exception as e: # If GPU is not enabled in TVM, skip the fp16 test. pass # Temporary disable fp16 test check_fp16 = False if check_fp16: verify_model(_create_module(torch.float16), input_data=input_data)
Example #26
Source File: test_forward.py From incubator-tvm with Apache License 2.0 | 5 votes |
def test_forward_bitwise_not(): torch.set_grad_enabled(False) class BitwiseNot1(Module): def forward(self, *args): return torch.bitwise_not(args[0]) input_data = torch.tensor([0, 1, -10], dtype=torch.int8) verify_model(BitwiseNot1().float().eval(), input_data=input_data) input_data = torch.tensor([0., 1., -10.], dtype=torch.int32) verify_model(BitwiseNot1().float().eval(), input_data=input_data) input_data = torch.tensor([True, False]) verify_model(BitwiseNot1().float().eval(), input_data=input_data)
Example #27
Source File: torch2trt.py From torch2trt with MIT License | 5 votes |
def torch_dtype_to_trt(dtype): if trt_version() >= '7.0' and dtype == torch.bool: return trt.bool elif dtype == torch.int8: return trt.int8 elif dtype == torch.int32: return trt.int32 elif dtype == torch.float16: return trt.float16 elif dtype == torch.float32: return trt.float32 else: raise TypeError("%s is not supported by tensorrt" % dtype)
Example #28
Source File: torch2trt.py From torch2trt with MIT License | 5 votes |
def torch_dtype_from_trt(dtype): if dtype == trt.int8: return torch.int8 elif trt_version() >= '7.0' and dtype == trt.bool: return torch.bool elif dtype == trt.int32: return torch.int32 elif dtype == trt.float16: return torch.float16 elif dtype == trt.float32: return torch.float32 else: raise TypeError("%s is not supported by torch" % dtype)
Example #29
Source File: test_types.py From heat with MIT License | 5 votes |
def test_type_promotions(self): self.assertEqual(ht.promote_types(ht.uint8, ht.uint8), ht.uint8) self.assertEqual(ht.promote_types(ht.int8, ht.uint8), ht.int16) self.assertEqual(ht.promote_types(ht.int32, ht.float32), ht.float32) self.assertEqual(ht.promote_types("f4", ht.float), ht.float32) self.assertEqual(ht.promote_types(ht.bool_, "?"), ht.bool) # exceptions with self.assertRaises(TypeError): ht.promote_types(1, "?") with self.assertRaises(TypeError): ht.promote_types(ht.float32, "hello world")
Example #30
Source File: tensor.py From dgl with Apache License 2.0 | 5 votes |
def data_type_dict(): return {'float16' : th.float16, 'float32' : th.float32, 'float64' : th.float64, 'uint8' : th.uint8, 'int8' : th.int8, 'int16' : th.int16, 'int32' : th.int32, 'int64' : th.int64, 'bool' : th.bool}