Source code for mlspm.data_generation

import io
import multiprocessing as mp
import os
import queue
import tarfile
import time
from multiprocessing.shared_memory import SharedMemory
from os import PathLike
from pathlib import Path
from typing import Optional, TypedDict
import warnings

import numpy as np
from PIL import Image
from ppafm.ocl.field import ElectronDensity, HartreePotential


[docs] class TarWriter: """ Write samples of AFM images, molecules and descriptors to tar files. Use as a context manager and add samples with :meth:`add_sample`. Each tar file has a maximum number of samples, and whenever that maximum is reached, a new tar file is created. The generated tar files are named as ``{base_name}_{n}.tar`` and saved into the specified folder. Arguments: base_path: Path to directory where tar files are saved. base_name: Base name for output tar files. The number of the tar file is appended to the name. max_count: Maximum number of samples per tar file. async_write: Write tar files asynchronously in a parallel process. """ def __init__(self, base_path: PathLike = "./", base_name: str = "", max_count: int = 100, async_write=True): self.base_path = Path(base_path) self.base_name = base_name self.max_count = max_count self.async_write = async_write def __enter__(self): self.sample_count = 0 self.total_count = 0 self.tar_count = 0 if self.async_write: self._launch_write_process() else: self._ft = self._get_tar_file() return self def __exit__(self, exc_type, exc_value, exc_traceback): if self.async_write: self._event_done.set() if not self._event_tar_close.wait(60): warnings.warn("Write process did not respond within timeout period. Last tar file may not have been closed properly.") else: self._ft.close() def _launch_write_process(self): self._q = mp.Queue(1) self._event_done = mp.Event() self._event_tar_close = mp.Event() p = mp.Process(target=self._write_async) p.start() def _write_async(self): self._ft = self._get_tar_file() try: while True: try: sample = self._q.get(block=False) self._add_sample(*sample) continue except queue.Empty: pass if self._event_done.is_set() and self._q.empty(): self._ft.close() self._event_tar_close.set() return except: self._ft.close() self._event_tar_close.set() def _get_tar_file(self): file_path = self.base_path / f"{self.base_name}_{self.tar_count}.tar" if os.path.exists(file_path): raise RuntimeError(f"Tar file already exists at `{file_path}`") return tarfile.open(file_path, "w", format=tarfile.GNU_FORMAT) def _add_sample(self, X, xyzs, Y, comment_str): if self.sample_count >= self.max_count: self.tar_count += 1 self.sample_count = 0 self._ft.close() self._ft = self._get_tar_file() # Write AFM images for i, x in enumerate(X): for j in range(x.shape[-1]): xj = x[:, :, j] xj = ((xj - xj.min()) / np.ptp(xj) * (2**8 - 1)).astype(np.uint8) # Convert range to 0-255 integers img_bytes = io.BytesIO() Image.fromarray(xj.T[::-1], mode="L").save(img_bytes, "png") img_bytes.seek(0) # Return stream to start so that addfile can read it correctly self._ft.addfile(get_tarinfo(f"{self.total_count}.{j:02d}.{i}.png", img_bytes), img_bytes) img_bytes.close() # Write xyz file xyz_bytes = io.BytesIO() xyz_bytes.write(bytearray(f"{len(xyzs)}\n{comment_str}\n", "utf-8")) for xyz in xyzs: xyz_bytes.write(bytearray(f"{int(xyz[-1])}\t", "utf-8")) for i in range(len(xyz) - 1): xyz_bytes.write(bytearray(f"{xyz[i]:10.8f}\t", "utf-8")) xyz_bytes.write(bytearray("\n", "utf-8")) xyz_bytes.seek(0) # Return stream to start so that addfile can read it correctly self._ft.addfile(get_tarinfo(f"{self.total_count}.xyz", xyz_bytes), xyz_bytes) xyz_bytes.close() # Write image descriptors (if any) if Y is not None: for i, y in enumerate(Y): img_bytes = io.BytesIO() np.save(img_bytes, y.astype(np.float32)) img_bytes.seek(0) # Return stream to start so that addfile can read it correctly self._ft.addfile(get_tarinfo(f"{self.total_count}.desc_{i}.npy", img_bytes), img_bytes) img_bytes.close() self.sample_count += 1 self.total_count += 1
[docs] def add_sample(self, X: list[np.ndarray], xyzs: np.ndarray, Y: Optional[np.ndarray] = None, comment_str: str = ""): """ Add a sample to the current tar file. Arguments: X: AFM images. Each list item corresponds to an AFM tip and is an array of shape (nx, ny, nz). xyzs: Atom coordinates and elements. Each row is one atom and is of the form [x, y, z, element]. Y: Image descriptors. Each list item is one descriptor and is an array of shape (nx, ny). comment_str: Comment line (second line) to add to the xyz file. """ if self.async_write: self._q.put((X, xyzs, Y, comment_str), block=True, timeout=60) else: self._add_sample(X, xyzs, Y, comment_str)
[docs] def get_tarinfo(fname: str, file_bytes: io.BytesIO): info = tarfile.TarInfo(fname) info.size = file_bytes.getbuffer().nbytes info.mtime = time.time() return info
[docs] class TarSampleList(TypedDict, total=False): """ - ``'hartree'``: Paths to the Hartree potentials. First item in the tuple is the path to the tar file, and second entry is a list of tar file member names. - ``'rho'``: (Optional) Paths to the electron densities. First item in the tuple is the path to the tar file, and second entry is a list tar file member names. - ``'rots'``: List of rotations for each sample. """ hartree: tuple[PathLike, list[str]] rho: tuple[PathLike, list[str]] rots: list[np.ndarray]
[docs] class TarDataGenerator: """ Iterable that loads data from tar archives with data saved in npz format for generating samples with ``GeneratorAFMTrainer`` in *ppafm*. The npz files should contain the following entries: - ``'data'``: An array containing the potential/density on a 3D grid. - ``'origin'``: Lattice origin in 3D space as an array of shape ``(3,)``. - ``'lattice'``: Lattice vectors as an array of shape ``(3, 3)``, where the rows are the vectors. - ``'xyz'``: Atom xyz coordinates as an array of shape ``(n_atoms, 3)``. - ``'Z'``: Atom atomic numbers as an array of shape ``(n_atoms,)``. Yields dicts that contain the following: - ``'xyzs'``: Atom xyz coordinates. - ``'Zs'``: Atomic numbers. - ``'qs'``: Sample Hartree potential. - ``'rho_sample'``: Sample electron density if the sample dict contained ``rho``, or ``None`` otherwise. - ``'rot'``: Rotation matrix. Note: It is recommended to use ``multiprocessing.set_start_method('spawn')`` when using the :class:`TarDataGenerator`. Otherwise a lot of warnings about leaked memory objects may be thrown on exit. Arguments: samples: List of sample dicts as :class:`TarSampleList`. File paths should be relative to ``base_path``. base_path: Path to the directory with the tar files. n_proc: Number of parallel processes for loading data. The sample lists get divided evenly over the processes. For memory usage, note that a maximum number of samples double the number of processes can be loaded into memory at the same time. scale_pot: The loaded Hartree potentials are scaled by this factor in order to correct the units. The yielded potential should be in units of V. The default value of -1 works for potentials in units of eV. scale_rho: The loaded electron densities are scaled by this factor in order to correct the units. The yielded density should be in units of e/Å^3 with positive sign for the electron density. """ _timings = False def __init__( self, samples: list[TarSampleList], base_path: PathLike = "./", n_proc: int = 1, scale_pot: float = -1, scale_rho: float = 1 ): self.samples = samples self.base_path = Path(base_path) self.n_proc = n_proc self.scale_pot = scale_pot self.scale_rho = scale_rho self.pot = None self.rho = None def __len__(self) -> int: """Total number of samples (including rotations)""" return sum([sum([len(rots) for rots in sample_list["rots"]]) for sample_list in self.samples]) def _launch_procs(self): queue_size = 2 * self.n_proc self.q = mp.Queue(queue_size) self.events = [] samples_split = np.array_split(self.samples, self.n_proc) for i in range(self.n_proc): event = mp.Event() p = mp.Process(target=self._load_samples, args=(samples_split[i], i, event)) p.start() self.events.append(event) def __iter__(self): self._launch_procs() self.iterator = iter(self._yield_samples()) return self def __next__(self): return next(self.iterator) def _get_data(self, tar: tarfile.TarFile, name: str) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: data = np.load(tar.extractfile(name)) array = data["data"] origin = data["origin"] lattice = data["lattice"] xyzs = data["xyz"] Zs = data["Z"] lvec = np.concatenate([origin[None, :], lattice], axis=0) return array, lvec, xyzs, Zs def _load_samples(self, sample_lists: list[TarSampleList], i_proc: int, event: mp.Event): proc_id = str(time.time_ns() + 1000000000 * i_proc)[-10:] print(f"Starting worker {i_proc}, id {proc_id}, samples: {len(sample_lists)}") start_time = time.perf_counter() total_bytes = 0 n_sample_total = 0 for sample_list in sample_lists: tar_path_hartree, name_list_hartree = sample_list["hartree"] tar_hartree = tarfile.open(self.base_path / tar_path_hartree, "r") n_sample = len(name_list_hartree) if len(sample_list["rots"]) != n_sample: raise ValueError(f"Inconsistent number of rotations in sample list ({len(sample_list['rots'])} != {n_sample})") use_rho = ("rho" in sample_list) and (sample_list["rho"] is not None) if use_rho: tar_path_rho, name_list_rho = sample_list["rho"] tar_rho = tarfile.open(self.base_path / tar_path_rho, "r") if len(name_list_rho) != n_sample: raise ValueError( f"Inconsistent number of samples between hartree and rho lists ({len(name_list_rho)} != {n_sample})" ) shm_pot_prev = None shm_rho_prev = None for i_sample in range(n_sample): if self._timings: t0 = time.perf_counter() # Load data from tar(s) rots = sample_list["rots"][i_sample] pot, lvec_pot, xyzs, Zs = self._get_data(tar_hartree, name_list_hartree[i_sample]) if not np.allclose(self.scale_pot, 1): pot *= self.scale_pot total_bytes += pot.nbytes if use_rho: rho, lvec_rho, _, _ = self._get_data(tar_rho, name_list_rho[i_sample]) if not np.allclose(self.scale_rho, 1): rho *= self.scale_rho rho_shape = rho.shape total_bytes += rho.nbytes else: lvec_rho = None rho_shape = None if self._timings: t1 = time.perf_counter() # Put the data to shared memory sample_id_pot = f"{i_proc}_{proc_id}_{i_sample}_pot" shm_pot = _put_to_shared_memory(pot, sample_id_pot) if use_rho: sample_id_rho = f"{i_proc}_{proc_id}_{i_sample}_rho" shm_rho = _put_to_shared_memory(rho, sample_id_rho) else: sample_id_rho = None shm_rho = None # Inform the main process of the data using the queue self.q.put((i_proc, sample_id_pot, sample_id_rho, pot.shape, rho_shape, lvec_pot, lvec_rho, xyzs, Zs, rots)) if self._timings: t2 = time.perf_counter() if i_sample > 0: # Wait until main process is done with the previous data _wait_and_unlink(i_proc, event, shm_pot_prev, shm_rho_prev) if self._timings: t3 = time.perf_counter() n_sample_total += 1 print( f"[Worker {i_proc}, id {sample_id_pot}] Get data / Shm / Wait-unlink: " f"{t1 - t0:.5f} / {t2 - t1:.5f} / {t3 - t2:.5f} " ) shm_pot_prev = shm_pot shm_rho_prev = shm_rho # Wait to unlink the last data _wait_and_unlink(i_proc, event, shm_pot, shm_rho) tar_hartree.close() if use_rho: tar_rho.close() if self._timings: dt = time.perf_counter() - start_time print( f"[Worker {i_proc}]: Loaded {n_sample_total} samples in {dt}s, totaling {total_bytes / 2**30:.3f}GiB. " f"Average load time: {dt / n_sample_total}s." ) def _get_queue_sample( self, ) -> tuple[int, np.ndarray, np.ndarray, list[np.ndarray], HartreePotential, SharedMemory, ElectronDensity, SharedMemory, str]: if self._timings: t0 = time.perf_counter() i_proc, sample_id_pot, sample_id_rho, pot_shape, rho_shape, lvec_pot, lvec_rho, xyzs, Zs, rots = self.q.get(timeout=200) if self._timings: t1 = time.perf_counter() shm_pot = SharedMemory(sample_id_pot) pot = np.ndarray(pot_shape, dtype=np.float32, buffer=shm_pot.buf) if self.pot is None: self.pot = HartreePotential(pot, lvec_pot) else: self.pot.update_array(pot, lvec_pot) if self._timings: t2 = time.perf_counter() if sample_id_rho is not None: shm_rho = SharedMemory(sample_id_rho) rho = np.ndarray(rho_shape, dtype=np.float32, buffer=shm_rho.buf) if self.rho is None: self.rho = ElectronDensity(rho, lvec_rho) else: self.rho.update_array(rho, lvec_rho) else: shm_rho = None rho = None if self._timings: t3 = time.perf_counter() print(f"[Main, receive data, id {sample_id_pot}] Queue / Pot / Rho: " f"{t1 - t0:.5f} / {t2 - t1:.5f} / {t3 - t2:.5f}") return i_proc, xyzs, Zs, rots, self.pot, shm_pot, self.rho, shm_rho, sample_id_pot def _yield_samples(self): start_time = time.perf_counter() n_sample_yielded = 0 n_sample_total = sum([len(sample_list["rots"]) for sample_list in self.samples]) for _ in range(n_sample_total): if self._timings: t0 = time.perf_counter() i_proc, xyzs, Zs, rots, pot, shm_pot, rho, shm_rho, sample_id = self._get_queue_sample() if self._timings: t1 = time.perf_counter() for rot in rots: sample_dict = {"xyzs": xyzs, "Zs": Zs, "qs": pot, "rho_sample": rho, "rot": rot} yield sample_dict n_sample_yielded += 1 if self._timings: t2 = time.perf_counter() # Close shared memory and inform producer that the shared memory can be unlinked shm_pot.close() if shm_rho is not None: shm_rho.close() self.events[i_proc].set() if self._timings: t3 = time.perf_counter() print(f"[Main, id {sample_id}] Receive data / Yield / Event: " f"{t1 - t0:.5f} / {t2 - t1:.5f} / {t3 - t2:.5f}") if self._timings: dt = time.perf_counter() - start_time print(f"[Main]: Yielded {n_sample_yielded} samples in {dt}s. Average yield time: {dt / n_sample_yielded}s.")
def _put_to_shared_memory(array, name): shm = SharedMemory(create=True, size=array.nbytes, name=name) b = np.ndarray(array.shape, dtype=np.float32, buffer=shm.buf) b[:] = array[:] return shm def _wait_and_unlink(i_proc, event, shm_pot, shm_rho): if not event.wait(timeout=60): raise RuntimeError(f"[Worker {i_proc}]: Did not receive signal from main process in 60 seconds.") event.clear() shm_pot.close() shm_pot.unlink() if shm_rho: shm_rho.close() shm_rho.unlink()