Python torch.isinf() Examples

The following are 30 code examples of torch.isinf(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source File: training.py    From TTS with Mozilla Public License 2.0 6 votes vote down vote up
def check_update(model, grad_clip, ignore_stopnet=False):
    r'''Check model gradient against unexpected jumps and failures'''
    skip_flag = False
    if ignore_stopnet:
        grad_norm = torch.nn.utils.clip_grad_norm_([param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip)
    else:
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    # compatibility with different torch versions
    if isinstance(grad_norm, float):
        if np.isinf(grad_norm):
            print(" | > Gradient is INF !!")
            skip_flag = True
    else:
        if torch.isinf(grad_norm):
            print(" | > Gradient is INF !!")
            skip_flag = True
    return grad_norm, skip_flag 
Example #2
Source File: base_test.py    From nsf with MIT License 6 votes vote down vote up
def test_sample(self):
        num_samples = 10
        input_shape = [2, 3, 4]
        latent_shape = [5, 6]

        prior = distributions.StandardNormal(latent_shape)
        approximate_posterior = distributions.StandardNormal(latent_shape)
        likelihood = distributions.StandardNormal(input_shape)
        vae = base.VariationalAutoencoder(prior, approximate_posterior, likelihood)

        for mean in [True, False]:
            with self.subTest(mean=mean):
                samples = vae.sample(num_samples, mean=mean)
                self.assertIsInstance(samples, torch.Tensor)
                self.assertFalse(torch.isnan(samples).any())
                self.assertFalse(torch.isinf(samples).any())
                self.assertEqual(samples.shape, torch.Size([num_samples] + input_shape)) 
Example #3
Source File: base_test.py    From nsf with MIT License 6 votes vote down vote up
def test_stochastic_elbo(self):
        batch_size = 10
        input_shape = [2, 3, 4]
        latent_shape = [5, 6]

        prior = distributions.StandardNormal(latent_shape)
        approximate_posterior = distributions.StandardNormal(latent_shape)
        likelihood = distributions.StandardNormal(input_shape)
        vae = base.VariationalAutoencoder(prior, approximate_posterior, likelihood)

        inputs = torch.randn(batch_size, *input_shape)
        for num_samples in [1, 10, 100]:
            with self.subTest(num_samples=num_samples):
                elbo = vae.stochastic_elbo(inputs, num_samples)
                self.assertIsInstance(elbo, torch.Tensor)
                self.assertFalse(torch.isnan(elbo).any())
                self.assertFalse(torch.isinf(elbo).any())
                self.assertEqual(elbo.shape, torch.Size([batch_size])) 
Example #4
Source File: training.py    From TTS with Mozilla Public License 2.0 6 votes vote down vote up
def check_update(model, grad_clip, ignore_stopnet=False):
    r'''Check model gradient against unexpected jumps and failures'''
    skip_flag = False
    if ignore_stopnet:
        grad_norm = torch.nn.utils.clip_grad_norm_([param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip)
    else:
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    # compatibility with different torch versions
    if isinstance(grad_norm, float):
        if np.isinf(grad_norm):
            print(" | > Gradient is INF !!")
            skip_flag = True
    else:
        if torch.isinf(grad_norm):
            print(" | > Gradient is INF !!")
            skip_flag = True
    return grad_norm, skip_flag 
Example #5
Source File: training.py    From TTS with Mozilla Public License 2.0 6 votes vote down vote up
def check_update(model, grad_clip, ignore_stopnet=False):
    r'''Check model gradient against unexpected jumps and failures'''
    skip_flag = False
    if ignore_stopnet:
        grad_norm = torch.nn.utils.clip_grad_norm_([param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip)
    else:
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    # compatibility with different torch versions
    if isinstance(grad_norm, float):
        if np.isinf(grad_norm):
            print(" | > Gradient is INF !!")
            skip_flag = True
    else:
        if torch.isinf(grad_norm):
            print(" | > Gradient is INF !!")
            skip_flag = True
    return grad_norm, skip_flag 
Example #6
Source File: training.py    From TTS with Mozilla Public License 2.0 6 votes vote down vote up
def check_update(model, grad_clip, ignore_stopnet=False):
    r'''Check model gradient against unexpected jumps and failures'''
    skip_flag = False
    if ignore_stopnet:
        grad_norm = torch.nn.utils.clip_grad_norm_([param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip)
    else:
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    # compatibility with different torch versions
    if isinstance(grad_norm, float):
        if np.isinf(grad_norm):
            print(" | > Gradient is INF !!")
            skip_flag = True
    else:
        if torch.isinf(grad_norm):
            print(" | > Gradient is INF !!")
            skip_flag = True
    return grad_norm, skip_flag 
Example #7
Source File: discrete_test.py    From nsf with MIT License 6 votes vote down vote up
def test_sample_and_log_prob_with_context(self):
        num_samples = 10
        context_size = 20
        input_shape = [2, 3, 4]
        context_shape = [2, 3, 4]

        dist = discrete.ConditionalIndependentBernoulli(input_shape)
        context = torch.randn(context_size, *context_shape)
        samples, log_prob = dist.sample_and_log_prob(num_samples, context=context)

        self.assertIsInstance(samples, torch.Tensor)
        self.assertIsInstance(log_prob, torch.Tensor)

        self.assertEqual(samples.shape, torch.Size([context_size, num_samples] + input_shape))
        self.assertEqual(log_prob.shape, torch.Size([context_size, num_samples]))

        self.assertFalse(torch.isnan(log_prob).any())
        self.assertFalse(torch.isinf(log_prob).any())
        self.assert_tensor_less_equal(log_prob, 0.0)

        self.assertFalse(torch.isnan(samples).any())
        self.assertFalse(torch.isinf(samples).any())
        binary = (samples == 1.0) | (samples == 0.0)
        self.assertEqual(binary, torch.ones_like(binary)) 
Example #8
Source File: prognn.py    From DeepRobust with MIT License 6 votes vote down vote up
def feature_smoothing(self, adj, X):
        adj = (adj.t() + adj)/2
        rowsum = adj.sum(1)
        r_inv = rowsum.flatten()
        D = torch.diag(r_inv)
        L = D - adj

        r_inv = r_inv  + 1e-3
        r_inv = r_inv.pow(-1/2).flatten()
        r_inv[torch.isinf(r_inv)] = 0.
        r_mat_inv = torch.diag(r_inv)
        # L = r_mat_inv @ L
        L = r_mat_inv @ L @ r_mat_inv

        XLXT = torch.matmul(torch.matmul(X.t(), L), X)
        loss_smooth_feat = torch.trace(XLXT)
        return loss_smooth_feat 
Example #9
Source File: training.py    From TTS with Mozilla Public License 2.0 6 votes vote down vote up
def check_update(model, grad_clip, ignore_stopnet=False):
    r'''Check model gradient against unexpected jumps and failures'''
    skip_flag = False
    if ignore_stopnet:
        grad_norm = torch.nn.utils.clip_grad_norm_([param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip)
    else:
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    # compatibility with different torch versions
    if isinstance(grad_norm, float):
        if np.isinf(grad_norm):
            print(" | > Gradient is INF !!")
            skip_flag = True
    else:
        if torch.isinf(grad_norm):
            print(" | > Gradient is INF !!")
            skip_flag = True
    return grad_norm, skip_flag 
Example #10
Source File: training.py    From TTS with Mozilla Public License 2.0 6 votes vote down vote up
def check_update(model, grad_clip, ignore_stopnet=False):
    r'''Check model gradient against unexpected jumps and failures'''
    skip_flag = False
    if ignore_stopnet:
        grad_norm = torch.nn.utils.clip_grad_norm_([param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip)
    else:
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    # compatibility with different torch versions
    if isinstance(grad_norm, float):
        if np.isinf(grad_norm):
            print(" | > Gradient is INF !!")
            skip_flag = True
    else:
        if torch.isinf(grad_norm):
            print(" | > Gradient is INF !!")
            skip_flag = True
    return grad_norm, skip_flag 
Example #11
Source File: base.py    From torch-kalman with MIT License 6 votes vote down vote up
def _validate(self):
        if self.means.dim() != 2:
            raise ValueError("means should be 2D (first dimension batch-size)")
        if self.covs.dim() != 3:
            raise ValueError("covs should be 3D (first dimension batch-size)")
        if torch.isinf(self.means).any():
            raise ValueError("Infs in `means`.")
        if torch.isinf(self.covs).any():
            raise ValueError("Infs in `covs`.")
        if torch.isnan(self.means).any():
            raise ValueError("nans in `means`.")
        if torch.isnan(self.covs).any():
            raise ValueError("nans in `covs`.")
        if self.covs.shape[0] != self.means.shape[0]:
            raise ValueError("The batch-size (1st dimension) of cov doesn't match that of mean.")
        if self.covs.shape[1] != self.covs.shape[2]:
            raise ValueError("The cov should be symmetric in the last two dimensions.")
        if self.covs.shape[1] != self.means.shape[1]:
            raise ValueError("The state-size (2nd/3rd dimension) of cov doesn't match that of mean.")
        if self.last_measured.shape[0] != self.num_groups or self.last_measured.dim() != 1:
            raise ValueError(f"`last_measured` should be 1D tensor w/length of {self.num_groups:,}.") 
Example #12
Source File: training.py    From TTS with Mozilla Public License 2.0 6 votes vote down vote up
def check_update(model, grad_clip, ignore_stopnet=False):
    r'''Check model gradient against unexpected jumps and failures'''
    skip_flag = False
    if ignore_stopnet:
        grad_norm = torch.nn.utils.clip_grad_norm_([param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip)
    else:
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    # compatibility with different torch versions
    if isinstance(grad_norm, float):
        if np.isinf(grad_norm):
            print(" | > Gradient is INF !!")
            skip_flag = True
    else:
        if torch.isinf(grad_norm):
            print(" | > Gradient is INF !!")
            skip_flag = True
    return grad_norm, skip_flag 
Example #13
Source File: training.py    From TTS with Mozilla Public License 2.0 6 votes vote down vote up
def check_update(model, grad_clip, ignore_stopnet=False):
    r'''Check model gradient against unexpected jumps and failures'''
    skip_flag = False
    if ignore_stopnet:
        grad_norm = torch.nn.utils.clip_grad_norm_([param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip)
    else:
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    # compatibility with different torch versions
    if isinstance(grad_norm, float):
        if np.isinf(grad_norm):
            print(" | > Gradient is INF !!")
            skip_flag = True
    else:
        if torch.isinf(grad_norm):
            print(" | > Gradient is INF !!")
            skip_flag = True
    return grad_norm, skip_flag 
Example #14
Source File: training.py    From TTS with Mozilla Public License 2.0 6 votes vote down vote up
def check_update(model, grad_clip, ignore_stopnet=False):
    r'''Check model gradient against unexpected jumps and failures'''
    skip_flag = False
    if ignore_stopnet:
        grad_norm = torch.nn.utils.clip_grad_norm_([param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip)
    else:
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    # compatibility with different torch versions
    if isinstance(grad_norm, float):
        if np.isinf(grad_norm):
            print(" | > Gradient is INF !!")
            skip_flag = True
    else:
        if torch.isinf(grad_norm):
            print(" | > Gradient is INF !!")
            skip_flag = True
    return grad_norm, skip_flag 
Example #15
Source File: training.py    From TTS with Mozilla Public License 2.0 6 votes vote down vote up
def check_update(model, grad_clip, ignore_stopnet=False):
    r'''Check model gradient against unexpected jumps and failures'''
    skip_flag = False
    if ignore_stopnet:
        grad_norm = torch.nn.utils.clip_grad_norm_([param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip)
    else:
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    # compatibility with different torch versions
    if isinstance(grad_norm, float):
        if np.isinf(grad_norm):
            print(" | > Gradient is INF !!")
            skip_flag = True
    else:
        if torch.isinf(grad_norm):
            print(" | > Gradient is INF !!")
            skip_flag = True
    return grad_norm, skip_flag 
Example #16
Source File: utils.py    From DeepRobust with MIT License 6 votes vote down vote up
def degree_normalize_adj_tensor(adj, sparse=True):
    """degree_normalize_adj_tensor.
    """

    device = torch.device("cuda" if adj.is_cuda else "cpu")
    if sparse:
        # return  degree_normalize_sparse_tensor(adj)
        adj = to_scipy(adj)
        mx = degree_normalize_adj(adj)
        return sparse_mx_to_torch_sparse_tensor(mx).to(device)
    else:
        mx = adj + torch.eye(adj.shape[0]).to(device)
        rowsum = mx.sum(1)
        r_inv = rowsum.pow(-1).flatten()
        r_inv[torch.isinf(r_inv)] = 0.
        r_mat_inv = torch.diag(r_inv)
        mx = r_mat_inv @ mx
    return mx 
Example #17
Source File: training.py    From TTS with Mozilla Public License 2.0 6 votes vote down vote up
def check_update(model, grad_clip, ignore_stopnet=False):
    r'''Check model gradient against unexpected jumps and failures'''
    skip_flag = False
    if ignore_stopnet:
        grad_norm = torch.nn.utils.clip_grad_norm_([param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip)
    else:
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    # compatibility with different torch versions
    if isinstance(grad_norm, float):
        if np.isinf(grad_norm):
            print(" | > Gradient is INF !!")
            skip_flag = True
    else:
        if torch.isinf(grad_norm):
            print(" | > Gradient is INF !!")
            skip_flag = True
    return grad_norm, skip_flag 
Example #18
Source File: training.py    From TTS with Mozilla Public License 2.0 6 votes vote down vote up
def check_update(model, grad_clip, ignore_stopnet=False):
    r'''Check model gradient against unexpected jumps and failures'''
    skip_flag = False
    if ignore_stopnet:
        grad_norm = torch.nn.utils.clip_grad_norm_([param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip)
    else:
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    # compatibility with different torch versions
    if isinstance(grad_norm, float):
        if np.isinf(grad_norm):
            print(" | > Gradient is INF !!")
            skip_flag = True
    else:
        if torch.isinf(grad_norm):
            print(" | > Gradient is INF !!")
            skip_flag = True
    return grad_norm, skip_flag 
Example #19
Source File: training.py    From TTS with Mozilla Public License 2.0 6 votes vote down vote up
def check_update(model, grad_clip, ignore_stopnet=False):
    r'''Check model gradient against unexpected jumps and failures'''
    skip_flag = False
    if ignore_stopnet:
        grad_norm = torch.nn.utils.clip_grad_norm_([param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip)
    else:
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    # compatibility with different torch versions
    if isinstance(grad_norm, float):
        if np.isinf(grad_norm):
            print(" | > Gradient is INF !!")
            skip_flag = True
    else:
        if torch.isinf(grad_norm):
            print(" | > Gradient is INF !!")
            skip_flag = True
    return grad_norm, skip_flag 
Example #20
Source File: training.py    From TTS with Mozilla Public License 2.0 6 votes vote down vote up
def check_update(model, grad_clip, ignore_stopnet=False):
    r'''Check model gradient against unexpected jumps and failures'''
    skip_flag = False
    if ignore_stopnet:
        grad_norm = torch.nn.utils.clip_grad_norm_([param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip)
    else:
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    # compatibility with different torch versions
    if isinstance(grad_norm, float):
        if np.isinf(grad_norm):
            print(" | > Gradient is INF !!")
            skip_flag = True
    else:
        if torch.isinf(grad_norm):
            print(" | > Gradient is INF !!")
            skip_flag = True
    return grad_norm, skip_flag 
Example #21
Source File: training.py    From TTS with Mozilla Public License 2.0 6 votes vote down vote up
def check_update(model, grad_clip, ignore_stopnet=False):
    r'''Check model gradient against unexpected jumps and failures'''
    skip_flag = False
    if ignore_stopnet:
        grad_norm = torch.nn.utils.clip_grad_norm_([param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip)
    else:
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    # compatibility with different torch versions
    if isinstance(grad_norm, float):
        if np.isinf(grad_norm):
            print(" | > Gradient is INF !!")
            skip_flag = True
    else:
        if torch.isinf(grad_norm):
            print(" | > Gradient is INF !!")
            skip_flag = True
    return grad_norm, skip_flag 
Example #22
Source File: training.py    From TTS with Mozilla Public License 2.0 6 votes vote down vote up
def check_update(model, grad_clip, ignore_stopnet=False):
    r'''Check model gradient against unexpected jumps and failures'''
    skip_flag = False
    if ignore_stopnet:
        grad_norm = torch.nn.utils.clip_grad_norm_([param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip)
    else:
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    # compatibility with different torch versions
    if isinstance(grad_norm, float):
        if np.isinf(grad_norm):
            print(" | > Gradient is INF !!")
            skip_flag = True
    else:
        if torch.isinf(grad_norm):
            print(" | > Gradient is INF !!")
            skip_flag = True
    return grad_norm, skip_flag 
Example #23
Source File: training.py    From TTS with Mozilla Public License 2.0 6 votes vote down vote up
def check_update(model, grad_clip, ignore_stopnet=False):
    r'''Check model gradient against unexpected jumps and failures'''
    skip_flag = False
    if ignore_stopnet:
        grad_norm = torch.nn.utils.clip_grad_norm_([param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip)
    else:
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    # compatibility with different torch versions
    if isinstance(grad_norm, float):
        if np.isinf(grad_norm):
            print(" | > Gradient is INF !!")
            skip_flag = True
    else:
        if torch.isinf(grad_norm):
            print(" | > Gradient is INF !!")
            skip_flag = True
    return grad_norm, skip_flag 
Example #24
Source File: mlp_test.py    From nsf with MIT License 6 votes vote down vote up
def test_forward(self):
        batch_size = 10
        in_shape = [2, 3, 4]
        out_shape = [5, 6]
        inputs = torch.randn(batch_size, *in_shape)

        for hidden_sizes in [[20], [20, 30], [20, 30, 40]]:
            with self.subTest(hidden_sizes=hidden_sizes):
                model = mlp.MLP(
                    in_shape=in_shape,
                    out_shape=out_shape,
                    hidden_sizes=hidden_sizes,
                )
                outputs = model(inputs)
                self.assertIsInstance(outputs, torch.Tensor)
                self.assertEqual(outputs.shape, torch.Size([batch_size] + out_shape))
                self.assertFalse(torch.isnan(outputs).any())
                self.assertFalse(torch.isinf(outputs).any())

        with self.assertRaises(Exception):
            mlp.MLP(
                in_shape=in_shape,
                out_shape=out_shape,
                hidden_sizes=[],
            ) 
Example #25
Source File: search_model_gdas.py    From AutoDL-Projects with MIT License 6 votes vote down vote up
def forward(self, inputs):
    while True:
      gumbels = -torch.empty_like(self.arch_parameters).exponential_().log()
      logits  = (self.arch_parameters.log_softmax(dim=1) + gumbels) / self.tau
      probs   = nn.functional.softmax(logits, dim=1)
      index   = probs.max(-1, keepdim=True)[1]
      one_h   = torch.zeros_like(logits).scatter_(-1, index, 1.0)
      hardwts = one_h - probs.detach() + probs
      if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()):
        continue
      else: break

    feature = self.stem(inputs)
    for i, cell in enumerate(self.cells):
      if isinstance(cell, SearchCell):
        feature = cell.forward_gdas(feature, hardwts, index)
      else:
        feature = cell(feature)
    out = self.lastact(feature)
    out = self.global_pooling( out )
    out = out.view(out.size(0), -1)
    logits = self.classifier(out)

    return out, logits 
Example #26
Source File: SoftSelect.py    From AutoDL-Projects with MIT License 6 votes vote down vote up
def select2withP(logits, tau, just_prob=False, num=2, eps=1e-7):
  if tau <= 0:
    new_logits = logits
    probs = nn.functional.softmax(new_logits, dim=1)
  else       :
    while True: # a trick to avoid the gumbels bug
      gumbels = -torch.empty_like(logits).exponential_().log()
      new_logits = (logits.log_softmax(dim=1) + gumbels) / tau
      probs = nn.functional.softmax(new_logits, dim=1)
      if (not torch.isinf(gumbels).any()) and (not torch.isinf(probs).any()) and (not torch.isnan(probs).any()): break

  if just_prob: return probs

  #with torch.no_grad(): # add eps for unexpected torch error
  #  probs = nn.functional.softmax(new_logits, dim=1)
  #  selected_index = torch.multinomial(probs + eps, 2, False)
  with torch.no_grad(): # add eps for unexpected torch error
    probs          = probs.cpu()
    selected_index = torch.multinomial(probs + eps, num, False).to(logits.device)
  selected_logit = torch.gather(new_logits, 1, selected_index)
  selcted_probs  = nn.functional.softmax(selected_logit, dim=1)
  return selected_index, selcted_probs 
Example #27
Source File: training.py    From TTS with Mozilla Public License 2.0 6 votes vote down vote up
def check_update(model, grad_clip, ignore_stopnet=False):
    r'''Check model gradient against unexpected jumps and failures'''
    skip_flag = False
    if ignore_stopnet:
        grad_norm = torch.nn.utils.clip_grad_norm_([param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip)
    else:
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    # compatibility with different torch versions
    if isinstance(grad_norm, float):
        if np.isinf(grad_norm):
            print(" | > Gradient is INF !!")
            skip_flag = True
    else:
        if torch.isinf(grad_norm):
            print(" | > Gradient is INF !!")
            skip_flag = True
    return grad_norm, skip_flag 
Example #28
Source File: training.py    From TTS with Mozilla Public License 2.0 6 votes vote down vote up
def check_update(model, grad_clip, ignore_stopnet=False):
    r'''Check model gradient against unexpected jumps and failures'''
    skip_flag = False
    if ignore_stopnet:
        grad_norm = torch.nn.utils.clip_grad_norm_([param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip)
    else:
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    # compatibility with different torch versions
    if isinstance(grad_norm, float):
        if np.isinf(grad_norm):
            print(" | > Gradient is INF !!")
            skip_flag = True
    else:
        if torch.isinf(grad_norm):
            print(" | > Gradient is INF !!")
            skip_flag = True
    return grad_norm, skip_flag 
Example #29
Source File: training.py    From TTS with Mozilla Public License 2.0 6 votes vote down vote up
def check_update(model, grad_clip, ignore_stopnet=False):
    r'''Check model gradient against unexpected jumps and failures'''
    skip_flag = False
    if ignore_stopnet:
        grad_norm = torch.nn.utils.clip_grad_norm_([param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip)
    else:
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    # compatibility with different torch versions
    if isinstance(grad_norm, float):
        if np.isinf(grad_norm):
            print(" | > Gradient is INF !!")
            skip_flag = True
    else:
        if torch.isinf(grad_norm):
            print(" | > Gradient is INF !!")
            skip_flag = True
    return grad_norm, skip_flag 
Example #30
Source File: test_Gan_networks.py    From Variational_Discriminator_Bottleneck with MIT License 6 votes vote down vote up
def test_forward(self):
        # test the edge discriminator:
        mock_in = torch.randn(3, 128).to(device)

        mock_out = self.gen_edge(mock_in)

        # check the shapes of all the three:
        self.assertEqual(mock_out.shape, (3, 3, 4, 4))
        self.assertEqual(torch.isnan(mock_out).sum().item(), 0)
        self.assertEqual(torch.isinf(mock_out).sum().item(), 0)

        # test the normal discriminator:
        mock_in = torch.randn(16, 8).to(device)

        mock_out = self.gen(mock_in)

        # check the shapes of all the three:
        self.assertEqual(mock_out.shape, (16, 3, 256, 256))
        self.assertEqual(torch.isnan(mock_out).sum().item(), 0)
        self.assertEqual(torch.isinf(mock_out).sum().item(), 0)