Source code for mlspm._weights

from os import PathLike
from pathlib import Path
from typing import Optional
from urllib.request import urlretrieve

from .utils import _print_progress

WEIGHTS_URLS = {
    "graph-ice-cu111": "https://zenodo.org/records/10054348/files/weights_ice-cu111.pth?download=1",
    "graph-ice-au111-monolayer": "https://zenodo.org/records/10054348/files/weights_ice-au111-monolayer.pth?download=1",
    "graph-ice-au111-bilayer": "https://zenodo.org/records/10054348/files/weights_ice-au111-bilayer.pth?download=1",
    "asdafm-light": "https://zenodo.org/records/10514470/files/weights_asdafm_light.pth?download=1",
    "asdafm-heavy": "https://zenodo.org/records/10514470/files/weights_asdafm_heavy.pth?download=1",
    "edafm-base": "https://zenodo.org/records/10606273/files/base.pth?download=1",
    "edafm-single-channel": "https://zenodo.org/records/10606273/files/single-channel.pth?download=1",
    "edafm-CO-Cl": "https://zenodo.org/records/10606273/files/CO-Cl.pth?download=1",
    "edafm-Xe-Cl": "https://zenodo.org/records/10606273/files/Xe-Cl.pth?download=1",
    "edafm-constant-noise": "https://zenodo.org/records/10606273/files/constant-noise.pth?download=1",
    "edafm-uniform-noise": "https://zenodo.org/records/10606273/files/uniform_noise.pth?download=1",
    "edafm-no-gradient": "https://zenodo.org/records/10606273/files/no-gradient.pth?download=1",
    "edafm-matched-tips": "https://zenodo.org/records/10606273/files/matched-tips.pth?download=1",
}


[docs] def download_weights(weights_name: str, target_path: Optional[PathLike] = None) -> PathLike: """ Download pretrained weights for models. The following weights are available: - ``'graph-ice-cu111'``: PosNet trained on ice clusters on Cu(111). (https://doi.org/10.5281/zenodo.10054348) - ``'graph-ice-au111-monolayer'``: PosNet trained on monolayer ice clusters on Au(111). (https://doi.org/10.5281/zenodo.10054348) - ``'graph-ice-au111-bilayer'``: PosNet trained on bilayer ice clusters on Au(111). (https://doi.org/10.5281/zenodo.10054348) - ``'asdafm-light'``: :class:`.ASDAFMNet` trained on molecules containing the elements H, C, N, O, and F. (https://doi.org/10.5281/zenodo.10514470) - ``'asdafm-heavy'``: :class:`.ASDAFMNet` trained on molecules additionally containing Si, P, S, Cl, and Br. (https://doi.org/10.5281/zenodo.10514470) - ``'edafm-base'``: :class:`.EDAFMNet` used for all predictions in the main ED-AFM paper and used for comparison in the various tests in the supplementary information of the paper. (https://doi.org/10.5281/zenodo.10606273) - ``'edafm-single-channel'``: :class:`.EDAFMNet` trained on only a single CO-tip AFM input. (https://doi.org/10.5281/zenodo.10606273) - ``'edafm-CO-Cl'``: :class:`.EDAFMNet` trained on alternative tip combination of CO and Cl. (https://doi.org/10.5281/zenodo.10606273) - ``'edafm-Xe-Cl'``: :class:`.EDAFMNet` trained on alternative tip combination of Xe and Cl. (https://doi.org/10.5281/zenodo.10606273) - ``'edafm-constant-noise'``: :class:`.EDAFMNet` trained using constant noise amplitude instead of normally distributed amplitude. (https://doi.org/10.5281/zenodo.10606273) - ``'edafm-uniform-noise'``: :class:`.EDAFMNet` trained using uniform random noise amplitude instead of normally distributed amplitude. (https://doi.org/10.5281/zenodo.10606273) - ``'edafm-no-gradient'``: :class:`.EDAFMNet` trained without background-gradient augmentation. (https://doi.org/10.5281/zenodo.10606273) - ``'edafm-matched-tips'``: :class:`.EDAFMNet` trained on data with matched tip distance between CO and Xe, instead of independently randomized distances. (https://doi.org/10.5281/zenodo.10606273) Arguments: weights_name: Name of weights to download. target_path: Path where the weights file will be saved. If specified, the parent directory for the file has to exists. If not specified, a location in a cache directory is chosen. If the target file already exists, the download is skipped Returns: Path where the weights were saved. """ try: weights_url = WEIGHTS_URLS[weights_name] except KeyError: raise ValueError(f"Unrecognized weights name `{weights_name}`") if target_path is None: cache_dir = Path.home() / ".cache" / "mlspm" cache_dir.mkdir(exist_ok=True, parents=True) target_path = cache_dir / f"{weights_name}.pth" else: target_path = Path(target_path) if target_path.exists(): print(f"Target path `{target_path}` already exists. Skipping downloading weights `{weights_name}`.") return target_path print(f"Downloading weights `{weights_name}`: ", end="") urlretrieve(weights_url, target_path, _print_progress) return target_path