Python torch.nn.Modules() Examples

The following are 4 code examples of torch.nn.Modules(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch.nn , or try the search function .
Example #1
Source File: conv_block.py    From batchflow with Apache License 2.0 6 votes vote down vote up
def _make_layer(self, *args, inputs=None, base_block=BaseConvBlock, **kwargs):
        # each element in `args` is a dict or module: make a sequential out of them
        if args:
            layers = []
            for item in args:
                if isinstance(item, dict):
                    block = item.pop('base_block', None) or item.pop('base', None) or base_block
                    block_args = {'inputs': inputs, **dict(Config(kwargs) + Config(item))}
                    layer = block(**block_args)
                    inputs = layer(inputs)
                    layers.append(layer)
                elif isinstance(item, nn.Module):
                    inputs = item(inputs)
                    layers.append(item)
                else:
                    raise ValueError('Positional arguments of ConvBlock must be either dicts or nn.Modules, \
                                      got instead {}'.format(type(item)))
            return nn.Sequential(*layers)
        # one block only
        return base_block(inputs=inputs, **kwargs) 
Example #2
Source File: classy_model.py    From ClassyVision with MIT License 5 votes vote down vote up
def get_heads(self):
        """Returns the heads on the model

        Function returns the heads a dictionary of block names to
        `nn.Modules <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_
        attached to that block.

        """
        return {
            block_name: list(heads.values())
            for block_name, heads in self._heads.items()
        } 
Example #3
Source File: blocks_3d.py    From novelty-detection with MIT License 5 votes vote down vote up
def residual_op(x, functions, bns, activation_fn):
    # type: (torch.Tensor, List[Module, Module, Module], List[Module, Module, Module], Module) -> torch.Tensor
    """
    Implements a global residual operation.

    :param x: the input tensor.
    :param functions: a list of functions (nn.Modules).
    :param bns: a list of optional batch-norm layers.
    :param activation_fn: the activation to be applied.
    :return: the output of the residual operation.
    """
    f1, f2, f3 = functions
    bn1, bn2, bn3 = bns

    assert len(functions) == len(bns) == 3
    assert f1 is not None and f2 is not None
    assert not (f3 is None and bn3 is not None)

    # A-branch
    ha = x
    ha = f1(ha)
    if bn1 is not None:
        ha = bn1(ha)
    ha = activation_fn(ha)

    ha = f2(ha)
    if bn2 is not None:
        ha = bn2(ha)

    # B-branch
    hb = x
    if f3 is not None:
        hb = f3(hb)
    if bn3 is not None:
        hb = bn3(hb)

    # Residual connection
    out = ha + hb
    return activation_fn(out) 
Example #4
Source File: blocks_2d.py    From novelty-detection with MIT License 5 votes vote down vote up
def residual_op(x, functions, bns, activation_fn):
    # type: (torch.Tensor, List[Module, Module, Module], List[Module, Module, Module], Module) -> torch.Tensor
    """
    Implements a global residual operation.

    :param x: the input tensor.
    :param functions: a list of functions (nn.Modules).
    :param bns: a list of optional batch-norm layers.
    :param activation_fn: the activation to be applied.
    :return: the output of the residual operation.
    """
    f1, f2, f3 = functions
    bn1, bn2, bn3 = bns

    assert len(functions) == len(bns) == 3
    assert f1 is not None and f2 is not None
    assert not (f3 is None and bn3 is not None)

    # A-branch
    ha = x
    ha = f1(ha)
    if bn1 is not None:
        ha = bn1(ha)
    ha = activation_fn(ha)

    ha = f2(ha)
    if bn2 is not None:
        ha = bn2(ha)

    # B-branch
    hb = x
    if f3 is not None:
        hb = f3(hb)
    if bn3 is not None:
        hb = bn3(hb)

    # Residual connection
    out = ha + hb
    return activation_fn(out)