mlspm.data_loading#
- class mlspm.data_loading.ShardList(urls: list[str] | str, base_path: str = '', world_size: int = 1, rank: int = 0, substitute_param: bool = False, log: str | None = None)[source]#
Bases:
IterableDatasetA webdataset shardlist that fills the size of the url list to be divisible by the world size and splits the urls by rank. The filling is done by randomly doubling elements in the url list.
Additionally, can yield random parameters sets for the same shard, using the pattern
*K-{num}*in the files names, substituting different numbers for num.- Parameters:
urls – URLs as a list or brace notation string.
base_path – The URL paths are relative to this path. Leave empty to use absolute paths.
world_size – Number of parallel processes over which the URLs are split.
substitute_param – Split shards into parameter sets and yield random parameter set for each shard.
log – If not None, path to a file where the yielded shard urls are logged.
- mlspm.data_loading.batched(batch_size: int) RestCurried[source]#
Wrapper for
webdataset.batched()with a suitable collation function.The collation function takes collections of sample dictionaries with the following keys and collects them into batched sample dictionaries with the same keys:
'X'- AFM images.'sw'- Scan windows that determine the real-space extent of the AFM image region.'Ys'- (Optional) Auxiliary image descriptors corresponding to the AFM images.
Rest of the keys in the dictionary are simply gathered into lists.
- mlspm.data_loading.decode_xyz(key: str, data: Any) Tuple[ndarray, ndarray] | Tuple[None, None][source]#
Webdataset pipeline function for decoding xyz files.
- Parameters:
key – Stream value key. If the key is
'.xyz', then the data is decoded.data – Data to decode.
- Returns:
Tuple (xyz, scan_window), where
xyz - Decoded atom coordinates and elements as an array where each row is of the form
[x, y, z, element].scan_window - The xyz coordinates of the opposite corners of the scan window in the form
((x_start, y_start, z_start), (x_end, y_end, z_end))
If the stream key did not match, the tuple is
(None, None)instead.
- mlspm.data_loading.default_collate(batch: Tuple[ndarray, ...]) Tuple[Tensor, ...][source]#
Transfer a batch of Numpy arrays into Pytorch tensors.
- Parameters:
batch – Should contain at least two arrays (
X,Y, …), whereXare AFM images andYare image descriptors.- Returns:
Tuple (
X,Y, …), where theXandYare the AFM images and image descriptors as tensors, and the rest of the elements are passed through unchanged.
- mlspm.data_loading.get_scan_window_from_comment(comment: str) ndarray[source]#
Process the comment line in a .xyz file and extract the bounding box of the scan. The comment either has the format (QUAM dataset)
Lattice="x0 x1 x2 y0 y1 y2 z0 z1 z2"where the lattice is assumed to be orthogonal and origin at zero, or
Scan window: [[x_start y_start z_start], [x_end y_end z_end]]- Parameters:
comment – Comment to parse.
- Returns:
- The xyz coordinates of the opposite corners of the scan window in the form
((x_start, y_start, z_start), (x_end, y_end, z_end))
- mlspm.data_loading.rotate_and_stack = <webdataset.filters.RestCurried object>#
Webdataset pipeline filter for
_rotate_and_stack()
- mlspm.data_loading.worker_init_fn(worker_id: int)[source]#
Initialize each worker with a unique random seed based on it’s ID and current time.
- Parameters:
worker_id – ID of the worker process.
- mlspm.data_loading._rotate_and_stack(src: Iterable[dict], reverse: bool = False) Generator[dict, None, None][source]#
Take a sample in dict format and update it with fields containing an image stack, xyz coordinates and scan window. Rotate the images to be xy-indexing convention and stack them into a single array.
Likely you don’t want to use this function directly, but the wrapper
rotate_and_stack.- Parameters:
src –
Iterable of dicts with the fields:
'{000..0xx}.{jpg,png}'-PIL.Image.Imageof one slice of the simulation.'xyz'- Tuple(np.ndarray,np.ndarray) of the xyz data and the scan window.'desc_{0..x}.npy'- optionalnp.ndarrayof image descriptors.
reverse – Whether the order of the image stack is reversed.
- Returns:
Generator that yields sample dicts with the updated fields
'X','Y','xyz','sw'.
- mlspm.data_loading.collate_graph(batch: Tuple[ndarray, list[MoleculeGraph], list[ndarray]]) Tuple[Tensor, list[Tensor], list[Tensor], list[Tensor], list[MoleculeGraph], list[ndarray]][source]#
Collate graph samples into a batch.
- Parameters:
batch –
Tuple (X, mols, xyz), where
X - Input AFM image. Array of shape
(batch_size, x, y, z).mols - Input molecules.
xyz - List of original molecules. Arrays of shape
(n_atoms, 5).
- Returns:
Tuple (X, pos, node_classes, edges, mols, xyz), where
X - Input AFM images.
pos - Graph node xyz coordinates.
node_classes - Graph node class indices.
edges - Graph edge indices.
mols - Input molecules. Unchanged from input argument.
xyz - List of original molecules. Unchanged from input argument.