mlspm.modules#

class mlspm.modules.AttentionConvZ(in_channels: int, z_out: int, kernel_size: int | Tuple[int, int, int] = 3, conv_depth: int = 2, padding_mode: str = 'zeros')[source]#

Bases: Module

Reduce and expand 3D feature map in z direction by an attention convolution.

Performs the computation

\[y_{k'} = \sum_{k=1}^{K} \sigma(f_{k'}(x))_k \odot x_k \quad ,\]

where \(\sigma\) is the sigmoid function, \(f_{k'}\) are convolution blocks, and \(\odot\) denotes an element-wise multiplication. The sum is over the z-dimension of the input feature map \(x\), and \(k' \in \{1...K'\}\) represents the z-index of the output feature map \(y\), which has a total z-size of \(K'\).

Parameters:
  • in_channels – Number of channels in input feature map.

  • z_out – Size of z dimension in output feature map (= \(K'\)).

  • kernel_size – Convolution kernel size.

  • conv_depth – Convolution block depth.

  • padding_mode – Type of padding in convolution layer. 'zeros', 'reflect', 'replicate' or 'circular'.

forward(x: Tensor) Tensor[source]#

Perform the forward computation.

Parameters:

x – Input feature map.

Returns

Output feature map with adjusted z-size.

class mlspm.modules.Conv2dBlock(in_channels: int, out_channels: int, kernel_size: int | Tuple[int, int] = 3, depth: int = 2, padding_mode: int = 'zeros', res_connection: bool = True, activation: Module | None = None, last_activation: bool = True)[source]#

Bases: _ConvNdBlock

Pytorch 2D convolution block module.

Parameters:
  • in_channels – Number of channels entering the first convolution layer.

  • out_channels – Number of output channels in each layer of the block.

  • kernel_size – Size of convolution kernel.

  • depth – Number of convolution layers in the block.

  • padding_mode – Type of padding in each convolution layer. 'zeros', 'reflect', 'replicate' or 'circular'.

  • res_connection – Whether to use residual connection over the block (f(x) = h(x) + x). If in_channels != out_channels, a 1x1x1 convolution is applied to the res connection to make the channel numbers match.

  • activation – Activation function to use after every layer in block. If None, defaults to ReLU.

  • last_activation – Whether to apply the activation after the last conv layer (before res connection).

class mlspm.modules.Conv3dBlock(in_channels: int, out_channels: int, kernel_size: int | Tuple[int, int, int] = 3, depth: int = 2, padding_mode: int = 'zeros', res_connection: bool = True, activation: Module | None = None, last_activation: bool = True)[source]#

Bases: _ConvNdBlock

Pytorch 3D convolution block module.

Parameters:
  • in_channels – Number of channels entering the first convolution layer.

  • out_channels – Number of output channels in each layer of the block.

  • kernel_size – Size of convolution kernel.

  • depth – Number of convolution layers in the block.

  • padding_mode – Type of padding in each convolution layer. 'zeros', 'reflect', 'replicate' or 'circular'.

  • res_connection – Whether to use residual connection over the block (f(x) = h(x) + x). If in_channels != out_channels, a 1x1x1 convolution is applied to the res connection to make the channel numbers match.

  • activation – Activation function to use after every layer in block. If None, defaults to ReLU.

  • last_activation – Whether to apply the activation after the last conv layer (before res connection).

class mlspm.modules.UNetAttentionConv(in_channels: int, query_channels: int, attention_channels: int, kernel_size: int | Tuple[int, ...], padding_mode: str = 'zeros', conv_activation: Module = ReLU(), attention_activation: str = 'softmax', upsample_mode: str = 'bilinear', ndim: int = 2)[source]#

Bases: Module

Pytorch attention layer for U-net model upsampling stage.

Given the input feature map \(x\), and query feature map \(q\), performs the computation

\[\begin{split}q' &= \sigma(f_\mathrm{q}(\mathrm{Interp}(q))) \\ x' &= \sigma(f_\mathrm{x}(x)) \\ a &= \sigma'(f_\mathrm{a}(\sigma(x'+ q'))) \\ y &= x \odot a,\end{split}\]

where \(f_\mathrm{\{q, x, a\}}\) are convolution layers, \(\mathrm{Interp}(q)\) denotes an interpolation of \(q\) to match the size of \(x\), \(\sigma\) and \(\sigma'\) are activation functions corresponding to conv_activation and attention_activation, and \(\odot\) denotes an element-wise multiplication.

Parameters:
  • in_channels – Number of channels in the attended feature map.

  • query_channels – Number of channels in query feature map.

  • attention_channels – Number of channels in hidden convolution layer before computing attention.

  • kernel_size – Size of convolution kernel.

  • padding_mode – Type of padding in each convolution layer. 'zeros', 'reflect', 'replicate' or 'circular'.

  • conv_activation – Activation function to use after convolution layers.

  • attention_activation – Type of activation to use for the attention map. 'sigmoid' or 'softmax'.

  • upsample_mode – Algorithm for upsampling query feature map to the attended feature map size. See torch.nn.functional.interpolate().

  • ndim – Dimensionality of convolution. 2 or 3.

Reference: https://arxiv.org/abs/1804.03999

forward(x: Tensor, q: Tensor) Tuple[Tensor, Tensor][source]#

Perform the forward computation.

Parameters:
  • x – Input feature map.

  • q – Query feature map.

Returns

Tuple (y, a), where

  • y - Attention-multiplied output feature map.

  • a - Attention map.