Python torch.half() Examples

The following are 30 code examples of torch.half(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source File: test_compute_adaptive_lr.py    From torchlars with Apache License 2.0 6 votes vote down vote up
def test_when_param_norm_is_zero_with_half():
    param_norm = torch.tensor(0., dtype=torch.half, device='cuda')
    grad_norm = torch.tensor(1., dtype=torch.half, device='cuda')
    adaptive_lr = torch.tensor(0., dtype=torch.half, device='cuda')

    weight_decay = 1.
    eps = 1.
    trust_coef = 1.

    adaptive_lr = compute_adaptive_lr(
        param_norm,
        grad_norm,
        weight_decay,
        eps,
        trust_coef,
        adaptive_lr)

    assert adaptive_lr == torch.tensor(1., dtype=torch.half, device='cuda') 
Example #2
Source File: hooks.py    From mmdetection with Apache License 2.0 6 votes vote down vote up
def wrap_fp16_model(model):
    """Wrap the FP32 model to FP16.

    1. Convert FP32 model to FP16.
    2. Remain some necessary layers to be FP32, e.g., normalization layers.

    Args:
        model (nn.Module): Model in FP32.
    """
    # convert model to fp16
    model.half()
    # patch the normalization layers to make it work in fp32 mode
    patch_norm_fp32(model)
    # set `fp16_enabled` flag
    for m in model.modules():
        if hasattr(m, 'fp16_enabled'):
            m.fp16_enabled = True 
Example #3
Source File: cross_entropy.py    From fairseq with MIT License 6 votes vote down vote up
def cross_entropy(logits, target, ignore_index=-100, reduction='mean'):
        if logits.device == torch.device('cpu'):
            return _cross_entropy_pytorch(logits, target, ignore_index, reduction)
        else:
            half_to_float = (logits.dtype == torch.half)
            losses = xentropy.SoftmaxCrossEntropyLoss.apply(
                logits, target, 0.0, ignore_index, half_to_float,
            )
            if reduction == 'sum':
                return losses.sum()
            elif reduction == 'mean':
                if ignore_index >= 0:
                    return losses.sum() / target.ne(ignore_index).sum()
                else:
                    return losses.mean()
            elif reduction == 'none':
                return losses
            else:
                raise NotImplementedError 
Example #4
Source File: test_rnn.py    From apex with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def run_cell_test(self, cell, state_tuple=False):
        shape = (self.b, self.h)
        for typ in [torch.float, torch.half]:
            xs = [torch.randn(shape, dtype=typ).requires_grad_()
                  for _ in range(self.t)]
            hidden_fn = lambda: torch.zeros(shape, dtype=typ)
            if state_tuple:
                hidden = (hidden_fn(), hidden_fn())
            else:
                hidden = hidden_fn()
            outputs = []
            for i in range(self.t):
                hidden = cell(xs[i], hidden)
                if state_tuple:
                    output = hidden[0]
                else:
                    output = hidden
                outputs.append(output)
            for y in outputs:
                self.assertEqual(y.type(), HALF)
            outputs[-1].float().sum().backward()
            for i, x in enumerate(xs):
                self.assertEqual(x.grad.dtype, x.dtype) 
Example #5
Source File: types.py    From chainer-compiler with MIT License 6 votes vote down vote up
def torch_dtype_to_np_dtype(dtype):
    dtype_dict = {
            torch.bool    : np.dtype(np.bool),
            torch.uint8   : np.dtype(np.uint8),
            torch.int8    : np.dtype(np.int8),
            torch.int16   : np.dtype(np.int16),
            torch.short   : np.dtype(np.int16),
            torch.int32   : np.dtype(np.int32),
            torch.int     : np.dtype(np.int32),
            torch.int64   : np.dtype(np.int64),
            torch.long    : np.dtype(np.int64),
            torch.float16 : np.dtype(np.float16),
            torch.half    : np.dtype(np.float16),
            torch.float32 : np.dtype(np.float32),
            torch.float   : np.dtype(np.float32),
            torch.float64 : np.dtype(np.float64),
            torch.double  : np.dtype(np.float64),
            }
    return dtype_dict[dtype]


# ---------------------- InferenceEngine internal types ------------------------ 
Example #6
Source File: test_rnn.py    From apex with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def test_rnn_packed_sequence(self):
        num_layers = 2
        rnn = nn.RNN(input_size=self.h, hidden_size=self.h, num_layers=num_layers)
        for typ in [torch.float, torch.half]:
            x = torch.randn((self.t, self.b, self.h), dtype=typ).requires_grad_()
            lens = sorted([random.randint(self.t // 2, self.t) for _ in range(self.b)],
                          reverse=True)
            # `pack_padded_sequence` breaks if default tensor type is non-CPU
            torch.set_default_tensor_type(torch.FloatTensor)
            lens = torch.tensor(lens, dtype=torch.int64, device=torch.device('cpu'))
            packed_seq = nn.utils.rnn.pack_padded_sequence(x, lens)
            torch.set_default_tensor_type(torch.cuda.FloatTensor)
            hidden = torch.zeros((num_layers, self.b, self.h), dtype=typ)
            output, _ = rnn(packed_seq, hidden)
            self.assertEqual(output.data.type(), HALF)
            output.data.float().sum().backward()
            self.assertEqual(x.grad.dtype, x.dtype) 
Example #7
Source File: gan.py    From torchsupport with MIT License 6 votes vote down vote up
def _mix_on_path(real, fake):
  result = None
  if isinstance(real, (list, tuple)):
    result = [
      _mix_on_path(real_part, fake_part)
      for real_part, fake_part in zip(real, fake)
    ]
  elif isinstance(real, dict):
    result = {
      key: _mix_on_path(real[key], fake[key])
      for key in real
    }
  elif isinstance(real, torch.Tensor):
    if real.dtype in (torch.half, torch.float, torch.double):
      result = _mix_on_path_aux(real, fake)
    else:
      result = random.choice([real, fake])
  else:
    result = random.choice([real, fake])
  return result 
Example #8
Source File: test_roi_align.py    From mmcv with Apache License 2.0 6 votes vote down vote up
def _test_roialign_gradcheck(device, dtype):
    if not torch.cuda.is_available() and device == 'cuda':
        pytest.skip('test requires GPU')
    try:
        from mmcv.ops import RoIAlign
    except ModuleNotFoundError:
        pytest.skip('RoIAlign op is not successfully compiled')
    if dtype is torch.half:
        pytest.skip('grad check does not support fp16')
    for case in inputs:
        np_input = np.array(case[0])
        np_rois = np.array(case[1])

        x = torch.tensor(
            np_input, dtype=dtype, device=device, requires_grad=True)
        rois = torch.tensor(np_rois, dtype=dtype, device=device)

        froipool = RoIAlign((pool_h, pool_w), spatial_scale, sampling_ratio)

        gradcheck(froipool, (x, rois), eps=1e-5, atol=1e-5) 
Example #9
Source File: cross_entropy.py    From attn2d with MIT License 6 votes vote down vote up
def cross_entropy(logits, target, ignore_index=-100, reduction='mean'):
        if logits.device == torch.device('cpu'):
            return _cross_entropy_pytorch(logits, target, ignore_index, reduction)
        else:
            half_to_float = (logits.dtype == torch.half)
            losses = xentropy.SoftmaxCrossEntropyLoss.apply(
                logits, target, 0.0, ignore_index, half_to_float,
            )
            if reduction == 'sum':
                return losses.sum()
            elif reduction == 'mean':
                if ignore_index >= 0:
                    return losses.sum() / target.ne(ignore_index).sum()
                else:
                    return losses.mean()
            elif reduction == 'none':
                return losses
            else:
                raise NotImplementedError 
Example #10
Source File: waveglow.py    From NeMo with Apache License 2.0 6 votes vote down vote up
def forward(self, z, reverse: bool = False):
        # shape
        batch_size, group_size, n_of_groups = z.size()

        W = self.conv.weight.squeeze()

        if reverse:
            if not hasattr(self, 'W_inverse'):
                # Reverse computation
                W_inverse = W.float().inverse()
                W_inverse = Variable(W_inverse[..., None])
                if z.dtype == torch.half:
                    W_inverse = W_inverse.half()
                self.W_inverse = W_inverse
            z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0)
            return z
        else:
            # Forward computation
            log_det_W = batch_size * n_of_groups * torch.logdet(W.float())
            z = self.conv(z)
            return (
                z,
                log_det_W,
            ) 
Example #11
Source File: modeling.py    From VLP with Apache License 2.0 6 votes vote down vote up
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
        _, pooled_output = self.bert(
            input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        if labels is not None:
            if labels.dtype == torch.long:
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(
                    logits.view(-1, self.num_labels), labels.view(-1))
            elif labels.dtype == torch.half or labels.dtype == torch.float:
                loss_fct = MSELoss()
                loss = loss_fct(logits.view(-1), labels.view(-1))
            else:
                print('unkown labels.dtype')
                loss = None
            return loss
        else:
            return logits 
Example #12
Source File: modeling.py    From VLP with Apache License 2.0 6 votes vote down vote up
def __init__(self, config, bert_model_embedding_weights):
        super(BertLMPredictionHead, self).__init__()
        self.transform = BertPredictionHeadTransform(config)

        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
        self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
                                 bert_model_embedding_weights.size(0),
                                 bias=False)
        self.decoder.weight = bert_model_embedding_weights
        self.bias = nn.Parameter(torch.zeros(
            bert_model_embedding_weights.size(0)))
        if hasattr(config, 'relax_projection') and (config.relax_projection > 1):
            self.relax_projection = config.relax_projection
        else:
            self.relax_projection = 0
        self.fp32_embedding = config.fp32_embedding

        def convert_to_type(tensor):
            if self.fp32_embedding:
                return tensor.half()
            else:
                return tensor
        self.type_converter = convert_to_type
        self.converted = False 
Example #13
Source File: hooks.py    From mmdetection with Apache License 2.0 6 votes vote down vote up
def patch_norm_fp32(module):
    """Recursively convert normalization layers from FP16 to FP32.

    Args:
        module (nn.Module): The modules to be converted in FP16.

    Returns:
        nn.Module: The converted module, the normalization layers have been
            converted to FP32.
    """
    if isinstance(module, (nn.modules.batchnorm._BatchNorm, nn.GroupNorm)):
        module.float()
        if isinstance(module, nn.GroupNorm) or torch.__version__ < '1.3':
            module.forward = patch_forward_method(module.forward, torch.half,
                                                  torch.float)
    for child in module.children():
        patch_norm_fp32(child)
    return module 
Example #14
Source File: test_compute_adaptive_lr.py    From torchlars with Apache License 2.0 6 votes vote down vote up
def test_when_grad_norm_is_zero_with_half():
    param_norm = torch.tensor(1., dtype=torch.half, device='cuda')
    grad_norm = torch.tensor(0., dtype=torch.half, device='cuda')
    adaptive_lr = torch.tensor(0., dtype=torch.half, device='cuda')

    weight_decay = 1.
    eps = 1.
    trust_coef = 1.

    adaptive_lr = compute_adaptive_lr(
        param_norm,
        grad_norm,
        weight_decay,
        eps,
        trust_coef,
        adaptive_lr)

    assert adaptive_lr == torch.tensor(1., dtype=torch.half, device='cuda') 
Example #15
Source File: test_compute_adaptive_lr.py    From torchlars with Apache License 2.0 6 votes vote down vote up
def test_specific_case_with_half():
    param_norm = torch.tensor(1.234, dtype=torch.half, device='cuda')
    grad_norm = torch.tensor(5.678, dtype=torch.half, device='cuda')
    adaptive_lr = torch.tensor(0., dtype=torch.half, device='cuda')

    weight_decay = 1e-4
    eps = 1e-8
    trust_coef = 0.001

    adaptive_lr = compute_adaptive_lr(
        param_norm,
        grad_norm,
        weight_decay,
        eps,
        trust_coef,
        adaptive_lr)

    assert torch.allclose(adaptive_lr, torch.tensor(0.000217325, dtype=torch.half, device='cuda')) 
Example #16
Source File: test_utils.py    From pytorch-tools with MIT License 6 votes vote down vote up
def test_box2delta(device_dtype):
    ## this test only checks that encoding and decoding  gives the same result
    device, dtype = device_dtype
    boxes = random_boxes([10, 10, 20, 20], 10, 10).to(device).to(dtype)
    anchors = random_boxes([10, 10, 20, 20], 10, 10).to(device).to(dtype)
    deltas = pt.utils.box.box2delta(boxes, anchors)
    boxes_reconstructed = pt.utils.box.delta2box(deltas, anchors)
    atol = 2e-2 if dtype == torch.half else 1e-6  # for fp16 sometimes error is large
    assert torch.allclose(boxes, boxes_reconstructed, atol=atol)

    # check that it's jit friendly
    jit_box2delta = torch.jit.script(pt.utils.box.box2delta)
    jit_delta2box = torch.jit.script(pt.utils.box.delta2box)
    deltas2 = jit_box2delta(boxes, anchors)
    boxes_reconstructed2 = jit_delta2box(deltas2, anchors)
    assert torch.allclose(boxes, boxes_reconstructed2, atol=atol) 
Example #17
Source File: modeling.py    From unilm with MIT License 6 votes vote down vote up
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, mask_qkv=None, task_idx=None):
        _, pooled_output = self.bert(
            input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, mask_qkv=mask_qkv, task_idx=task_idx)
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        if labels is not None:
            if labels.dtype == torch.long:
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(
                    logits.view(-1, self.num_labels), labels.view(-1))
            elif labels.dtype == torch.half or labels.dtype == torch.float:
                loss_fct = MSELoss()
                loss = loss_fct(logits.view(-1), labels.view(-1))
            else:
                print('unkown labels.dtype')
                loss = None
            return loss
        else:
            return logits 
Example #18
Source File: modeling.py    From unilm with MIT License 6 votes vote down vote up
def forward(self, hidden_states, task_idx=None):
        if not self.converted:
            self.converted = True
            if self.fp32_embedding:
                self.transform.half()
        hidden_states = self.transform(self.type_converter(hidden_states))
        if self.relax_projection > 1:
            num_batch = hidden_states.size(0)
            num_pos = hidden_states.size(1)
            # (batch, num_pos, relax_projection*hid) -> (batch, num_pos, relax_projection, hid) -> (batch, num_pos, hid)
            hidden_states = hidden_states.view(
                num_batch, num_pos, self.relax_projection, -1)[torch.arange(0, num_batch).long(), :, task_idx, :]
        if self.fp32_embedding:
            hidden_states = F.linear(self.type_converter(hidden_states), self.type_converter(
                self.decoder.weight), self.type_converter(self.bias))
        else:
            hidden_states = self.decoder(hidden_states) + self.bias
        return hidden_states 
Example #19
Source File: modeling.py    From unilm with MIT License 6 votes vote down vote up
def __init__(self, config, bert_model_embedding_weights):
        super(BertLMPredictionHead, self).__init__()
        self.transform = BertPredictionHeadTransform(config)

        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
        self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
                                 bert_model_embedding_weights.size(0),
                                 bias=False)
        self.decoder.weight = bert_model_embedding_weights
        self.bias = nn.Parameter(torch.zeros(
            bert_model_embedding_weights.size(0)))
        if hasattr(config, 'relax_projection') and (config.relax_projection > 1):
            self.relax_projection = config.relax_projection
        else:
            self.relax_projection = 0
        self.fp32_embedding = config.fp32_embedding

        def convert_to_type(tensor):
            if self.fp32_embedding:
                return tensor.half()
            else:
                return tensor
        self.type_converter = convert_to_type
        self.converted = False 
Example #20
Source File: modeling.py    From unilm with MIT License 6 votes vote down vote up
def forward(self, input_ids, token_type_ids=None, position_ids=None, task_idx=None):
        seq_length = input_ids.size(1)
        if position_ids is None:
            position_ids = torch.arange(
                seq_length, dtype=torch.long, device=input_ids.device)
            position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        if self.num_pos_emb > 1:
            num_batch = position_embeddings.size(0)
            num_pos = position_embeddings.size(1)
            position_embeddings = position_embeddings.view(
                num_batch, num_pos, self.num_pos_emb, -1)[torch.arange(0, num_batch).long(), :, task_idx, :]

        embeddings = words_embeddings + position_embeddings + token_type_embeddings
        if self.fp32_embedding:
            embeddings = embeddings.half()
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings 
Example #21
Source File: modeling.py    From VLP with Apache License 2.0 6 votes vote down vote up
def forward(self, vis_feats, vis_pe, input_ids, token_type_ids=None, position_ids=None, vis_input=True, len_vis_input=49):
        seq_length = input_ids.size(1)
        if position_ids is None:
            position_ids = torch.arange(
                seq_length, dtype=torch.long, device=input_ids.device)
            position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        if vis_input:
            words_embeddings = torch.cat((words_embeddings[:, :1], vis_feats,
                words_embeddings[:, len_vis_input+1:]), dim=1)
            assert len_vis_input == 100, 'only support region attn!'
            position_embeddings = torch.cat((position_embeddings[:, :1], vis_pe,
                position_embeddings[:, len_vis_input+1:]), dim=1) # hacky...
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        embeddings = words_embeddings + position_embeddings + token_type_embeddings
        if self.fp32_embedding:
            embeddings = embeddings.half()
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings 
Example #22
Source File: modeling.py    From VLP with Apache License 2.0 6 votes vote down vote up
def forward(self, hidden_states, task_idx=None):
        if not self.converted:
            self.converted = True
            if self.fp32_embedding:
                self.transform.half()
        hidden_states = self.transform(self.type_converter(hidden_states))
        if self.relax_projection > 1:
            num_batch = hidden_states.size(0)
            num_pos = hidden_states.size(1)
            # (batch, num_pos, relax_projection*hid) -> (batch, num_pos, relax_projection, hid) -> (batch, num_pos, hid)
            hidden_states = hidden_states.view(
                num_batch, num_pos, self.relax_projection, -1)[torch.arange(0, num_batch).long(), :, task_idx, :]
        if self.fp32_embedding:
            hidden_states = F.linear(self.type_converter(hidden_states), self.type_converter(
                self.decoder.weight), self.type_converter(self.bias))
        else:
            hidden_states = self.decoder(hidden_states) + self.bias
        return hidden_states 
Example #23
Source File: device_dtype_mixin.py    From pytorch-lightning with Apache License 2.0 5 votes vote down vote up
def half(self) -> Module:
        """Casts all floating point parameters and buffers to ``half`` datatype.

        Returns:
            Module: self
        """
        self._dtype = torch.half
        return super().half() 
Example #24
Source File: performance_test.py    From torchfunc with MIT License 5 votes vote down vote up
def test_report():
    goal = r"""
===========================GENERAL TIPS===========================

- Make sure you are running newest PyTorch version. See available releases: https://github.com/pytorch/pytorch/tags
- Use GPU for larger batches, CPU might be suitable for smaller jobs.
- Use mixed-precision training on GPU, preferably automated, e.g. NVIDIA Apex: https://github.com/NVIDIA/apex.

===========================SPECIFIC TIPS===========================

=======> Module should be an instance of torch.jit.ScriptModule.
See https://pytorch.org/docs/stable/jit.html for more information.
=======> NVIDIA's Apex is not installed. It is the easiest way to use mixed precision training.
See https://github.com/NVIDIA/apex for more information and installation.
=======> In-place operations might harm kernel fusion. Indices of those modules:
[3, 5]
You may want to remove inplace flag (see this issue: https://github.com/pytorch/pytorch/issues/23655)
=======> Depthwise convolutions are not currently using specialized kernel and might be slower.
See this issue: https://github.com/pytorch/pytorch/issues/18631 for more information.
Indices of those modules:
[4]
You may want to decrease number of groups (like it's done for ResNeXt) for possible speed & accuracy improvements.
=======> TensorCores incompatible modules:
Modules where float type is not torch.half:
[2, 4, 6, 8, 10]
Modules where inputs shape should be divisible by 8:
[2, 8]
Modules where outputs shape should be divisible by 8:
[6, 10]"""
    tips = torchfunc.performance.tips(Model())
    print(tips)
    assert tips == goal 
Example #25
Source File: hooks.py    From CenterNet with Apache License 2.0 5 votes vote down vote up
def patch_norm_fp32(module):
    if isinstance(module, (nn.modules.batchnorm._BatchNorm, nn.GroupNorm)):
        module.float()
        module.forward = patch_forward_method(module.forward, torch.half,
                                              torch.float)
    for child in module.children():
        patch_norm_fp32(child)
    return module 
Example #26
Source File: hooks.py    From Libra_R-CNN with Apache License 2.0 5 votes vote down vote up
def patch_norm_fp32(module):
    if isinstance(module, (nn.modules.batchnorm._BatchNorm, nn.GroupNorm)):
        module.float()
        module.forward = patch_forward_method(module.forward, torch.half,
                                              torch.float)
    for child in module.children():
        patch_norm_fp32(child)
    return module 
Example #27
Source File: hooks.py    From Libra_R-CNN with Apache License 2.0 5 votes vote down vote up
def wrap_fp16_model(model):
    # convert model to fp16
    model.half()
    # patch the normalization layers to make it work in fp32 mode
    patch_norm_fp32(model)
    # set `fp16_enabled` flag
    for m in model.modules():
        if hasattr(m, 'fp16_enabled'):
            m.fp16_enabled = True 
Example #28
Source File: test_rnn.py    From apex with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def run_rnn_test(self, rnn, layers, bidir, state_tuple=False):
        for typ in [torch.float, torch.half]:
            x = torch.randn((self.t, self.b, self.h), dtype=typ).requires_grad_()
            hidden_fn = lambda: torch.zeros((layers + (layers * bidir),
                                             self.b, self.h), dtype=typ)
            if state_tuple:
                hidden = (hidden_fn(), hidden_fn())
            else:
                hidden = hidden_fn()
            output, _ = rnn(x, hidden)
            self.assertEqual(output.type(), HALF)
            output[-1, :, :].float().sum().backward()
            self.assertEqual(x.grad.dtype, x.dtype) 
Example #29
Source File: test_promotion.py    From apex with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_inplace_exp_is_error_for_half(self):
        xs = torch.randn(self.b)
        xs.exp_()
        self.assertEqual(xs.type(), FLOAT)
        xs = torch.randn(self.b, dtype=torch.half)
        with self.assertRaises(NotImplementedError):
            xs.exp_() 
Example #30
Source File: test_promotion.py    From apex with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_cat_matches_widest(self):
        shape = self.b
        ys = [torch.randn(shape, dtype=torch.half) for _ in range(5)]
        x_float = torch.randn(shape)
        out = torch.cat(ys + [x_float])
        self.assertEqual(out.type(), FLOAT)
        x_half = torch.randn(shape, dtype=torch.half)
        out = torch.cat(ys + [x_half])
        self.assertEqual(out.type(), HALF)