Source code for mlspm.graph.models

import warnings
from typing import Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from ..modules import AttentionConvZ, Conv3dBlock, UNetAttentionConv, _get_padding
from . import Atom, MoleculeGraph, find_gaussian_peaks, make_box_borders
from .._weights import download_weights


def _get_activation(activation):
    if isinstance(activation, nn.Module):
        return activation
    elif activation == "relu":
        return nn.ReLU()
    elif activation == "lrelu":
        return nn.LeakyReLU()
    elif activation == "elu":
        return nn.ELU()
    else:
        raise ValueError(f"Unknown activation function {activation}")


def _get_pool(pool_type):
    if pool_type == "avg":
        return nn.AvgPool3d
    elif pool_type == "max":
        return nn.MaxPool3d
    else:
        raise ValueError(f"Unknown pool type {pool_type}")


[docs] class PosNet(nn.Module): """ Attention U-net for predicting the positions of atoms in an AFM image. Arguments: encode_block_channels: Number channels in encoding 3D conv blocks. encode_block_depth: Number of layers in each encoding 3D conv block. decode_block_channels: Number of channels in each decoding 3D conv block after upscale before skip connection. decode_block_depth: Number of layers in each decoding 3D conv block after upscale before skip connection. decode_block_channels2: Number of channels in each decoding 3D conv block after skip connection. decode_block_depth2: Number of layers in each decoding 3D conv block after skip connection. attention_channels: Number of channels in conv layers within each attention block. res_connections: Whether to use residual connections in conv blocks. activation: Activation to use after every layer. ``'relu'``, ``'lrelu'``, or ``'elu'`` or :class:`nn.Module`. padding_mode: Type of padding in each convolution layer. ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. pool_type: Type of pooling to use. ``'avg'`` or ``'max'``. decoder_z_sizes: Upscale sizes of decoder stages in the z dimension. z_outs: Size of the z-dimension after encoder and skip connections. attention_activation: Type of activation to use for attention map. ``'sigmoid'`` or ``'softmax'``. afm_res: Real-space size of pixels in xy-plane in input AFM images in angstroms. grid_z_range: The real-space range in z-direction of the position grid in angstroms. Of the format ``(z_min, z_max)``. peak_std: Standard deviation of atom position grid peaks in angstroms. match_threshold: Detection threshold for matching when finding atom position peaks. match_method: Method for template matching when finding atom position peaks. See :func:`.find_gaussian_peaks` for options. device: Device to store model on. """ def __init__( self, encode_block_channels: list[int] = [4, 8, 16, 32], encode_block_depth: int = 2, decode_block_channels: list[int] = [32, 16, 8], decode_block_depth: int = 2, decode_block_channels2: list[int] = [32, 16, 8], decode_block_depth2: int = 2, attention_channels: list[int] = [32, 32, 32], res_connections: bool = True, activation: str = "relu", padding_mode: str = "replicate", pool_type: str = "avg", decoder_z_sizes: list[int] = [5, 10, 20], z_outs: list[int] = [3, 3, 5, 10], attention_activation: str = "softmax", afm_res: float = 0.125, grid_z_range: Tuple[float, float] = (-1.4, 0.5), peak_std: float = 0.3, match_threshold: float = 0.7, match_method: str = "msd_norm", device: str = "cuda", ): super().__init__() assert ( len(encode_block_channels) == len(decoder_z_sizes) + 1 == len(decode_block_channels) + 1 == len(decode_block_channels2) + 1 ), "Numbers of blocks do not match" self.encode_block_channels = encode_block_channels self.num_blocks = len(encode_block_channels) self.act = _get_activation(activation) self.decoder_z_sizes = decoder_z_sizes self.upsample_mode = "trilinear" self.padding_mode = padding_mode self.afm_res = afm_res self.grid_z_range = grid_z_range self.peak_std = peak_std self.match_threshold = match_threshold self.match_method = match_method self.pool_type = pool_type pool = _get_pool(self.pool_type) self.pool = pool((2, 2, 1), stride=(2, 2, 1)) # No pool in z-dimension # Encoder conv blocks encode_block_channels = [1] + encode_block_channels self.encode_blocks = nn.ModuleList( [ Conv3dBlock( encode_block_channels[i], encode_block_channels[i + 1], 3, encode_block_depth, padding_mode, res_connections, self.act, ) for i in range(self.num_blocks) ] ) # Skip-connection attention conv blocks self.unet_attentions = nn.ModuleList( [ UNetAttentionConv( encode_block_channels[-(i + 2)], encode_block_channels[-1], attention_channels[i], 3, padding_mode, self.act, attention_activation, upsample_mode=self.upsample_mode, ndim=3, ) for i in range(self.num_blocks - 1) ] ) # Decoder conv blocks decode_block_channels2 = [encode_block_channels[-1]] + decode_block_channels2 self.decode_blocks = nn.ModuleList( [ Conv3dBlock( decode_block_channels2[i], decode_block_channels[i], 3, decode_block_depth, padding_mode, res_connections, self.act, False, ) for i in range(self.num_blocks - 1) ] ) self.decode_blocks2 = nn.ModuleList( [ Conv3dBlock( decode_block_channels[i] + encode_block_channels[-(i + 2)], decode_block_channels2[i + 1], 3, decode_block_depth2, padding_mode, res_connections, self.act, False, ) for i in range(self.num_blocks - 1) ] ) self.out_conv = nn.Conv3d( decode_block_channels2[-1], 1, kernel_size=3, padding=_get_padding(3, 3), padding_mode=padding_mode, ) # Attention convolution for dealing with variable z sizes at the end of the encoder enc_channels = self.encode_block_channels self.att_conv_enc = AttentionConvZ(enc_channels[-1], z_outs[0], conv_depth=3, padding_mode=self.padding_mode) # Attention convolutions for the skip connections self.att_conv_skip = nn.ModuleList( [ AttentionConvZ(c, z_out, conv_depth=3, padding_mode=self.padding_mode) for c, z_out in zip(reversed(enc_channels[:-1]), z_outs[1:]) ] ) self.device = device self.to(device)
[docs] def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, list[torch.Tensor]]: """ Arguments: x: Batch of AFM images. Should be of shape ``(n_batch, nx, ny, nz)```. Returns: Tuple (**pos_dist**, **attention_maps**), where - **pos_dist** - Predicted atom position distribution. - **attention_maps** - Attention maps from the skip connection attention layers. """ xs = [] for i in range(self.num_blocks): # Apply 3D conv block x = self.act(self.encode_blocks[i](x)) if i < self.num_blocks - 1: # Store feature maps for attention gating later xs.append(x) # Down-sample for next iteration of convolutions. x = self.pool(x) # Apply attention convolution to get to a fixed z size x = self.att_conv_enc(x) # Compute skip-connection attention maps attention_maps = [] x_gated = [] xs.reverse() for attention, att_z, x_ in zip(self.unet_attentions, self.att_conv_skip, xs): xg, a = attention(x_, x) xg = att_z(xg) x_gated.append(xg) attention_maps.append(a) # Decode for i, (conv1, conv2, xg) in enumerate(zip(self.decode_blocks, self.decode_blocks2, x_gated)): # Upsample and apply first conv block target_size = xs[i].shape[2:4] + (self.decoder_z_sizes[i],) x = F.interpolate(x, size=target_size, mode=self.upsample_mode, align_corners=False) x = self.act(conv1(x)) # Concatenate attention-gated skip connections and apply second conv block xg = F.interpolate(xg, size=target_size, mode=self.upsample_mode, align_corners=False) x = torch.cat([x, xg], dim=1) x = self.act(conv2(x)) # Get output grid x = self.out_conv(x).squeeze(1) return x, attention_maps
[docs] def get_positions( self, x: torch.Tensor | np.ndarray, device: str = "cuda" ) -> Tuple[list[torch.Tensor], torch.Tensor | np.ndarray, list[torch.Tensor | np.ndarray]]: """ Predict atom positions for a batch of AFM images. Arguments: x: Batch of AFM images. Should be of shape ``(n_batch, nx, ny, nz)``. device: Device used when **x** is an np.ndarray. Returns: atom_pos: Atom positions for each batch item. grid: Atom position grid from PosNet prediction. Type matches input AFM image type attention: Attention maps from skip-connection attention layers. Type matches input AFM image type. """ if isinstance(x, np.ndarray): xt = torch.from_numpy(x).float().to(device) else: xt = x with torch.no_grad(): xt, attention = self(xt.unsqueeze(1)) box_borders = make_box_borders(x.shape[1:3], (self.afm_res, self.afm_res), z_range=self.grid_z_range) atom_pos, _, _ = find_gaussian_peaks( xt, box_borders, match_threshold=self.match_threshold, std=self.peak_std, method=self.match_method, ) if isinstance(x, np.ndarray): attention = [a.cpu().numpy() for a in attention] xt = xt.cpu().numpy() return atom_pos, xt, attention
[docs] class GraphImgNet(nn.Module): """ Image-to-graph model that constructs a molecule graph out of atom positions and an AFM image. Arguments: posnet: :class:`PosNet` for predicting atom positions from an AFM image. Required when training or doing inference without pre-defined atom positions. n_classes: Number of different classes for nodes. iters: Number of message passing iterations. node_feature_size: Number of hidden node features. message_size: Size of message vector. message_hidden_size: Size of hidden layers in message MLP. edge_cutoff: Cutoff radius in angstroms for edges between atoms within MPNN. afm_cutoff: Cutoff radius in angstroms for receptive regions around each atom. afm_res: Real-space size of pixels in xy-plane in input AFM images in angstroms. conv_channels: Number channels in 3D conv blocks encoding AFM image regions. conv_depth: Number of layers in each 3D conv block. node_out_hidden_size: Size of hidden layers in node classification MLP. edge_out_hidden_size: Size of hidden layers in edge classification MLP. res_connections: Whether to use residual connections in conv blocks. activation: Activation to use after every layer. ``'relu'``, ``'lrelu'``, or ``'elu'`` or :class:`nn.Module`. padding_mode: Type of padding in each convolution layer. ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. pool_type: Type of pooling to use. ``'avg'`` or ``'max'``. device: Device to store model on. """ def __init__( self, n_classes: int, posnet: Optional[PosNet] = None, iters: int = 3, node_feature_size: int = 20, message_size: int = 20, message_hidden_size: int = 64, edge_cutoff: float = 3.0, afm_cutoff: float = 1.25, afm_res: float = 0.125, conv_channels: list[int] = [4, 8, 16], conv_depth: int = 2, node_out_hidden_size: int = 128, edge_out_hidden_size: int = 128, res_connections: bool = True, activation: str | nn.Module = "relu", padding_mode: str = "zeros", pool_type: str = "avg", device: str = "cuda", ): super().__init__() self.posnet = posnet self.n_classes = n_classes self.iters = iters self.node_feature_size = node_feature_size self.message_size = message_size self.act = _get_activation(activation) self.edge_cutoff = edge_cutoff self.afm_cutoff = afm_cutoff self.afm_res = afm_res self.conv_channels = conv_channels self.padding_mode = padding_mode if (self.posnet is not None) and not np.allclose(self.posnet.afm_res, self.afm_res): warnings.warn( f"AFM pixel resolution ({self.afm_res}) does not match with the resolution in PosNet ({self.posnet.afm_res}). " "This can lead to bad inference results." ) self.pool_type = pool_type pool = _get_pool(self.pool_type) self.pool = pool((2, 2, 1), stride=(2, 2, 1)) # Don't pool in z direction self.msg_net = nn.Sequential( nn.Linear(2 * node_feature_size + 3, message_hidden_size), self.act, nn.Linear(message_hidden_size, message_hidden_size), self.act, nn.Linear(message_hidden_size, message_size), ) self.gru_node = nn.GRUCell(message_size, node_feature_size) self.gru_edge = nn.GRUCell(message_size, node_feature_size) conv_in_channels = [1] + conv_channels[:-1] self.conv_blocks = nn.ModuleList( [ Conv3dBlock( conv_in_channels[i], conv_channels[i], 3, conv_depth, padding_mode, res_connections, self.act, ) for i in range(len(self.conv_channels)) ] ) # Attention convolution for dealing with variable feature maps size at the end of the AFM image encoder in_channels = self.conv_channels[-1] self.att_conv = Conv3dBlock( in_channels, in_channels, kernel_size=3, depth=3, padding_mode=self.padding_mode, res_connection=False, last_activation=False, ) self.node_transform = nn.Linear(conv_channels[-1], node_feature_size) self.class_net = nn.Sequential( nn.Linear(node_feature_size, node_out_hidden_size), self.act, nn.Linear(node_out_hidden_size, n_classes), ) self.edge_net = nn.Sequential( nn.Linear(node_feature_size, edge_out_hidden_size), self.act, nn.Linear(edge_out_hidden_size, 1), nn.Sigmoid(), ) self.device = device self.to(device) def _gather_afm(self, x: torch.Tensor, pos: list[torch.Tensor]) -> torch.Tensor: if sum([len(p) for p in pos]) == 0: # No atom positions, so just return an empty tensor print("Encountered an empty position list") return torch.zeros((0, 1, 1, 1), device=self.device) ind_radius = int(self.afm_cutoff / self.afm_res) # Pad AFM image so that image regions on the edges are the same size x = F.pad( x, (0, 0, ind_radius, ind_radius, ind_radius, ind_radius), mode="constant", value=0, ) x_afm = [] for ib, p in enumerate(pos): # Find xy index range around each atom ind_min = ((p[:, :2] - self.afm_cutoff) / self.afm_res).round().long() ind_min += ind_radius # Add radius due to padding ind_max = ind_min + 2 * ind_radius + 1 for ia in range(p.shape[0]): x_afm.append( x[ ib, ind_min[ia, 0] : ind_max[ia, 0], ind_min[ia, 1] : ind_max[ia, 1], ] ) return torch.stack(x_afm, axis=0) def _get_edges(self, pos: list[torch.Tensor]) -> list[torch.Tensor]: edges = [] for p in pos: d = F.pdist(p) inds = torch.nonzero(d <= self.edge_cutoff)[:, 0] edges.append(torch.triu_indices(len(p), len(p), offset=1, device=self.device)[:, inds]) return edges def _combine_graphs(self, pos: list[torch.Tensor], edges: list[torch.Tensor]): Ns = [] ind_count = 0 edges_shifted = [] for p, e in zip(pos, edges): edges_shifted += [[e_[0] + ind_count, e_[1] + ind_count] for e_ in e.T] ind_count += len(p) Ns.append(len(p)) edges_shifted = torch.tensor(edges_shifted, device=self.device, dtype=torch.long).T pos = torch.cat(pos, axis=0) return edges_shifted, pos, Ns
[docs] def pred_to_graph( self, pos: list[torch.Tensor], node_classes: list[torch.Tensor], edge_classes: list[torch.Tensor], edges: list[torch.Tensor], bond_threshold: float, ) -> list[MoleculeGraph]: """ Convert predicted batch to a simple list of molecule graphs. Arguments: pos: Atom positions for each batch item. node_classes: Predicted class probabilities for each atom in the molecule graphs. Each batch item is a tensor of shape ``(n_atoms, n_classes)``. edge_classes: Predicted probabilities for the existence of bonds between atoms indicated by **edges**. Eatch batch item is a tensor of shape ``(n_edges,)``. edges: Possible bond connection indices between atoms. Each batch item is a tensor of shape ``(2, n_edges)``. bond_threshold: Threshold probability when an edge is considered a bond between atoms. Returns: Molecule graphs corresponding to the predictions. """ graphs = [] for p, nc, e, ec in zip(pos, node_classes, edges, edge_classes): nc = F.softmax(nc, dim=1) atoms = [Atom(pi.cpu().numpy(), class_weights=nci.cpu().numpy()) for pi, nci in zip(p, nc)] et = e[:, ec >= bond_threshold].cpu().numpy() bonds = [tuple(b) for b in et.T] graphs.append(MoleculeGraph(atoms, bonds)) return graphs
[docs] def encode_afm(self, x: torch.Tensor) -> torch.Tensor: if x.shape[0] == 0: return torch.zeros((0, self.node_feature_size), device=self.device) # Apply convolutions x = x.unsqueeze(1) for conv in self.conv_blocks: x = self.act(conv(x)) x = self.pool(x) # Reduce xyz dimensions to 1 by attention convolution sh = x.shape a = self.att_conv(x) # Get attention maps a = F.softmax(a.reshape(sh[0], sh[1], -1), dim=2).reshape(sh) # Softmax so that xyz dimension sum to 1 for every channel x = (a * x).sum(dim=(2, 3, 4)) # Multiply by attention and reduce over xyz # Transform features by a linear layer x = self.node_transform(self.act(x)) return x
[docs] def mpnn(self, pos: torch.Tensor, node_features: torch.Tensor, edges: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: Ne = 0 if edges.ndim < 2 else edges.size(1) # Number of edges if Ne > 0: # Symmetrise directional edge connections edges_sym = torch.cat([edges, edges[[1, 0]]], dim=1) # Compute vectors between nodes connected by edges src_pos = pos.index_select(0, edges_sym[0]) dst_pos = pos.index_select(0, edges_sym[1]) d_pos = dst_pos - src_pos # Initialize edge features to the average of the nodes they are connecting src_features = node_features.index_select(0, edges[0]) dst_features = node_features.index_select(0, edges[1]) edge_features = (src_features + dst_features) / 2 else: edge_features = torch.empty((0, self.node_feature_size), device=self.device) for _ in range(self.iters): a = torch.zeros(node_features.size(0), self.message_size, device=self.device) if Ne > 0: # No messages if no edges # Gather start and end nodes of edges src_features = node_features.index_select(0, edges_sym[0]) dst_features = node_features.index_select(0, edges_sym[1]) inputs = torch.cat([src_features, dst_features, d_pos], dim=1) # Calculate messages for all edges and add them to start nodes messages = self.msg_net(inputs) a.index_add_(0, edges_sym[0], messages) # Update edge features b = (messages[:Ne] + messages[Ne:]) / 2 # Average over two directions edge_features = self.gru_edge(b, edge_features) # Update node features node_features = self.gru_node(a, node_features) return node_features, edge_features
[docs] def forward( self, x: torch.Tensor, pos: Optional[list[torch.Tensor]] = None ) -> Tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: """ Arguments: x: Batch of AFM images. Array of shape ``(n_batch, nx, ny, nz)``. pos: Atom positions for each batch item. If None, the positions are predicted from the AFM images. The positions should be such that the lower left corner of the AFM image is at coordinate ``(0, 0)``, and all positions are within the bounds of the AFM image. Model needs to be constructed with PosNet defined in order for the position prediction to work. Returns: Tuple (**node_classes**, **edge_classes**, **edges**), where - **node_classes** - Predicted class probabilities for each atom in the molecule graphs. Each batch item is a tensor of shape ``(n_atoms, n_classes)``. - **edge_classes** - Predicted probabilities for the existence of bonds between atoms indicated by **edges**. Eatch batch item is a tensor of shape ``(n_edges,)``. - **edges** - Possible bond connection indices between atoms. Each batch item is a tensor of shape ``(2, n_edges)``. """ assert x.ndim == 4, "Wrong number of dimensions in AFM tensor. Should be 4." if pos is None: if self.posnet is None: raise RuntimeError(f"Attempting to predict atom positions, but PosNet is not defined.") pos, _, _ = self.posnet.get_positions(x) # Gather all AFM image regions to one tensor x = self._gather_afm(x, pos) # Get AFM embeddings for each node x_afm = self.encode_afm(x) # Propagate features with MPNN edges = self._get_edges(pos) # Get edges based on distances between atoms edges_combined, pos_combined, Ns = self._combine_graphs(pos, edges) # Combine graphs into one for faster processing node_features, edge_features = self.mpnn(pos_combined, x_afm, edges_combined) # Predict node and edge classes node_classes = self.class_net(node_features) edge_classes = self.edge_net(edge_features).squeeze(1) # Split into batch of separate graphs again node_classes = torch.split(node_classes, split_size_or_sections=Ns) edge_classes = torch.split(edge_classes, split_size_or_sections=[e.size(1) for e in edges]) return node_classes, edge_classes, edges
[docs] def predict_graph( self, x: torch.Tensor | np.ndarray, pos: Optional[torch.Tensor] = None, bond_threshold: float = 0.5, ) -> Tuple[list[MoleculeGraph], Optional[torch.Tensor | np.ndarray]]: """ Predict molecule graphs from AFM images. Arguments: X: Batch of AFM images. Array of shape ``(n_batch, nx, ny, nz)``. pos: Atom positions for each batch item. If None, the positions are predicted from the AFM images. The positions should be such that the lower left corner of the AFM image is at coordinate ``(0, 0)``, and all positions are within the bounds of the AFM image. bond_threshold: Threshold probability when an edge is considered a bond between atoms. Returns: Tuple (**graphs**, **grid**), where - **graphs**: Predicted graphs. - **grid**: Atom position grid from PosNet prediction when input **pos** is ``None``. Type matches input AFM image type. """ if isinstance(x, np.ndarray): x = torch.from_numpy(x).float().to(self.device) if pos is None: if self.posnet is None: raise RuntimeError(f"Attempting to predict positions, but PosNet is not defined.") pos, grid, _ = self.posnet.get_positions(x) else: grid = None node_classes, edge_classes, edges = self.forward(x, pos) graphs = self.pred_to_graph(pos, node_classes, edge_classes, edges, bond_threshold) return graphs, grid
[docs] class GraphImgNetIce(GraphImgNet): """ GraphImgNet with hyperparameters set exactly as in the paper "Structure discovery in Atomic Force Microscopy imaging of ice", https://arxiv.org/abs/2310.17161. Three sets of pretrained weights are available: - ``'cu111'``: trained on images of ice clusters on Cu(111) - ``'au111-monolayer'``: trained on images of ice clusters on monolayer Au(111) - ``'au111-bilayer'``: trained on images of ice clusters on bilayer Au(111) Arguments: pretrained_weights: Name of pretrained weights. If specified, load pretrained weights. Otherwise, weights are initialized randomly. grid_z_range: The real-space range in z-direction of the position grid in angstroms. Of the format ``(z_min, z_max)``. Has to be specified when **pretrained_weights** is not given. device: Device to store model on. """ def __init__(self, pretrained_weights: Optional[str] = None, grid_z_range: Optional[Tuple[float, float]] = None, device="cuda"): if pretrained_weights is not None: ice_z_ranges = { "cu111": (-2.9, 0.5), "au111-monolayer": (-2.9, 0.5), "au111-bilayer": (-3.5, 0.5), } z_range_weights = ice_z_ranges[pretrained_weights] if (grid_z_range is not None) and not np.allclose(z_range_weights, grid_z_range): warnings.warn(f"Specified grid z range ({grid_z_range}) does not match one for pretrained_weights ({z_range_weights})") else: grid_z_range = z_range_weights elif grid_z_range is None: raise ValueError("At least one of pretrained_weights or grid_z_range has to be specified.") outsize = round((grid_z_range[1] - grid_z_range[0]) / 0.1) + 1 posnet = PosNet( encode_block_channels=[16, 32, 64, 128], encode_block_depth=3, decode_block_channels=[128, 64, 32], decode_block_depth=2, decode_block_channels2=[128, 64, 32], decode_block_depth2=3, attention_channels=[128, 128, 128], res_connections=True, activation="relu", padding_mode="zeros", pool_type="avg", decoder_z_sizes=[5, 15, outsize], z_outs=[3, 3, 5, 10], attention_activation="softmax", afm_res=0.125, grid_z_range=grid_z_range, peak_std=0.20, match_threshold=0.7, match_method="msd_norm", device=device, ) super().__init__( n_classes=2, posnet=posnet, iters=5, node_feature_size=40, message_size=40, message_hidden_size=196, edge_cutoff=3.0, afm_cutoff=1.125, afm_res=0.125, conv_channels=[12, 24, 48], conv_depth=2, node_out_hidden_size=196, edge_out_hidden_size=196, res_connections=True, activation="relu", padding_mode="zeros", pool_type="avg", device=device, ) if pretrained_weights is not None: weights_name = f"graph-ice-{pretrained_weights}" weights_path = download_weights(weights_name) weights = torch.load(weights_path) self.load_state_dict(weights)