UNet++

class mmit.decoders.UNetPlusPlus(input_channels, input_reductions, decoder_channels=None, upsample_layer=<class 'mmit.base.upsamplers.ConvTranspose2d'>, norm_layer=<class 'torch.nn.modules.batchnorm.BatchNorm2d'>, activation_layer=<class 'torch.nn.modules.activation.ReLU'>, extra_layer=<class 'torch.nn.modules.linear.Identity'>, mismatch_layer=<class 'mmit.base.mismatch.Pad'>, return_features=False)

Implementation of the U-Net++ decoder. Paper: https://arxiv.org/abs/1807.10165.

In this implementation, we follow the following naming convention referring to Figure 1.a in the paper:
  • lidx is the layer index, i.e. the index that spans horizontally.

  • didx is the depth index, i.e. the index that spans vertically.

  • i_j will be the key of the block that is at depth i and layer j.

  • i_j will also be the key of the resulting tensor after the block that is at depth i and layer j.

Since we implement only the decoder, there will be no i_0 blocks, and the i_0 tensors will be the input tensors.

Parameters:
  • input_channels (List[int]) – The channels of the input features.

  • input_reductions (List[int]) – The reduction factor of the input features.

  • decoder_channels (Optional[List[int]]) – The channels on each layer of the decoder.

  • upsample_layer (Type[Module]) – Layer to use for the upsampling.

  • norm_layer (Type[Module]) – Normalization layer to use.

  • activation_layer (Type[Module]) – Activation function to use.

  • extra_layer (Type[Module]) – Addional layer to use.

  • mismatch_layer (Type[Module]) – Strategy to deal with odd resolutions.

  • return_features (bool) – Whether to return the intermediate results of the decoder.

forward(*features)

Forward pass of the decoder.

Parameters:

*features (Tensor) – Features from the encoder, the first is the input image, last one the deepest.

property out_classes: int

Number of output classes.