mlspm.data_loading#

class mlspm.data_loading.ShardList(urls: list[str] | str, base_path: str = '', world_size: int = 1, rank: int = 0, substitute_param: bool = False, log: str | None = None)[source]#

Bases: IterableDataset

A webdataset shardlist that fills the size of the url list to be divisible by the world size and splits the urls by rank. The filling is done by randomly doubling elements in the url list.

Additionally, can yield random parameters sets for the same shard, using the pattern *K-{num}* in the files names, substituting different numbers for num.

Parameters:
  • urls – URLs as a list or brace notation string.

  • base_path – The URL paths are relative to this path. Leave empty to use absolute paths.

  • world_size – Number of parallel processes over which the URLs are split.

  • substitute_param – Split shards into parameter sets and yield random parameter set for each shard.

  • log – If not None, path to a file where the yielded shard urls are logged.

mlspm.data_loading.batched(batch_size: int) RestCurried[source]#

Wrapper for webdataset.batched() with a suitable collation function.

The collation function takes collections of sample dictionaries with the following keys and collects them into batched sample dictionaries with the same keys:

  • 'X' - AFM images.

  • 'sw' - Scan windows that determine the real-space extent of the AFM image region.

  • 'Ys' - (Optional) Auxiliary image descriptors corresponding to the AFM images.

Rest of the keys in the dictionary are simply gathered into lists.

mlspm.data_loading.decode_xyz(key: str, data: Any) Tuple[ndarray, ndarray] | Tuple[None, None][source]#

Webdataset pipeline function for decoding xyz files.

Parameters:
  • key – Stream value key. If the key is '.xyz', then the data is decoded.

  • data – Data to decode.

Returns:

Tuple (xyz, scan_window), where

  • xyz - Decoded atom coordinates and elements as an array where each row is of the form [x, y, z, element].

  • scan_window - The xyz coordinates of the opposite corners of the scan window in the form ((x_start, y_start, z_start), (x_end, y_end, z_end))

If the stream key did not match, the tuple is (None, None) instead.

mlspm.data_loading.default_collate(batch: Tuple[ndarray, ...]) Tuple[Tensor, ...][source]#

Transfer a batch of Numpy arrays into Pytorch tensors.

Parameters:

batch – Should contain at least two arrays (X, Y, …), where X are AFM images and Y are image descriptors.

Returns:

Tuple (X, Y, …), where the X and Y are the AFM images and image descriptors as tensors, and the rest of the elements are passed through unchanged.

mlspm.data_loading.get_scan_window_from_comment(comment: str) ndarray[source]#

Process the comment line in a .xyz file and extract the bounding box of the scan. The comment either has the format (QUAM dataset)

Lattice="x0 x1 x2 y0 y1 y2 z0 z1 z2"

where the lattice is assumed to be orthogonal and origin at zero, or

Scan window: [[x_start y_start z_start], [x_end y_end z_end]]

Parameters:

comment – Comment to parse.

Returns:

The xyz coordinates of the opposite corners of the scan window in the form

((x_start, y_start, z_start), (x_end, y_end, z_end))

mlspm.data_loading.rotate_and_stack = <webdataset.filters.RestCurried object>#

Webdataset pipeline filter for _rotate_and_stack()

mlspm.data_loading.worker_init_fn(worker_id: int)[source]#

Initialize each worker with a unique random seed based on it’s ID and current time.

Parameters:

worker_id – ID of the worker process.

mlspm.data_loading._rotate_and_stack(src: Iterable[dict], reverse: bool = False) Generator[dict, None, None][source]#

Take a sample in dict format and update it with fields containing an image stack, xyz coordinates and scan window. Rotate the images to be xy-indexing convention and stack them into a single array.

Likely you don’t want to use this function directly, but the wrapper rotate_and_stack.

Parameters:
  • src

    Iterable of dicts with the fields:

    • '{000..0xx}.{jpg,png}' - PIL.Image.Image of one slice of the simulation.

    • 'xyz' - Tuple(np.ndarray, np.ndarray) of the xyz data and the scan window.

    • 'desc_{0..x}.npy' - optional np.ndarray of image descriptors.

  • reverse – Whether the order of the image stack is reversed.

Returns:

Generator that yields sample dicts with the updated fields 'X', 'Y', 'xyz', 'sw'.

mlspm.data_loading.collate_graph(batch: Tuple[ndarray, list[MoleculeGraph], list[ndarray]]) Tuple[Tensor, list[Tensor], list[Tensor], list[Tensor], list[MoleculeGraph], list[ndarray]][source]#

Collate graph samples into a batch.

Parameters:

batch

Tuple (X, mols, xyz), where

  • X - Input AFM image. Array of shape (batch_size, x, y, z).

  • mols - Input molecules.

  • xyz - List of original molecules. Arrays of shape (n_atoms, 5).

Returns:

Tuple (X, pos, node_classes, edges, mols, xyz), where

  • X - Input AFM images.

  • pos - Graph node xyz coordinates.

  • node_classes - Graph node class indices.

  • edges - Graph edge indices.

  • mols - Input molecules. Unchanged from input argument.

  • xyz - List of original molecules. Unchanged from input argument.