Python torch.nn.TripletMarginLoss() Examples

The following are 5 code examples of torch.nn.TripletMarginLoss(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch.nn , or try the search function .
Example #1
Source File: base_agent.py    From 2D-Motion-Retargeting with MIT License 7 votes vote down vote up
def __init__(self, config, net):
        self.log_dir = config.log_dir
        self.model_dir = config.model_dir
        self.net = net
        self.clock = TrainClock()
        self.device = config.device

        self.use_triplet = config.use_triplet
        self.use_footvel_loss = config.use_footvel_loss

        # set loss function
        self.mse = nn.MSELoss()
        self.tripletloss = nn.TripletMarginLoss(margin=config.triplet_margin)
        self.triplet_weight = config.triplet_weight
        self.foot_idx = config.foot_idx
        self.footvel_loss_weight = config.footvel_loss_weight

        # set optimizer
        self.optimizer = optim.Adam(self.net.parameters(), config.lr)
        self.scheduler = optim.lr_scheduler.ExponentialLR(self.optimizer, 0.99) 
Example #2
Source File: __init__.py    From Deep-Expander-Networks with GNU General Public License v3.0 6 votes vote down vote up
def setup(model, opt):

    if opt.criterion == "l1":
        criterion = nn.L1Loss().cuda()
    elif opt.criterion == "mse":
        criterion = nn.MSELoss().cuda()
    elif opt.criterion == "crossentropy":
        criterion = nn.CrossEntropyLoss().cuda()
    elif opt.criterion == "hingeEmbedding":
        criterion = nn.HingeEmbeddingLoss().cuda()
    elif opt.criterion == "tripletmargin":
        criterion = nn.TripletMarginLoss(margin = opt.margin, swap = opt.anchorswap).cuda()

    parameters = filter(lambda p: p.requires_grad, model.parameters())

    if opt.optimType == 'sgd':
        optimizer = optim.SGD(parameters, lr = opt.lr, momentum = opt.momentum, nesterov = opt.nesterov, weight_decay = opt.weightDecay)
    elif opt.optimType == 'adam':
        optimizer = optim.Adam(parameters, lr = opt.maxlr, weight_decay = opt.weightDecay)

    if opt.weight_init:
        utils.weights_init(model, opt)

    return model, criterion, optimizer 
Example #3
Source File: utils.py    From mmfashion with Apache License 2.0 5 votes vote down vote up
def build_criterion(loss_dict):

    if loss_dict.type == 'CrossEntropyLoss':
        weight = loss_dict.weight
        size_average = loss_dict.size_average
        reduce = loss_dict.reduce
        reduction = loss_dict.reduction

        if loss_dict.use_sigmoid:
            return nn.BCEWithLogitsLoss(
                weight=weight,
                size_average=size_average,
                reduce=reduce,
                reduction=reduction)
        else:
            return nn.CrossEntropyLoss(
                weight=weight,
                size_average=size_average,
                reduce=reduce,
                reduction=reduction)

    elif loss_dict.type == 'TripletLoss':
        return nn.TripletMarginLoss(margin=loss_dict.margin, p=loss_dict.p)

    else:
        raise TypeError('{} cannot be processed'.format(loss_dict.type)) 
Example #4
Source File: loss.py    From triplet-reid-pytorch with Apache License 2.0 5 votes vote down vote up
def __init__(self, margin = None):
        super(TripletLoss, self).__init__()
        self.margin = margin
        if self.margin is None:  # use soft-margin
            self.Loss = nn.SoftMarginLoss()
        else:
            self.Loss = nn.TripletMarginLoss(margin = margin, p = 2) 
Example #5
Source File: triplet_loss.py    From pytorch-loss with MIT License 5 votes vote down vote up
def __init__(self, margin=None):
        super(TripletLoss, self).__init__()
        self.margin = margin
        if self.margin is None:  # if no margin assigned, use soft-margin
            self.Loss = nn.SoftMarginLoss()
        else:
            self.Loss = nn.TripletMarginLoss(margin=margin, p=2)