Source code for mlspm.graph._visualization

import os
from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
from matplotlib import cm, gridspec

from ..utils import _calc_plot_dim, elements
from . import MoleculeGraph

CLASS_COLORS = "rkbgcmy"


[docs] def plot_graphs( pred: Optional[list[MoleculeGraph]] = None, ref: Optional[list[MoleculeGraph]] = None, box_borders: np.ndarray = np.array(((0, 0, -1.4), (16, 16, 0.5))), outdir: str = "./graphs/", classes: list[list[int]] = None, class_colors: list[str] = CLASS_COLORS, start_ind: int = 0, verbose: int = 1, ): """ Plot batch of graphs into files 0_graph.png, 1_graph.png, ... etc. Arguments: pred: Predicted molecule graphs. ref: Reference molecule graphs. box_borders: Real-space extent of the plotting region in Ångströms. The array should be of the form ``((x_start, y_start, z_start), (x_end, y_end, z_end))``. outdir: Directory where files are saved. classes: Classes for categorizing atoms based on their chemical elements. Each class is a list of elements either as atomic numbers or as chemical symbols. class_colors: Colors for each atom class. start_ind: Starting index for file naming. verbose: Whether to print output information. """ n_plot = (pred is not None) + (ref is not None) if n_plot == 0: raise ValueError("pred and ref cannot both be None.") if (pred is not None) and (ref is not None) and (len(pred) != len(ref)): raise ValueError(f"pred ({len(pred)}) and ref ({len(ref)}) have different number of samples.") n_samples = len(pred) if pred is not None else len(ref) if not os.path.exists(outdir): os.makedirs(outdir) if classes is None: atom_classes = [] for mols in [pred, ref]: if mols is not None: for m in mols: atom_classes += list(m.array(class_index=True)) n_classes = max(atom_classes) + 1 classes = [f"Class {i}" for i in range(n_classes)] else: n_classes = len(classes) classes = [", ".join([elements[e - 1] for e in c]) for c in classes] z_min = box_borders[0][2] z_max = box_borders[1][2] scatter_size = 160 def get_marker_size(z, max_size): return max_size * (z - z_min) / (z_max - z_min) def plot_xy(ax, mol): if len(mol) > 0: mol_pos = mol.array(xyz=True) s = get_marker_size(mol_pos[:, 2], scatter_size) if (s < 0).any(): raise ValueError("Encountered atom z position(s) below box borders.") c = np.array([class_colors[atom.class_index] for atom in mol.atoms]) for b in mol.bonds: pos = np.vstack([mol_pos[b[0]], mol_pos[b[1]]]) ax.plot(pos[:, 0], pos[:, 1], "k", linewidth=2, zorder=1) sort_mask = mol_pos[:, 2].argsort() ax.scatter(mol_pos[sort_mask, 0], mol_pos[sort_mask, 1], c=c[sort_mask], s=s[sort_mask], edgecolors="k", zorder=2) ax.set_xlim(box_borders[0][0], box_borders[1][0]) ax.set_ylim(box_borders[0][1], box_borders[1][1]) ax.set_aspect("equal", "box") def plot_xz(ax, mol): if len(mol) > 0: order = list(np.argsort(mol.array(xyz=True)[:, 1])[::-1]) mol = mol.permute(order) mol_pos = mol.array(xyz=True) s = get_marker_size(mol_pos[:, 2], scatter_size) if (s < 0).any(): raise ValueError("Encountered atom z position(s) below box borders.") c = np.array([class_colors[atom.class_index] for atom in mol.atoms]) for b in mol.bonds: pos = np.vstack([mol_pos[b[0]], mol_pos[b[1]]]) ax.plot(pos[:, 0], pos[:, 2], "k", linewidth=2, zorder=1) sort_mask = (-mol_pos[:, 1]).argsort() ax.scatter(mol_pos[sort_mask, 0], mol_pos[sort_mask, 2], c=c[sort_mask], s=s[sort_mask], edgecolors="k", zorder=2) ax.set_xlim(box_borders[0][0], box_borders[1][0]) ax.set_ylim(box_borders[0][2], box_borders[1][2]) ax.set_aspect("equal", "box") ind = start_ind for i in range(n_samples): # Setup plot grid x_size = 5 * n_plot x_extra = 0.35 * max([len(c) for c in classes]) fig = plt.figure(figsize=(x_size + x_extra, 6.5)) fig_grid = gridspec.GridSpec(1, 2, width_ratios=(x_size, x_extra), wspace=1 / (x_size + x_extra)) grid_graphs = fig_grid[0, 0].subgridspec(2, n_plot, height_ratios=(5, 1.5), hspace=0.1, wspace=0.2) # Prediction if pred is not None: ax_xy_pred = fig.add_subplot(grid_graphs[0, 0]) ax_xz_pred = fig.add_subplot(grid_graphs[1, 0]) plot_xy(ax_xy_pred, pred[i]) plot_xz(ax_xz_pred, pred[i]) ax_xy_pred.set_xlabel("x (Å)", fontsize=12) ax_xy_pred.set_ylabel("y (Å)", fontsize=12) ax_xz_pred.set_xlabel("x (Å)", fontsize=12) ax_xz_pred.set_ylabel("z (Å)", fontsize=12) ax_xy_pred.set_title("Prediction", fontsize=20) i_plot = 1 else: i_plot = 0 # Reference if ref is not None: ax_xy_ref = fig.add_subplot(grid_graphs[0, i_plot]) ax_xz_ref = fig.add_subplot(grid_graphs[1, i_plot]) plot_xy(ax_xy_ref, ref[i]) plot_xz(ax_xz_ref, ref[i]) ax_xy_ref.set_xlabel("x (Å)", fontsize=12) ax_xy_ref.set_ylabel("y (Å)", fontsize=12) ax_xz_ref.set_xlabel("x (Å)", fontsize=12) ax_xz_ref.set_ylabel("z (Å)", fontsize=12) ax_xy_ref.set_title("Reference", fontsize=20) # Plot legend ax_legend = fig.add_subplot(fig_grid[0, 1]) # Class colors dy = 0.08 dx = 0.35 / x_extra y_start = 0.5 + dy * (n_classes + 3) / 2 for i, c in enumerate(classes): ax_legend.scatter(dx, y_start - dy * i, s=scatter_size, c=class_colors[i], edgecolors="k") ax_legend.text(2 * dx, y_start - dy * i, c, fontsize=16, ha="left", va="center_baseline") # Marker sizes y_start2 = y_start - (n_classes + 1) * dy marker_zs = np.array([z_max, (z_min + z_max + 0.2) / 2, z_min + 0.2]) ss = get_marker_size(marker_zs, scatter_size) for i, (s, z) in enumerate(zip(ss, marker_zs)): ax_legend.scatter(dx, y_start2 - dy * i, s=s, c="w", edgecolors="k") ax_legend.text(2 * dx, y_start2 - dy * i, f"z = {z:.2f}Å", fontsize=16, ha="left", va="center_baseline") ax_legend.set_xlim(0, 1) ax_legend.set_ylim(0, 1) ax_legend.axis("off") plt.savefig(save_path := os.path.join(outdir, f"{ind}_graph.png")) if verbose > 0: print(f"Graph image saved to {save_path}") plt.close() ind += 1
[docs] def plot_distribution_grid( pred_dist: np.ndarray, ref_dist: Optional[np.ndarray] = None, box_borders: np.ndarray = np.array(((2, 2, -1.5), (18, 18, 0))), outdir: str = "./graphs/", start_ind: int = 0, verbose: int = 1, ): """ Plot batch of position distribution grids into files 0_pred_dist.png, 1_pred_dist.png, ..., and 0_pred_dist2D.png 1_pred_dist2D.png, ... etc. The full grids are divided into separate images for each z-slice in the arrays. The 2D grids are averaged over the z-dimension of the full grids. Arguments: pred_dist: Predicted position distribution grid. ref_dist: Reference position distribution grid. box_borders: Real-space extent of the distribution grid region in Ångströms. The array should be of the form ``((x_start, y_start, z_start), (x_end, y_end, z_end))``. outdir: Directory where files are saved. start_ind: Starting index for file naming. verbose: Whether to print output information. """ if ref_dist is not None: assert pred_dist.shape == ref_dist.shape, (pred_dist.shape, ref_dist.shape) n_img = 2 if ref_dist is not None else 1 if not os.path.exists(outdir): os.makedirs(outdir) fontsize = 24 z_start = box_borders[0][2] z_res = (box_borders[1][2] - box_borders[0][2]) / (pred_dist.shape[-1] - 1) extent = [box_borders[0][0], box_borders[1][0], box_borders[0][1], box_borders[1][1]] ind = start_ind for i in range(len(pred_dist)): p = pred_dist[i] r = ref_dist[i] if ref_dist is not None else None # Plot grid in 2D p_mean = p.mean(axis=-1) if r is not None: r_mean = r.mean(axis=-1) vmin = min(r_mean.min(), p_mean.min()) vmax = max(r_mean.max(), p_mean.max()) else: vmin, vmax = p_mean.min(), p_mean.max() fig, axes = plt.subplots(1, n_img, figsize=(2 + 5 * n_img, 6), squeeze=False) axes = axes[0] axes[0].imshow(p_mean.T, origin="lower", vmin=vmin, vmax=vmax, extent=extent) axes[0].set_title("Prediction") if r is not None: axes[1].imshow(r_mean.T, origin="lower", vmin=vmin, vmax=vmax, extent=extent) axes[1].set_title("Reference") # Colorbar plt.tight_layout(rect=[0, 0, 0.9, 1]) pos = axes[-1].get_position() cax = fig.add_axes(rect=[0.9, pos.ymin, 0.03, pos.ymax - pos.ymin]) m = cm.ScalarMappable() m.set_array([vmin, vmax]) plt.colorbar(m, cax=cax) plt.savefig(save_path := os.path.join(outdir, f"{ind}_pred_dist2D.png")) if verbose > 0: print(f"Position distribution 2D prediction image saved to {save_path}") plt.close() # Plot each z-slice separately if r is not None: vmin = min(r.min(), p.min()) vmax = max(r.max(), p.max()) else: vmin, vmax = p.min(), p.max() nrows, ncols = _calc_plot_dim(p.shape[-1], f=0.5) fig = plt.figure(figsize=(4 * ncols, 4.25 * nrows * n_img)) fig_grid = fig.add_gridspec(nrows, ncols, wspace=0.05, hspace=0.15, left=0.03, right=0.98, bottom=0.02, top=0.98) for iz in range(p.shape[-1]): ix = iz % ncols iy = iz // ncols axes = fig_grid[iy, ix].subgridspec(n_img, 1, hspace=0.03).subplots(squeeze=False)[:, 0] axes[0].imshow(p[:, :, iz].T, origin="lower", vmin=vmin, vmax=vmax, extent=extent) axes[0].axis("off") axes[0].set_title(f"z = {z_start + (iz + 0.5) * z_res:.2f}Å", fontsize=fontsize) if r is not None: axes[1].imshow(r[:, :, iz].T, origin="lower", vmin=vmin, vmax=vmax, extent=extent) axes[1].axis("off") if ix == 0: axes[0].text( -0.1, 0.5, "Prediction", ha="center", va="center", transform=axes[0].transAxes, rotation="vertical", fontsize=fontsize, ) if r is not None: axes[1].text( -0.1, 0.5, "Reference", ha="center", va="center", transform=axes[1].transAxes, rotation="vertical", fontsize=fontsize, ) plt.savefig(save_path := os.path.join(outdir, f"{ind}_pred_dist.png")) if verbose > 0: print(f"Position distribution prediction image saved to {save_path}") plt.close() ind += 1