# -*- coding: utf-8 -*- # @Author : DevinYang(pistonyang@gmail.com) from torchtoolbox.nn.functional import class_balanced_weight from torchtoolbox.nn import * import torch from torch import nn import numpy as np @torch.no_grad() def test_lsloss(): pred = torch.rand(3, 10) label = torch.randint(0, 10, size=(3,)) Loss = LabelSmoothingLoss(10, 0.1) Loss1 = nn.CrossEntropyLoss() cost = Loss(pred, label) cost1 = Loss1(pred, label) assert cost.shape == cost1.shape @torch.no_grad() def test_logits_loss(): pred = torch.rand(3, 10) label = torch.randint(0, 10, size=(3,)) weight = class_balanced_weight(0.9999, np.random.randint(0, 100, size=(10,)).tolist()) Loss = SigmoidCrossEntropy(classes=10, weight=weight) Loss1 = FocalLoss(classes=10, weight=weight, gamma=0.5) Loss2 = ArcLoss(classes=10, weight=weight) cost = Loss(pred, label) cost1 = Loss1(pred, label) cost2 = Loss2(pred, label) print(cost, cost1, cost2) class n_to_n(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 3, 1, 1, bias=False) self.conv2 = nn.Conv2d(3, 3, 1, 1, bias=False) def forward(self, x1, x2): y1 = self.conv1(x1) y2 = self.conv2(x2) return y1, y2 class n_to_one(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 3, 1, 1, bias=False) self.conv2 = nn.Conv2d(3, 3, 1, 1, bias=False) def forward(self, x1, x2): y1 = self.conv1(x1) y2 = self.conv2(x2) return y1 + y2 class one_to_n(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 3, 1, 1, bias=False) self.conv2 = nn.Conv2d(3, 3, 1, 1, bias=False) def forward(self, x): y1 = self.conv1(x) y2 = self.conv2(x) return y1, y2 @torch.no_grad() def test_ad_sequential(): seq = AdaptiveSequential(one_to_n(), n_to_n(), n_to_one()) td = torch.rand(1, 3, 32, 32) out = seq(td) assert out.size() == torch.Size([1, 3, 32, 32]) @torch.no_grad() def test_switch_norm(): td2 = torch.rand(1, 3, 32, 32) td3 = torch.rand(1, 3, 32, 32, 3) norm2 = SwitchNorm2d(3) norm3 = SwitchNorm3d(3) out2 = norm2(td2) out3 = norm3(td3) assert out2.size() == td2.size() and out3.size() == td3.size() def test_swish(): td = torch.rand(1, 16, 32, 32) swish = Swish(beta=10.0) swish(td)