Source code for mlspm.visualization
import os
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes import Axes
from matplotlib.colors import Colormap
# Make subpackage plotting tools also available
from .graph._visualization import *
from .image._visualization import *
from .utils import _calc_plot_dim
[docs]
def plot_input(X: np.ndarray, constant_range: bool = False, cmap: str | Colormap = "afmhot") -> plt.Figure:
"""
Plot a single stack of AFM images.
Arguments:
X: AFM image to plot.
constant_range: Whether the different slices should use the same value range or not.
cmap: Colormap to use for plotting.
Returns:
Figure on which the image was plotted.
"""
rows, cols = _calc_plot_dim(X.shape[-1])
fig = plt.figure(figsize=(3.2 * cols, 2.5 * rows))
vmax = X.max()
vmin = X.min()
for k in range(X.shape[-1]):
fig.add_subplot(rows, cols, k + 1)
if constant_range:
plt.imshow(X[:, :, k].T, cmap=cmap, vmin=vmin, vmax=vmax, origin="lower")
else:
plt.imshow(X[:, :, k].T, cmap=cmap, origin="lower")
plt.colorbar()
plt.tight_layout()
return fig
[docs]
def make_input_plots(
Xs: list[np.ndarray],
outdir: str = "./predictions/",
start_ind: int = 0,
constant_range: bool = False,
cmap: str | Colormap = "afmhot",
verbose: int = 1,
):
"""
Plot a batch of AFM images to files 0_input.png, 1_input.png, ... etc.
Arguments:
Xs: Input AFM images to plot. Each list element corresponds to one AFM tip and is an array of shape ``(batch, x, y, z)``.
outdir: Directory where images are saved.
start_ind: Starting index for file naming.
constant_range: Whether the different slices should use the same value range or not.
cmap: Colormap to use for plotting.
verbose: Whether to print output information.
"""
if not os.path.exists(outdir):
os.makedirs(outdir)
img_ind = start_ind
for i in range(Xs[0].shape[0]):
for j in range(len(Xs)):
plot_input(Xs[j][i], constant_range, cmap=cmap)
save_name = f"{img_ind}_input"
if len(Xs) > 1:
save_name += str(j + 1)
save_name = os.path.join(outdir, save_name)
save_name += ".png"
plt.savefig(save_name)
plt.close()
if verbose > 0:
print(f"Input image saved to {save_name}")
img_ind += 1
[docs]
def plot_confusion_matrix(ax: Axes, conf_mat: np.ndarray, tick_labels: Optional[list[str]] = None):
"""
Plot confusion matrix on matplotlib axes.
Arguments:
ax: Axes object on which the confusion matrix is plotted.
conf_mat: Confusion matrix counts.
tick_labels: Labels for classes.
"""
if tick_labels:
assert len(conf_mat) == len(tick_labels)
else:
tick_labels = [str(i) for i in range(len(conf_mat))]
conf_mat_norm = np.zeros_like(conf_mat, dtype=np.float64)
for i, r in enumerate(conf_mat):
conf_mat_norm[i] = r / np.sum(r)
im = ax.imshow(conf_mat_norm, cmap=cm.Blues)
plt.colorbar(im)
ax.set_xticks(np.arange(conf_mat.shape[0]))
ax.set_yticks(np.arange(conf_mat.shape[1]))
ax.set_xlabel("Predicted class")
ax.set_ylabel("True class")
ax.set_xticklabels(tick_labels)
ax.set_yticklabels(tick_labels, rotation="vertical", va="center")
for i in range(conf_mat.shape[0]):
for j in range(conf_mat.shape[1]):
color = "white" if conf_mat_norm[i, j] > 0.5 else "black"
label = "{:.3f}".format(conf_mat_norm[i, j]) + "\n(" + "{:d}".format(conf_mat[i, j]) + ")"
ax.text(j, i, label, ha="center", va="center", color=color)