Source code for mlspm.utils

import glob
import os
import re
from typing import Optional, Tuple

import numpy as np
import torch
import torch.distributed as dist

elements = [
    "H",
    "He",
    "Li",
    "Be",
    "B",
    "C",
    "N",
    "O",
    "F",
    "Ne",
    "Na",
    "Mg",
    "Al",
    "Si",
    "P",
    "S",
    "Cl",
    "Ar",
    "K",
    "Ca",
    "Sc",
    "Ti",
    "V",
    "Cr",
    "Mn",
    "Fe",
    "Co",
    "Ni",
    "Cu",
    "Zn",
    "Ga",
    "Ge",
    "As",
    "Se",
    "Br",
    "Kr",
    "Rb",
    "Sr",
    "Y",
    "Zr",
    "Nb",
    "Mo",
    "Tc",
    "Ru",
    "Rh",
    "Pd",
    "Ag",
    "Cd",
    "In",
    "Sn",
    "Sb",
    "Te",
    "I",
    "Xe",
]


def _calc_plot_dim(n: int, f: float = 0.3):
    rows = max(int(np.sqrt(n) - f), 1)
    cols = 1
    while rows * cols < n:
        cols += 1
    return rows, cols


def _get_distributed() -> Tuple[int, int, int, Optional[dist.ProcessGroup]]:
    try:
        if "RANK" in os.environ and "LOCAL_RANK" in os.environ and "WORLD_SIZE" in os.environ:
            world_size = int(os.environ["WORLD_SIZE"])
            global_rank = int(os.environ["RANK"])
            local_rank = int(os.environ["LOCAL_RANK"])
        else:
            world_size = dist.get_world_size()
            local_rank = global_rank = dist.get_rank()
        group = dist.group.WORLD
    except (RuntimeError, AssertionError, ValueError):
        world_size = 1
        group = None
    if world_size <= 1:
        world_size = 1
        local_rank = global_rank = 0
    return world_size, local_rank, global_rank, group


def _print_progress(block_num: int, block_size: int, total_size: int):
    if total_size == -1:
        return
    delta = block_size / total_size * 100
    current_size = block_num * block_size
    percent = current_size / total_size * 100
    percent_int = int(percent)
    if (percent - percent_int) > 1.0001 * delta:
        # Only print when crossing an integer percentage
        return
    if block_num > 0:
        print("\b\b\b", end="", flush=True)
    if current_size < total_size:
        print(f"{percent_int:2d}%", end="", flush=True)
    else:
        print("Done")


[docs] class Checkpointer: """ Keep checkpoints of a Pytorch model and optimizer over epochs, keeping only the one with best loss. Also load the latest checkpoint in the beginning if any exist. Arguments: model: Pytorch model. optimizer: Pytorch optimizer. additional_module: Additional modules whose state will be saved to the checkpoints. checkpoint_dir: Path to directory where checkpoints are saved. keep_last_epoch: Also keep the last epoch even if it does not have the best loss. """ def __init__( self, model: torch.nn.Module, optimizer: torch.optim.Optimizer, additional_data: dict = {}, checkpoint_dir: str = "./Checkpoints", keep_last_epoch: bool = True, ): if hasattr(model, "module"): model = model.module self.model = model self.optimizer = optimizer self.additional_data = additional_data self.additional_data["best_loss"] = np.inf self.additional_data["best_epoch"] = 0 self.checkpoint_dir = checkpoint_dir self.keep_last_epoch = keep_last_epoch self.world_size, self.local_rank, self.global_rank, self.group = _get_distributed() self._get_init_epoch() @property def best_epoch(self): return self.additional_data["best_epoch"] @property def best_loss(self): return self.additional_data["best_loss"] def _get_init_epoch(self): dir_exists = os.path.exists(self.checkpoint_dir) if self.world_size > 1: dist.barrier() if not dir_exists: if self.global_rank == 0: os.makedirs(self.checkpoint_dir) cp_files = glob.glob(os.path.join(self.checkpoint_dir, "model_*.pth")) if len(cp_files) == 0: self.epoch = 1 return self.epoch = sorted([int(re.search("[0-9]+", os.path.split(p)[1]).group(0)) for p in cp_files])[-1] last_cp_path = os.path.join(self.checkpoint_dir, f"model_{self.epoch}.pth") load_checkpoint(self.model, self.optimizer, last_cp_path, self.additional_data, self.local_rank) self.epoch += 1
[docs] def next_epoch(self, loss: float): """ Advance epoch and save state if loss improved. Arguments: loss: Loss value for the current epoch. """ if self.global_rank == 0: improved = loss < self.best_loss prev_best_epoch = self.best_epoch if improved: if prev_best_epoch > 0: os.remove(os.path.join(self.checkpoint_dir, f"model_{prev_best_epoch}.pth")) self.additional_data["best_loss"] = loss self.additional_data["best_epoch"] = self.epoch if self.keep_last_epoch and prev_best_epoch != (self.epoch - 1): os.remove(os.path.join(self.checkpoint_dir, f"model_{self.epoch - 1}.pth")) if self.keep_last_epoch or improved: save_checkpoint(self.model, self.optimizer, self.epoch, self.checkpoint_dir, additional_data=self.additional_data) if self.world_size > 1: dist.broadcast(torch.tensor(self.additional_data["best_loss"], device=self.local_rank, dtype=torch.float), src=0) dist.broadcast(torch.tensor(self.additional_data["best_epoch"], device=self.local_rank, dtype=torch.long), src=0) else: best_loss = torch.tensor(0, device=self.local_rank, dtype=torch.float) best_epoch = torch.tensor(0, device=self.local_rank, dtype=torch.long) dist.broadcast(best_loss, src=0) dist.broadcast(best_epoch, src=0) self.additional_data["best_loss"] = best_loss.cpu().item() self.additional_data["best_epoch"] = best_epoch.cpu().item() self.epoch += 1
[docs] def revert_to_best_epoch(self): """Revert model state to the best epoch.""" best_cp_path = os.path.join(self.checkpoint_dir, f'model_{self.additional_data["best_epoch"]}.pth') load_checkpoint(self.model, self.optimizer, best_cp_path, self.additional_data)
[docs] def save_checkpoint(model: torch.nn.Module, optimizer: torch.optim.Optimizer, epoch: int, save_dir: str, additional_data: dict = {}): """ Save a Pytorch model/optimizer checkpoint. Arguments: model: Model whose state to save. optimizer: Optimizer whose state to save. epoch: Training epoch. save_dir: Directory to save in. additional_data: A dictionary of additional modules or data to save to the checkpoint. """ if not os.path.exists(save_dir): os.makedirs(save_dir) if hasattr(model, "module"): model = model.module state = { "model_params": model.state_dict(), "optim_params": optimizer.state_dict(), } for key in additional_data: data = additional_data[key] if hasattr(data, "state_dict"): data = data.state_dict() state[key] = data save_path = os.path.join(save_dir, f"model_{epoch}.pth") torch.save(state, save_path) print(f"Model, optimizer weights on epoch {epoch} saved to {save_path}")
[docs] def load_checkpoint( model: torch.nn.Module, optimizer: torch.optim.Optimizer = None, file_name: str = "./model.pth", additional_data: Optional[dict] = None, rank: Optional[int] = None, ): """ Load a Pytorch model/optimizer checkpoint. Arguments: model: Model whose state to load. optimizer: Optimizer whose state to load. file_name: Checkpoint file to load from. additional_data: If not None, a dictionary with additional modules or other data that will be updated with additional data found in the checkpoint. rank: Process rank for distributed training. """ import torch if rank is None: state = torch.load(file_name, weights_only=False) else: state = torch.load(file_name, map_location={"cuda:0": f"cuda:{rank}"}, weights_only=False) model.load_state_dict(state["model_params"]) if optimizer: optimizer.load_state_dict(state["optim_params"]) print(f"Model, optimizer weights loaded from {file_name}") else: print(f"Model weights loaded from {file_name}") if additional_data is not None: for key in state: if key in additional_data: data = additional_data[key] if hasattr(data, "load_state_dict"): data.load_state_dict(state[key]) print(f"Loaded state for `{key}`") else: additional_data[key] = state[key] print(f"Updated data for `{key}`")
[docs] def read_xyzs(file_paths: list[str], return_comment: bool = False) -> list[np.ndarray]: """ Read molecule xyz files. Arguments: file_paths: Paths to xyz files return_comment: If True, also return the comment string on second line of file. Returns: Arrays of shape ``(num_atoms, 4)`` or ``(num_atoms, 5)``. Each row in the arrays corresponds to one atom with ``[x, y, z, element]`` or ``[x, y, z, charge, element]``. """ mols = [] comments = [] for file_path in file_paths: with open(file_path, "r") as f: N = int(f.readline().strip()) comments.append(f.readline()) atoms = [] for line in f: line = line.strip().split() try: elem = int(line[0]) except ValueError: elem = elements.index(line[0]) + 1 posc = [float(p) for p in line[1:]] atoms.append(posc + [elem]) mols.append(np.array(atoms)) if return_comment: mols = mols, comments return mols
[docs] def write_to_xyz(molecule: np.ndarray, outfile: str = "./pos.xyz", comment_str: str = "", verbose: int = 1): """ Write molecule into xyz file. Arguments: molecule: Molecule to write. np.array of shape (num_atoms, 4) or (num_atoms, 5). Each row corresponds to one atom with [x, y, z, element] or [x, y, z, charge, element]. outfile: Path where xyz file will be saved. comment_str: Comment written to the second line of the file. verbose: 0 or 1. Whether to print output information. """ molecule = molecule[molecule[:, -1] > 0] with open(outfile, "w") as f: f.write(f"{len(molecule)}\n{comment_str}\n") for atom in molecule: f.write(f"{int(atom[-1])}\t") for i in range(len(atom) - 1): f.write(f"{atom[i]:10.8f}\t") f.write("\n") if verbose > 0: print(f"Molecule xyz file saved to {outfile}")
[docs] def batch_write_xyzs(xyzs: list[np.ndarray], outdir: str = "./", start_ind: int = 0, verbose: int = 1): """ Write a batch of xyz files 0_mol.xyz, 1_mol.xyz, ... Arguments: xyzs: Molecules to write. outdir: Directory where files are saved. start_ind: Index where file numbering starts. verbose: 0 or 1. Whether to print output information. """ if outdir and not os.path.exists(outdir): os.makedirs(outdir) ind = start_ind for xyz in xyzs: write_to_xyz(xyz, os.path.join(outdir, f"{ind}_mol.xyz"), verbose=verbose) ind += 1
[docs] def count_parameters(module: torch.nn.Module) -> int: """ Count trainable parameters in a Pytorch module. Arguments: module: Pytorch module. """ return sum(p.numel() for p in module.parameters() if p.requires_grad)