Python torch.nn.Modules() Examples
The following are 4
code examples of torch.nn.Modules().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch.nn
, or try the search function
.
Example #1
Source File: conv_block.py From batchflow with Apache License 2.0 | 6 votes |
def _make_layer(self, *args, inputs=None, base_block=BaseConvBlock, **kwargs): # each element in `args` is a dict or module: make a sequential out of them if args: layers = [] for item in args: if isinstance(item, dict): block = item.pop('base_block', None) or item.pop('base', None) or base_block block_args = {'inputs': inputs, **dict(Config(kwargs) + Config(item))} layer = block(**block_args) inputs = layer(inputs) layers.append(layer) elif isinstance(item, nn.Module): inputs = item(inputs) layers.append(item) else: raise ValueError('Positional arguments of ConvBlock must be either dicts or nn.Modules, \ got instead {}'.format(type(item))) return nn.Sequential(*layers) # one block only return base_block(inputs=inputs, **kwargs)
Example #2
Source File: classy_model.py From ClassyVision with MIT License | 5 votes |
def get_heads(self): """Returns the heads on the model Function returns the heads a dictionary of block names to `nn.Modules <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ attached to that block. """ return { block_name: list(heads.values()) for block_name, heads in self._heads.items() }
Example #3
Source File: blocks_3d.py From novelty-detection with MIT License | 5 votes |
def residual_op(x, functions, bns, activation_fn): # type: (torch.Tensor, List[Module, Module, Module], List[Module, Module, Module], Module) -> torch.Tensor """ Implements a global residual operation. :param x: the input tensor. :param functions: a list of functions (nn.Modules). :param bns: a list of optional batch-norm layers. :param activation_fn: the activation to be applied. :return: the output of the residual operation. """ f1, f2, f3 = functions bn1, bn2, bn3 = bns assert len(functions) == len(bns) == 3 assert f1 is not None and f2 is not None assert not (f3 is None and bn3 is not None) # A-branch ha = x ha = f1(ha) if bn1 is not None: ha = bn1(ha) ha = activation_fn(ha) ha = f2(ha) if bn2 is not None: ha = bn2(ha) # B-branch hb = x if f3 is not None: hb = f3(hb) if bn3 is not None: hb = bn3(hb) # Residual connection out = ha + hb return activation_fn(out)
Example #4
Source File: blocks_2d.py From novelty-detection with MIT License | 5 votes |
def residual_op(x, functions, bns, activation_fn): # type: (torch.Tensor, List[Module, Module, Module], List[Module, Module, Module], Module) -> torch.Tensor """ Implements a global residual operation. :param x: the input tensor. :param functions: a list of functions (nn.Modules). :param bns: a list of optional batch-norm layers. :param activation_fn: the activation to be applied. :return: the output of the residual operation. """ f1, f2, f3 = functions bn1, bn2, bn3 = bns assert len(functions) == len(bns) == 3 assert f1 is not None and f2 is not None assert not (f3 is None and bn3 is not None) # A-branch ha = x ha = f1(ha) if bn1 is not None: ha = bn1(ha) ha = activation_fn(ha) ha = f2(ha) if bn2 is not None: ha = bn2(ha) # B-branch hb = x if f3 is not None: hb = f3(hb) if bn3 is not None: hb = bn3(hb) # Residual connection out = ha + hb return activation_fn(out)