import torch.nn as nn from dfw.losses import MultiClassHingeLoss, set_smoothing_enabled def get_loss(args): if args.loss == 'svm': loss_fn = MultiClassHingeLoss() elif args.loss == 'ce': loss_fn = nn.CrossEntropyLoss() else: raise ValueError print('L2 regularization: \t {}'.format(args.l2)) print('\nLoss function:') print(loss_fn) if args.cuda: loss_fn = loss_fn.cuda() return loss_fn