mlspm.graph#
- class mlspm.graph.Atom(xyz: Iterable[int], element: int | str = None, q: float = None, classes: Iterable[list[int | str]] = None, class_weights: Iterable[float] = None)[source]#
Bases:
objectA class representing an atom with a position, element and a charge.
- Parameters:
xyz – The xyz position of the atom.
element – The element of the atom. Either atomic number or chemical symbol.
q – The charge of the atom.
classes – Classes for categorizing atoms based on their chemical elements. Each class is a list of elements either as atomic numbers or as chemical symbols.
class_weights – List of weights or one-hot vector for classes. The weights must sum to unity.
Note: only one of classes and class_weights can be specified at the same time.
- array(xyz: bool = False, q: bool = False, element: bool = False, class_index: bool = False, class_weights: bool = False) ndarray[source]#
Return an array representation of the atom in order [xyz, q, element, class_index, one_hot_class].
- Parameters:
xyz – Include xyz coordinates.
q – Include charge.
element – Include element.
class_index – Include class index.
class_weights – Include class weights.
- Returns:
Array with requested information.
- class mlspm.graph.GraphStats(classes: list[list[int]], dist_threshold: float = 0.35, bin_size: int = 4)[source]#
Bases:
objectGather statistics on graph predictions divided into bins based on the graph sizes.
- Parameters:
classes – Classes for categorizing atom based on their chemical elements. Each class is a list of elements either as atomic numbers or as chemical symbols.
dist_threshold – Distance threshold for considering atoms as matching between prediction and reference.
bin_size – Bin size for graph sizes. The samples are divided into bins based on the number of nodes in the graphs.
- add_batch(pred: list[MoleculeGraph], ref: list[MoleculeGraph])[source]#
Gather stats from one batch of predictions and references.
- Parameters:
pred – Predicted molecule graphs.
ref – Reference molecule graphs.
- bond_count_diffs(size_bin: int = -1) ndarray[source]#
Get the full list of differences in edge counts between predictions and references.
- Parameters:
size_bin – Index of graph size bin. If negative, include all bins
- Returns:
Array of edge count differences.
- conf_mat_edge(size_bin: int = -1)[source]#
Get the confusion matrix for edge classification.
- Parameters:
size_bin – Index of graph size bin. If negative, include all bins.
- Returns:
Confusion matrix of predicted vs reference edge classes.
- conf_mat_node(size_bin: int = -1) ndarray[source]#
Get the confusion matrix for node classification.
- Parameters:
size_bin – Index of graph size bin. If negative, include all bins.
- Returns:
Confusion matrix of predicted vs reference node classes.
- edge_precision(size_bin: int = -1) ndarray[source]#
Get the precision for edge classification.
- Parameters:
size_bin – Index of graph size bin. If negative, include all bins.
- Returns:
Edge classification precision for every class.
- edge_recall(size_bin: int = -1) ndarray[source]#
Get the recall for edge classification.
- Parameters:
size_bin – Index of graph size bin. If negative, include all bins.
- Returns:
Edge classification recall for every class.
- extra_nodes(size_bin: int = -1) ndarray[source]#
Get the full list of the number of extra nodes in predictions compared to references.
- Parameters:
size_bin – Index of graph size bin. If negative, include all bins
- Returns:
Array of extra node counts.
- graph_sizes(size_bin: int = -1) ndarray[source]#
Get the full list of graph sizes.
- Parameters:
size_bin – Index of graph size bin. If negative, include all bins
- Returns:
Array of graph sizes.
- hausdorff_distances(size_bin: int = -1) ndarray[source]#
Get the full list of Hausdorff distances between predictions and references.
- Parameters:
size_bin – Index of graph size bin. If negative, include all bins
- Returns:
Array of Hausdorff distances.
- property largest_graph: int#
Size of the largest graph seen in all batches.
- matching_distances(size_bin: int = -1) ndarray[source]#
Get the full list of matching distances between predictions and references. The matching distance is the distance between a pair of nodes in the prediction and reference that were within the set threshold distance.
- Parameters:
size_bin – Index of graph size bin. If negative, include all bins
- Returns:
Array of matching distances.
- missing_nodes(size_bin: int = -1) ndarray[source]#
Get the full list of the number of missing nodes in predictions compared to references.
- Parameters:
size_bin – Index of graph size bin. If negative, include all bins
- Returns:
Array of missing node counts.
- node_count_diffs(size_bin: int = -1) ndarray[source]#
Get the full list of differences in node counts between predictions and references.
- Parameters:
size_bin – Index of graph size bin. If negative, include all bins
- Returns:
Array of node count differences.
- node_precision(size_bin: int = -1) ndarray[source]#
Get the precision for node classification.
- Parameters:
size_bin – Index of graph size bin. If negative, include all bins.
- Returns:
Node classification precision for every class.
- node_recall(size_bin: int = -1) ndarray[source]#
Get the precision for node classification.
- Parameters:
size_bin – Index of graph size bin. If negative, include all bins.
- Returns:
Node classification recall for every class.
- plot(outdir: str = './', verbose: str = 1)[source]#
Plot histograms of graph sizes, node/bond count differences, Hausdorff distances, and maching distances, and confusion matrices for node and edge classification.
- Parameters:
outdir – Directory where images are saved.
verbose – Whether to print information.
- report(outdir: str = './', verbose: int = 1)[source]#
Save to file mean absolute node/bond count diffs, mean Hausdorff, mean matching distance, missing/extra atoms, total samples/nodes, average/largest graph size, and node/edge precision/recall.
- Parameters:
outdir – Directory where files are saved.
verbose – Whether to print information.
- property total_nodes: int#
Total number of nodes seen in graphs all batches.
- property total_samples: int#
Total number of samples seen in all batches.
- class mlspm.graph.MoleculeGraph(atoms: list[Atom] | ndarray, bonds: list[Tuple[int, int]], classes: Iterable[list[int | str]] = None, class_weights: Iterable[float] = None)[source]#
Bases:
objectA class representing a molecule graph with atoms and bonds. The atoms are stored as a list of Atom objects.
- Parameters:
atoms – Molecule atom position and elements. If an np.ndarray, then must be of shape
(num_atoms, 4), where each row corresponds to one atom with[x, y, z, element].bonds – Indices of bonded atoms. Each bond is a tuple
(bond_start, bond_end)with the indices of the atom that the bond connects.classes – Classes for categorizing atom based on their chemical elements. Each class is a list of elements either as atomic numbers or as chemical symbols.
class_weights – List of weights or one-hot vector for classes. The weights must sum to unity.
Note: only one of classes and class_weights can be specified at the same time.
- add_atom(atom: Atom, bonds: list[int]) Self[source]#
Add an atom and bonds to molecule graph.
- Parameters:
atom – Atom to add.
bonds – Indicator list (0s and 1s) of bond connections from new atom to existing atoms in the graph.
- Returns:
New molecule graph where the atom and bonds have been added.
- adjacency_matrix() ndarray[source]#
Return the adjacency matrix of the graph.
- Returns:
Adjacency matrix of shape
(n_atoms, n_atoms), where the presence of bonds between pairs of atoms are indicated by binary values.
- array(xyz: bool = False, q: bool = False, element: bool = False, class_index: bool = False, class_weights: bool = False) ndarray | list[source]#
Return an array representation of the atoms in the molecule in order [xyz, q, element, class_index, class_weights]
- Parameters:
xyz – Include xyz coordinates.
q – Include charge.
element – Include element.
class_index – Include class index.
class_weights – Include class weights.
- Returns:
Array with requested information. Each element in first dimension corresponds to one atom.
- crop_atoms(box_borders: ndarray) Self[source]#
Delete atoms that are outside of specified region.
- Parameters:
box_borders – Real-space extent of the region outside of which atoms are deleted. The array should be of the form
((x_start, y_start, z_start), (x_end, y_end, z_end)).- Returns:
A new molecule graph without the deleted atoms.
- permute(permutation: list[int]) Self[source]#
Permute the index order of atoms and corresponding bond indices.
- Parameters:
permutation – New index order. Has to be same length as the number of atoms in graph.
- Returns:
New molecule graph with indices permuted.
- randomize_positions(sigma: Tuple[int, int, int] = (0.2, 0.2, 0.1)) Self[source]#
Randomly displace atom positions according to a Gaussian distribution.
- Parameters:
sigma – Standard deviations for displacement in x, y, and z directions in Ångstroms.
- Returns:
New molecule graph with randomized atom positions.
- remove_atoms(remove_inds: Iterable[int]) Tuple[Self, list[Tuple[Atom, list[int]]]][source]#
Remove atoms and corresponding bonds from a molecule graph.
- Parameters:
remove_inds – Indices of atoms to remove.
- Returns:
Tuple (new_molecule, removed), where
new_molecule - New molecule graph where the atoms and bonds have been removed.
removed - Removed atoms and bonds. Each list item is a tuple
(atom, bonds)corresponding to one of the removed atoms. The bonds are encoded as an indicator list where 0 indicates no bond and 1 indicates a bond with the atom at the corresponding index in the new molecule.
- transform_xy(shift: Tuple[float, float] = (0, 0), rot_xy: float = 0, flip_x: bool = False, flip_y: bool = False, center: Tuple[float, float] = (0, 0)) Self[source]#
Transform atom positions in the xy plane.
Transformations are perfomed in the order: shift, rotate, flip x, flip y
- Parameters:
shift – Shift atom positions in xy plane.
rot_xy – Rotate atoms in xy plane by rot_xy degrees around center point.
flip_x – Mirror atom positions in x direction with respect to the center point.
flip_y – Mirror atom positions in y direction with respect to the center point.
center – Point around which transformations are performed.
- Returns:
A new molecule graph with rotated atom positions.
- mlspm.graph.add_rotation_reflection_graph(X: list[ndarray], mols: list[MoleculeGraph], box_borders: ndarray, num_rotations: int = 1, reflections: bool = True, crop: Tuple[int, int] | str | None = None, per_batch_item: bool = True) Tuple[list[ndarray], list[MoleculeGraph]][source]#
Random rotation and reflection of AFM images and corresponding molecule graphs.
- Parameters:
X – Batch of AFM images. Each array in the list is of the shape
(batch, x, y, z).mols – Molecule graphs corresponding to the AFM images.
box_borders – Real-space extent of the AFM image region in Ångströms. The array should be of the form
((x_start, y_start, ...), (x_end, y_end, ...)).num_rotations – Number of rotations for each batch item. The batch size is multiplied by this number.
reflections – Whether to augment with reflections.
crop – If tuple, then output batch is cropped to specified size. If
'max', the crop region will be maximized to fit into the rotated image. Atoms outside the cropped region in the molecule graphs are deleted. The crop region is centered to the middle of the image.per_batch_item – If True, rotation is randomized per batch item, otherwise same rotation for all.
- Returns:
Tuple (X, mols), where
X - Rotation-augmented AFM images.
mols - New rotated molecule graphs.
- mlspm.graph.crop_graph(X: list[ndarray], mols: list[MoleculeGraph], start: Tuple[int, int], size: Tuple[int, int], box_borders: ndarray, new_start: Tuple[float, float] = (0.0, 0.0)) Tuple[list[ndarray], list[MoleculeGraph], ndarray][source]#
Crop AFM images and molecule graphs in a batch to a different size.
- Parameters:
X – Batch of AFM images. Each array in the list is of the shape
(batch, x, y, z).mols – Molecule graphs corresponding to the AFM images.
start – Start pixels for crop in x and y directions.
size – Size of cropped region in x and y directions.
box_borders – Real-space extent of the AFM image region in Ångströms. The array should be of the form
((x_start, y_start, ...), (x_end, y_end, ...)).new_start – The start coordinates of the cropped region in Ångströms.
- Returns:
Tuple (X, mols, box_borders), where
X - Cropped AFM images.
mols - Cropped molecule graphs.
box_borders_cropped - Real-space extent of the cropped region as
((x_start, y_start, ...), (x_end, y_end, ...)).
- mlspm.graph.find_bonds(molecules: list[ndarray], tolerance=0.2) list[list[Tuple[int, int]]][source]#
Find bonds in molecules based on atomic distances and a tabulated bond lengths.
- Parameters:
molecules – Molecule atom position and elements. List of arrays of shape
(num_atoms, 4), where each row corresponds to one atom with[x, y, z, element].tolerance – float. Two atoms are bonded if their distance is at most by a factor of
1 + toleranceas long as the table value for the bond length.
- Returns:
Indices of bonded atoms for each molecule.
- mlspm.graph.find_gaussian_peaks(pos_dist: ndarray | Tensor, box_borders: ndarray, match_threshold: float = 0.7, std: float = 0.3, method: str = 'mad') Tuple[list[ndarray], ndarray, ndarray] | Tuple[list[Tensor], Tensor, Tensor][source]#
Find real-space positions of gaussian peaks in a 3D position distribution grid.
- Parameters:
pos_dist – Position distribution array. Should be of shape
(n_batch, nx, ny, nz).box_borders – Real-space extent of the distribution grid in Ångströms. The array should be of the form
((x_start, y_start, z_start), (x_end, y_end, z_end)).match_threshold – Detection threshold for matching. Regions above the threshold are chosen for method
'zncc', and regions below the threshold are chosen for methods'mad','msd','mad_norm', and'msd_norm'.std – Standard deviation of peaks to search for in Ångströms.
method – Matching method to use. Either zero-normalized cross correlation (
'zncc'), mean absolute distance ('mad'), mean squared distance ('msd'), or the normalized version of the latter two ('mad_norm','msd_norm').
- Returns:
Tuple (xyzs, match, labels), where
xyzs - Positions of the found atoms. Each item in the list is an array of shape (num_atoms, 3) that correspond to one batch item.
matches - Array of matching values. Of the same shape as input pos_dist. For method
'zncc'larger values, and for'mad','msd','mad_norm', and'msd_norm'smaller values correspond to a better match.labels - Labelled regions where match is better than match_threshold. Of the same shape as input pos_dist.
The arrays are of same type as the input pos_dist array.
- mlspm.graph.make_box_borders(shape: Tuple[int, int], res: Tuple[float, float], z_range: Tuple[float, float]) ndarray[source]#
Make grid box borders for a given grid xy shape.
- Parameters:
shape – Grid xy shape.
res – Grid xy pixel resolution in Ångströms.
z_range – Grid z start and end coordinates in Ångströms.
- Returns:
Box start and end coordinates in the form
((x_start, y_start, z_start), (x_end, y_end, z_end)).
- mlspm.graph.make_position_distribution(mols: list[MoleculeGraph], box_borders: ndarray, box_res: Tuple[float, float, float] = (0.125, 0.125, 0.1), std: float = 0.3) ndarray[source]#
Make a distribution on a grid based on atom positions. Each atom is represented by a normal distribution.
- Parameters:
mols – List of molecules.
box_borders – Real-space extent of the distribution grid in Ångströms. The array should be of the form
((x_start, y_start, z_start), (x_end, y_end, z_end)).box_res – Real-space size of a voxel in each direction in Ångströms.
std – float. Standard deviation of normal distribution for each atom in Ångströms.
- Returns:
Array of shape
(n_batch, n_x, n_y, n_z).
- mlspm.graph.save_graphs_to_xyzs(molecules: list[MoleculeGraph], classes: list[list[int | str]], outfile_format: str = './{ind}_graph.xyz', start_ind: int = 0, verbose: int = 1)[source]#
Save molecule graphs to xyz files.
- Parameters:
molecules – Molecule graphs to save.
classes – Chemical elements for atom classification. Either atomic numbers of chemical symbols. The element for each atom in the graph is the first element in the corresponding class.
outfile_format – Formatting string for saved files. Sample index is available in variable
ind.start_ind – Index where file numbering starts.
verbose – Whether to print output information.
- mlspm.graph.shift_mols_window(molecules: list[MoleculeGraph], scan_windows: ndarray, start: Tuple[float, float] = (0, 0)) ndarray[source]#
Shift molecule xy coordinates to use the same scan window. All molecules should have the same scan window size.
- Parameters:
molecules – Molecules whose atom positions to shift.
scan_windows – Scan window for each molecule. Arrays of shape
(n_mol, 2, 3).start – The lower left corner of the new scan window.
- Returns:
Tuple (new_molecules, new_scan_window), where
new_molecules - Molecules with shifted atom coordinates.
new_scan_window - New scan window in the form ((x_start, y_start), (x_end, y_end)).
- mlspm.graph.threshold_atoms_bonds(molecules: list[MoleculeGraph], threshold: float = -1.0, use_vdW: bool = False) list[MoleculeGraph][source]#
Remove atoms and corresponding bonds beyond threshold depth in molecules.
- Parameters:
molecules – Molecules to threshold.
threshold – Deepest z-coordinate for included atoms (top is 0).
use_vdW – Whether to add vdW radii to the atom z coordinates when calculating depth.
- Returns:
Molecules with deep atoms removed.
mlspm.graph.losses#
- class mlspm.graph.losses.GraphLoss(node_factor: float = 1.0, edge_factor: float = 1.0)[source]#
Bases:
ModuleLoss that compares two graphs.
- Parameters:
node_factor – Weight for node classification loss.
edge_factor – Weight for edge classification loss.
- forward(pred: Tuple[list[Tensor], list[Tensor], list[Tensor]], ref: Tuple[list[Tensor], list[Tensor]], separate_loss_factors=False) Tensor | list[Tensor, Tensor, Tensor][source]#
- Parameters:
pred – Predicted graph batch as returned by
GraphImgNet.forward()ref –
Reference graph batch. A tuple (node_classes, edges), where
node_classes - Node classes as class index numbers. List of tensors of shape
(n_atoms,).edges - Edges as pairs of node indices. List of tensors of shape
(2, n_edges).
separate_loss_factors – Whether to return a single total loss or a separated list of values with each loss component.
- Returns:
Computed loss value. Either a single value when
separate_loss_factors==False, or a list[total_loss, node_loss, edge_loss]whenseparate_loss_factors==True.
mlspm.graph.models#
- class mlspm.graph.models.GraphImgNet(n_classes: int, posnet: PosNet | None = None, iters: int = 3, node_feature_size: int = 20, message_size: int = 20, message_hidden_size: int = 64, edge_cutoff: float = 3.0, afm_cutoff: float = 1.25, afm_res: float = 0.125, conv_channels: list[int] = [4, 8, 16], conv_depth: int = 2, node_out_hidden_size: int = 128, edge_out_hidden_size: int = 128, res_connections: bool = True, activation: str | Module = 'relu', padding_mode: str = 'zeros', pool_type: str = 'avg', device: str = 'cuda')[source]#
Bases:
ModuleImage-to-graph model that constructs a molecule graph out of atom positions and an AFM image.
- Parameters:
posnet –
PosNetfor predicting atom positions from an AFM image. Required when training or doing inference without pre-defined atom positions.n_classes – Number of different classes for nodes.
iters – Number of message passing iterations.
node_feature_size – Number of hidden node features.
message_size – Size of message vector.
message_hidden_size – Size of hidden layers in message MLP.
edge_cutoff – Cutoff radius in angstroms for edges between atoms within MPNN.
afm_cutoff – Cutoff radius in angstroms for receptive regions around each atom.
afm_res – Real-space size of pixels in xy-plane in input AFM images in angstroms.
conv_channels – Number channels in 3D conv blocks encoding AFM image regions.
conv_depth – Number of layers in each 3D conv block.
node_out_hidden_size – Size of hidden layers in node classification MLP.
edge_out_hidden_size – Size of hidden layers in edge classification MLP.
res_connections – Whether to use residual connections in conv blocks.
activation – Activation to use after every layer.
'relu','lrelu', or'elu'ornn.Module.padding_mode – Type of padding in each convolution layer.
'zeros','reflect','replicate'or'circular'.pool_type – Type of pooling to use.
'avg'or'max'.device – Device to store model on.
- forward(x: Tensor, pos: list[Tensor] | None = None) Tuple[list[Tensor], list[Tensor], list[Tensor]][source]#
- Parameters:
x – Batch of AFM images. Array of shape
(n_batch, nx, ny, nz).pos – Atom positions for each batch item. If None, the positions are predicted from the AFM images. The positions should be such that the lower left corner of the AFM image is at coordinate
(0, 0), and all positions are within the bounds of the AFM image. Model needs to be constructed with PosNet defined in order for the position prediction to work.
- Returns:
Tuple (node_classes, edge_classes, edges), where
node_classes - Predicted class probabilities for each atom in the molecule graphs. Each batch item is a tensor of shape
(n_atoms, n_classes).edge_classes - Predicted probabilities for the existence of bonds between atoms indicated by edges. Eatch batch item is a tensor of shape
(n_edges,).edges - Possible bond connection indices between atoms. Each batch item is a tensor of shape
(2, n_edges).
- pred_to_graph(pos: list[Tensor], node_classes: list[Tensor], edge_classes: list[Tensor], edges: list[Tensor], bond_threshold: float) list[MoleculeGraph][source]#
Convert predicted batch to a simple list of molecule graphs.
- Parameters:
pos – Atom positions for each batch item.
node_classes – Predicted class probabilities for each atom in the molecule graphs. Each batch item is a tensor of shape
(n_atoms, n_classes).edge_classes – Predicted probabilities for the existence of bonds between atoms indicated by edges. Eatch batch item is a tensor of shape
(n_edges,).edges – Possible bond connection indices between atoms. Each batch item is a tensor of shape
(2, n_edges).bond_threshold – Threshold probability when an edge is considered a bond between atoms.
- Returns:
Molecule graphs corresponding to the predictions.
- predict_graph(x: Tensor | ndarray, pos: Tensor | None = None, bond_threshold: float = 0.5) Tuple[list[MoleculeGraph], Tensor | ndarray | None][source]#
Predict molecule graphs from AFM images.
- Parameters:
X – Batch of AFM images. Array of shape
(n_batch, nx, ny, nz).pos – Atom positions for each batch item. If None, the positions are predicted from the AFM images. The positions should be such that the lower left corner of the AFM image is at coordinate
(0, 0), and all positions are within the bounds of the AFM image.bond_threshold – Threshold probability when an edge is considered a bond between atoms.
- Returns:
Tuple (graphs, grid), where
graphs: Predicted graphs.
grid: Atom position grid from PosNet prediction when input pos is
None. Type matches input AFM image type.
- class mlspm.graph.models.GraphImgNetIce(pretrained_weights: str | None = None, grid_z_range: Tuple[float, float] | None = None, device='cuda')[source]#
Bases:
GraphImgNetGraphImgNet with hyperparameters set exactly as in the paper “Structure discovery in Atomic Force Microscopy imaging of ice”, https://arxiv.org/abs/2310.17161.
Three sets of pretrained weights are available:
'cu111': trained on images of ice clusters on Cu(111)'au111-monolayer': trained on images of ice clusters on monolayer Au(111)'au111-bilayer': trained on images of ice clusters on bilayer Au(111)
- Parameters:
pretrained_weights – Name of pretrained weights. If specified, load pretrained weights. Otherwise, weights are initialized randomly.
grid_z_range – The real-space range in z-direction of the position grid in angstroms. Of the format
(z_min, z_max). Has to be specified when pretrained_weights is not given.device – Device to store model on.
- class mlspm.graph.models.PosNet(encode_block_channels: list[int] = [4, 8, 16, 32], encode_block_depth: int = 2, decode_block_channels: list[int] = [32, 16, 8], decode_block_depth: int = 2, decode_block_channels2: list[int] = [32, 16, 8], decode_block_depth2: int = 2, attention_channels: list[int] = [32, 32, 32], res_connections: bool = True, activation: str = 'relu', padding_mode: str = 'replicate', pool_type: str = 'avg', decoder_z_sizes: list[int] = [5, 10, 20], z_outs: list[int] = [3, 3, 5, 10], attention_activation: str = 'softmax', afm_res: float = 0.125, grid_z_range: Tuple[float, float] = (-1.4, 0.5), peak_std: float = 0.3, match_threshold: float = 0.7, match_method: str = 'msd_norm', device: str = 'cuda')[source]#
Bases:
ModuleAttention U-net for predicting the positions of atoms in an AFM image.
- Parameters:
encode_block_channels – Number channels in encoding 3D conv blocks.
encode_block_depth – Number of layers in each encoding 3D conv block.
decode_block_channels – Number of channels in each decoding 3D conv block after upscale before skip connection.
decode_block_depth – Number of layers in each decoding 3D conv block after upscale before skip connection.
decode_block_channels2 – Number of channels in each decoding 3D conv block after skip connection.
decode_block_depth2 – Number of layers in each decoding 3D conv block after skip connection.
attention_channels – Number of channels in conv layers within each attention block.
res_connections – Whether to use residual connections in conv blocks.
activation – Activation to use after every layer.
'relu','lrelu', or'elu'ornn.Module.padding_mode – Type of padding in each convolution layer.
'zeros','reflect','replicate'or'circular'.pool_type – Type of pooling to use.
'avg'or'max'.decoder_z_sizes – Upscale sizes of decoder stages in the z dimension.
z_outs – Size of the z-dimension after encoder and skip connections.
attention_activation – Type of activation to use for attention map.
'sigmoid'or'softmax'.afm_res – Real-space size of pixels in xy-plane in input AFM images in angstroms.
grid_z_range – The real-space range in z-direction of the position grid in angstroms. Of the format
(z_min, z_max).peak_std – Standard deviation of atom position grid peaks in angstroms.
match_threshold – Detection threshold for matching when finding atom position peaks.
match_method – Method for template matching when finding atom position peaks. See
find_gaussian_peaks()for options.device – Device to store model on.
- forward(x: Tensor) Tuple[Tensor, list[Tensor]][source]#
- Parameters:
x – Batch of AFM images. Should be of shape
(n_batch, nx, ny, nz)`.- Returns:
Tuple (pos_dist, attention_maps), where
pos_dist - Predicted atom position distribution.
attention_maps - Attention maps from the skip connection attention layers.
- get_positions(x: Tensor | ndarray, device: str = 'cuda') Tuple[list[Tensor], Tensor | ndarray, list[Tensor | ndarray]][source]#
Predict atom positions for a batch of AFM images.
- Parameters:
x – Batch of AFM images. Should be of shape
(n_batch, nx, ny, nz).device – Device used when x is an np.ndarray.
- Returns:
Atom positions for each batch item. grid: Atom position grid from PosNet prediction. Type matches input AFM image type attention: Attention maps from skip-connection attention layers. Type matches input AFM image type.
- Return type:
atom_pos