import os
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import MaxNLocator
from scipy.spatial.distance import cdist, directed_hausdorff
from ..utils import elements
from ._molecule_graph import MoleculeGraph
[docs]
class GraphStats:
"""
Gather statistics on graph predictions divided into bins based on the graph sizes.
Arguments:
classes: Classes for categorizing atom based on their chemical elements. Each class is a list of elements either
as atomic numbers or as chemical symbols.
dist_threshold: Distance threshold for considering atoms as matching between prediction and reference.
bin_size: Bin size for graph sizes. The samples are divided into bins based on the number of nodes in the graphs.
"""
def __init__(self, classes: list[list[int]], dist_threshold: float = 0.35, bin_size: int = 4):
self.classes = classes
self.n_classes = len(self.classes)
self.bin_size = bin_size
self.n_bins = 1
self._graph_sizes = [[]]
self._node_count_diffs = [[]]
self._bond_count_diffs = [[]]
self._hausdorff_distances = [[]]
self._matching_distances = [[]]
self._conf_mat_node = [np.zeros((self.n_classes, self.n_classes), dtype=np.int32)]
self._conf_mat_edge = [np.zeros((2, 2), dtype=np.int32)]
self._missing_nodes = [[]]
self._extra_nodes = [[]]
self.dist_threshold = dist_threshold
@property
def largest_graph(self) -> int:
"""Size of the largest graph seen in all batches."""
return max(self.graph_sizes())
@property
def total_nodes(self) -> int:
"""Total number of nodes seen in graphs all batches."""
return sum(self.graph_sizes())
@property
def total_samples(self) -> int:
"""Total number of samples seen in all batches."""
return len(self.graph_sizes())
[docs]
def conf_mat_node(self, size_bin: int = -1) -> np.ndarray:
"""
Get the confusion matrix for node classification.
Arguments:
size_bin: Index of graph size bin. If negative, include all bins.
Returns:
Confusion matrix of predicted vs reference node classes.
"""
return np.sum(self._conf_mat_node, axis=0) if size_bin < 0 else self._conf_mat_node[size_bin]
[docs]
def conf_mat_edge(self, size_bin: int = -1):
"""
Get the confusion matrix for edge classification.
Arguments:
size_bin: Index of graph size bin. If negative, include all bins.
Returns:
Confusion matrix of predicted vs reference edge classes.
"""
return np.sum(self._conf_mat_edge, axis=0) if size_bin < 0 else self._conf_mat_edge[size_bin]
[docs]
def edge_precision(self, size_bin: int = -1) -> np.ndarray:
"""
Get the precision for edge classification.
Arguments:
size_bin: Index of graph size bin. If negative, include all bins.
Returns:
Edge classification precision for every class.
"""
conf_mat = self.conf_mat_edge(size_bin)
return np.diag(conf_mat) / conf_mat.sum(axis=0)
[docs]
def edge_recall(self, size_bin: int = -1) -> np.ndarray:
"""
Get the recall for edge classification.
Arguments:
size_bin: Index of graph size bin. If negative, include all bins.
Returns:
Edge classification recall for every class.
"""
conf_mat = self.conf_mat_edge(size_bin)
return np.diag(conf_mat) / conf_mat.sum(axis=1)
[docs]
def node_precision(self, size_bin: int = -1) -> np.ndarray:
"""
Get the precision for node classification.
Arguments:
size_bin: Index of graph size bin. If negative, include all bins.
Returns:
Node classification precision for every class.
"""
conf_mat = self.conf_mat_node(size_bin)
return np.diag(conf_mat) / conf_mat.sum(axis=0)
[docs]
def node_recall(self, size_bin: int = -1) -> np.ndarray:
"""
Get the precision for node classification.
Arguments:
size_bin: Index of graph size bin. If negative, include all bins.
Returns:
Node classification recall for every class.
"""
conf_mat = self.conf_mat_node(size_bin)
return np.diag(conf_mat) / conf_mat.sum(axis=1)
def _get_array(self, arrays, size_bin):
arrays = [np.array(a) for a in arrays]
return np.concatenate(arrays, axis=0) if size_bin < 0 else arrays[size_bin]
[docs]
def graph_sizes(self, size_bin: int = -1) -> np.ndarray:
"""
Get the full list of graph sizes.
Arguments:
size_bin: Index of graph size bin. If negative, include all bins
Returns:
Array of graph sizes.
"""
return self._get_array(self._graph_sizes, size_bin)
[docs]
def node_count_diffs(self, size_bin: int = -1) -> np.ndarray:
"""
Get the full list of differences in node counts between predictions and references.
Arguments:
size_bin: Index of graph size bin. If negative, include all bins
Returns:
Array of node count differences.
"""
return self._get_array(self._node_count_diffs, size_bin)
[docs]
def bond_count_diffs(self, size_bin: int = -1) -> np.ndarray:
"""
Get the full list of differences in edge counts between predictions and references.
Arguments:
size_bin: Index of graph size bin. If negative, include all bins
Returns:
Array of edge count differences.
"""
return self._get_array(self._bond_count_diffs, size_bin)
[docs]
def hausdorff_distances(self, size_bin: int = -1) -> np.ndarray:
"""
Get the full list of Hausdorff distances between predictions and references.
Arguments:
size_bin: Index of graph size bin. If negative, include all bins
Returns:
Array of Hausdorff distances.
"""
return self._get_array(self._hausdorff_distances, size_bin)
[docs]
def matching_distances(self, size_bin: int = -1) -> np.ndarray:
"""
Get the full list of matching distances between predictions and references. The matching distance
is the distance between a pair of nodes in the prediction and reference that were within the set
threshold distance.
Arguments:
size_bin: Index of graph size bin. If negative, include all bins
Returns:
Array of matching distances.
"""
return self._get_array(self._matching_distances, size_bin)
[docs]
def missing_nodes(self, size_bin: int = -1) -> np.ndarray:
"""
Get the full list of the number of missing nodes in predictions compared to references.
Arguments:
size_bin: Index of graph size bin. If negative, include all bins
Returns:
Array of missing node counts.
"""
return self._get_array(self._missing_nodes, size_bin)
def _check_bins(self, size_bin: int):
for _ in range(self.n_bins, size_bin + 1):
self._graph_sizes.append([])
self._node_count_diffs.append([])
self._bond_count_diffs.append([])
self._hausdorff_distances.append([])
self._matching_distances.append([])
self._conf_mat_node.append(np.zeros((self.n_classes, self.n_classes), dtype=np.int32))
self._conf_mat_edge.append(np.zeros((2, 2), dtype=np.int32))
self._missing_nodes.append([])
self._extra_nodes.append([])
self.n_bins += 1
[docs]
def add_batch(self, pred: list[MoleculeGraph], ref: list[MoleculeGraph]):
"""
Gather stats from one batch of predictions and references.
Arguments:
pred: Predicted molecule graphs.
ref: Reference molecule graphs.
"""
assert len(pred) == len(ref), "Different number of predictions and references."
for p, r in zip(pred, ref):
graph_size = len(r)
size_bin = (graph_size - 1) // self.bin_size
self._check_bins(size_bin)
self._graph_sizes[size_bin].append(graph_size)
# Node and bond count diffs
self._node_count_diffs[size_bin].append(len(p) - len(r))
self._bond_count_diffs[size_bin].append(len(p.bonds) - len(r.bonds))
if len(p.atoms) > 0:
pos1 = p.array(xyz=True)
pos2 = r.array(xyz=True)
# Hausdorff distance
d1 = directed_hausdorff(pos1, pos2)
d2 = directed_hausdorff(pos2, pos1)
self._hausdorff_distances[size_bin].append(max(d1[0], d2[0]))
# Match closest positions in prediction and reference
dist_mat = cdist(pos2, pos1, metric="euclidean")
mapping = []
missing_nodes = []
for i, dists in enumerate(dist_mat):
matches = np.where(dists < self.dist_threshold)[0]
if len(matches) == 0:
missing_nodes.append(i)
elif len(matches) == 1:
mapping.append(matches[0])
self._matching_distances[size_bin].append(dists[matches[0]])
else:
mapping.append(matches[np.argmin(dists[matches])])
self._matching_distances[size_bin].append(dists[mapping[-1]])
p_extra_nodes = list(set(range(len(p))) - set(mapping))
n_matches = len(mapping)
mapping += p_extra_nodes
self._missing_nodes[size_bin].append(len(missing_nodes))
self._extra_nodes[size_bin].append(len(p_extra_nodes))
# Prune graphs to match nodes one-to-one
r_pruned = r.remove_atoms(missing_nodes)[0]
p_pruned = p.permute(mapping).remove_atoms([n_matches + i for i in range(len(p_extra_nodes))])[0]
assert len(r_pruned) == len(p_pruned), f"{len(r_pruned)}, {len(p_pruned)}"
# Node confusion matrix
for ri, pi in zip(r_pruned.atoms, p_pruned.atoms):
self._conf_mat_node[size_bin][ri.class_index, pi.class_index] += 1
# Edge confusion matrix
Ar = r_pruned.adjacency_matrix()
Ap = p_pruned.adjacency_matrix()
Ar = Ar[np.triu_indices(len(Ar), k=1)].flatten()
Ap = Ap[np.triu_indices(len(Ap), k=1)].flatten()
np.add.at(self._conf_mat_edge[size_bin], (Ar, Ap), 1)
[docs]
def plot(self, outdir: str = "./", verbose: str = 1):
"""
Plot histograms of graph sizes, node/bond count differences, Hausdorff distances, and maching distances,
and confusion matrices for node and edge classification.
Arguments:
outdir: Directory where images are saved.
verbose: Whether to print information.
"""
# Import here, because otherwise leads to a circular import
from mlspm.visualization import plot_confusion_matrix
if not os.path.exists(outdir):
os.makedirs(outdir)
def count_histogram(counts):
fig = plt.figure(figsize=(8, 8))
bin_min = np.min(counts) - 0.5
bin_max = np.max(counts) + 1.5
bins = np.arange(bin_min, bin_max, 1)
plt.hist(counts, bins=bins, edgecolor="black", density=True)
plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
fig.tight_layout(rect=[0, 0.03, 1, 0.95])
# Histogram of graph sizes
count_histogram(self.graph_sizes())
plt.title("Reference graph size")
plt.xlabel("Number of nodes in graph")
plt.ylabel("Normalized count")
plt.tight_layout()
plt.savefig(savepath := os.path.join(outdir, "graph_size.png"))
if verbose > 0:
print(f"Graph size histogram saved to {savepath}")
plt.close()
# Histogram of node count diffs
count_histogram(self.node_count_diffs())
plt.title("Node count difference")
plt.xlabel("Difference in number of nodes")
plt.ylabel("Normalized count")
plt.tight_layout()
plt.savefig(savepath := os.path.join(outdir, "node_diff.png"))
if verbose > 0:
print(f"Node count difference histogram saved to {savepath}")
plt.close()
# Histogram of bond count diffs
count_histogram(self.bond_count_diffs())
plt.title("Bond count difference")
plt.xlabel("Difference in number of bonds")
plt.ylabel("Normalized count")
plt.tight_layout()
plt.savefig(savepath := os.path.join(outdir, "bond_diff.png"))
if verbose > 0:
print(f"Bond count difference histogram saved to {savepath}")
plt.close()
# Node confusion matrix
fig = plt.figure(figsize=(max(6, 1.2 * self.n_classes), max(5, 1.0 * self.n_classes)))
ax = fig.add_subplot(111)
node_tick_labels = [", ".join([elements[e - 1] for e in c]) for c in self.classes]
plot_confusion_matrix(ax, self.conf_mat_node(), node_tick_labels)
ax.set_title("Node confusion matrix")
plt.tight_layout()
plt.savefig(savepath := os.path.join(outdir, "conf_mat_node.png"))
if verbose > 0:
print(f"Node confusion matrix saved to {savepath}")
plt.close()
# Edge confusion matrix
fig = plt.figure(figsize=(6, 5))
ax = fig.add_subplot(111)
plot_confusion_matrix(ax, self.conf_mat_edge(), ["No edge", "Edge"])
ax.set_title("Edge confusion matrix")
plt.tight_layout()
plt.savefig(savepath := os.path.join(outdir, "conf_mat_edge.png"))
if verbose > 0:
print(f"Edge confusion matrix saved to {savepath}")
plt.close()
# Binned node precision
fig = plt.figure(figsize=(6, 5))
bins = [(i + 1) * self.bin_size for i in range(self.n_bins)]
for c in range(self.n_classes):
prec = [self.node_precision(b)[c] for b in range(self.n_bins)]
plt.plot(bins, prec)
plt.legend(node_tick_labels)
plt.xlabel("Number of nodes in graph")
plt.ylabel("Node classification precision")
plt.tight_layout()
plt.savefig(savepath := os.path.join(outdir, "binned_node_precision.png"))
if verbose > 0:
print(f"Plot of node precision as a function of graph size saved to {savepath}")
plt.close()
# Binned node recall
fig = plt.figure(figsize=(6, 5))
bins = [(i + 1) * self.bin_size for i in range(self.n_bins)]
for c in range(self.n_classes):
prec = [self.node_recall(b)[c] for b in range(self.n_bins)]
plt.plot(bins, prec)
plt.legend(node_tick_labels)
plt.xlabel("Number of nodes in graph")
plt.ylabel("Node classification recall")
plt.tight_layout()
plt.savefig(savepath := os.path.join(outdir, "binned_node_recall.png"))
if verbose > 0:
print(f"Plot of node recall as a function of graph size saved to {savepath}")
plt.close()
# Binned edge precision
fig = plt.figure(figsize=(6, 5))
bins = [(i + 1) * self.bin_size for i in range(self.n_bins)]
for c in range(2):
prec = [self.edge_precision(b)[c] for b in range(self.n_bins)]
plt.plot(bins, prec)
plt.legend(["No edge", "Edge"])
plt.xlabel("Number of nodes in graph")
plt.ylabel("Edge classification precision")
plt.tight_layout()
plt.savefig(savepath := os.path.join(outdir, "binned_edge_precision.png"))
if verbose > 0:
print(f"Plot of edge precision as a function of graph size saved to {savepath}")
plt.close()
# Binned edge recall
fig = plt.figure(figsize=(6, 5))
bins = [(i + 1) * self.bin_size for i in range(self.n_bins)]
for c in range(2):
prec = [self.edge_recall(b)[c] for b in range(self.n_bins)]
plt.plot(bins, prec)
plt.legend(["No edge", "Edge"])
plt.xlabel("Number of nodes in graph")
plt.ylabel("Edge classification recall")
plt.tight_layout()
plt.savefig(savepath := os.path.join(outdir, "binned_edge_recall.png"))
if verbose > 0:
print(f"Plot of edge recall as a function of graph size saved to {savepath}")
plt.close()
if len(self.hausdorff_distances()) > 0:
# Histogram of Hausdorff distances
plt.figure(figsize=(8, 8))
plt.hist(self.hausdorff_distances(), bins=20, edgecolor="black", density=True)
plt.title("Hausdorff distances")
plt.xlabel(f"Distance ($\AA$)")
plt.ylabel("Normalized count")
plt.tight_layout()
plt.savefig(savepath := os.path.join(outdir, "hausdorff.png"))
if verbose > 0:
print(f"Hausdorff distance histogram saved to {savepath}")
plt.close()
# Histograms of matching distances
plt.figure(figsize=(8, 8))
plt.hist(self.matching_distances(), bins=20, edgecolor="black", density=True)
plt.title("Matching distance")
plt.xlabel(f"Distance ($\AA$)")
plt.ylabel("Normalized count")
plt.tight_layout()
plt.savefig(savepath := os.path.join(outdir, "matching_distances.png"))
if verbose > 0:
print(f"Histogram of matching distances saved to {savepath}")
plt.close()
[docs]
def report(self, outdir: str = "./", verbose: int = 1):
"""
Save to file mean absolute node/bond count diffs, mean Hausdorff, mean matching distance, missing/extra atoms,
total samples/nodes, average/largest graph size, and node/edge precision/recall.
Arguments:
outdir: Directory where files are saved.
verbose: Whether to print information.
"""
if not os.path.exists(outdir):
os.makedirs(outdir)
with open(os.path.join(outdir, "seq_stats.csv"), "w") as f:
f.write(f"Mean absolute node diff, {np.abs(self.node_count_diffs()).mean()}\n")
f.write(f"Mean absolute bond diff, {np.abs(self.bond_count_diffs()).mean()}\n")
f.write(f"Mean Hausdorff distance, {np.mean(self.hausdorff_distances())}\n")
f.write(f"Mean matching distance, {np.mean(self.matching_distances())}\n")
f.write(f"Average missing atoms, {np.mean(self.missing_nodes())}\n")
f.write(f"Average extra atoms, {np.mean(self.extra_nodes())}\n")
f.write(f"Total samples, {self.total_samples}\n")
f.write(f"Total nodes, {self.total_nodes}\n")
f.write(f"Average graph size, {self.total_nodes / self.total_samples}\n")
f.write(f"Largest graph size, {self.largest_graph}\n")
if verbose > 0:
print(f"Sequence stats saved to {f.name}")
# Node precision and recall
with open(os.path.join(outdir, "stats_node.csv"), "w") as f:
f.write("Ref class,Precision,Recall\n")
for i, (prec, rec) in enumerate(zip(self.node_precision(), self.node_recall())):
f.write(f"{i},{prec:.4f},{rec:.4f}")
if i < self.n_classes - 1:
f.write("\n")
if verbose > 0:
print(f"Sequence node prediction stats saved to {f.name}")
# Edge precision and recall
with open(os.path.join(outdir, "stats_edge.csv"), "w") as f:
f.write("Ref class,Precision,Recall\n")
for label, prec, rec in zip(["No edge", "Edge"], self.edge_precision(), self.edge_recall()):
f.write(f"{label},{prec:.4f},{rec:.4f}")
if label == "No edge":
f.write("\n")
if verbose > 0:
print(f"Sequence edge prediction stats saved to {f.name}")
# Node confusion matrix
np.savetxt(savepath := os.path.join(outdir, "conf_mat_node.csv"), self.conf_mat_node(), delimiter=",")
if verbose > 0:
print(f"Node confusion matrix data saved to {savepath}")
# Edge confusion matrix
np.savetxt(savepath := os.path.join(outdir, "conf_mat_edge.csv"), self.conf_mat_edge(), delimiter=",")
if verbose > 0:
print(f"Edge confusion matrix data saved to {savepath}")
# Binned stats
with open(os.path.join(outdir, "binned_seq_stats.csv"), "w") as f:
f.write("," + ",".join([str((i + 1) * self.bin_size) for i in range(self.n_bins)]) + "\n")
f.write("Number of samples," + ",".join([str(len(self.graph_sizes(b))) for b in range(self.n_bins)]) + "\n")
f.write(
"Mean absolute node diff,"
+ ",".join([str(np.abs(self.node_count_diffs(b)).mean()) for b in range(self.n_bins)])
+ "\n"
)
f.write(
"Mean absolute bond diff,"
+ ",".join([str(np.abs(self.bond_count_diffs(b)).mean()) for b in range(self.n_bins)])
+ "\n"
)
f.write(
"Mean Hausdorff distance," + ",".join([str(self.hausdorff_distances(b).mean()) for b in range(self.n_bins)]) + "\n"
)
f.write("Mean matching distance," + ",".join([str(self.matching_distances(b).mean()) for b in range(self.n_bins)]) + "\n")
f.write("Missing atoms (mean)," + ",".join([str(self.missing_nodes(b).mean()) for b in range(self.n_bins)]) + "\n")
f.write("Missing atoms (std)," + ",".join([str(self.missing_nodes(b).std()) for b in range(self.n_bins)]) + "\n")
f.write("Extra atoms (mean)," + ",".join([str(self.extra_nodes(b).mean()) for b in range(self.n_bins)]) + "\n")
f.write("Extra atoms (std)," + ",".join([str(self.extra_nodes(b).std()) for b in range(self.n_bins)]))
# Binned node precision and recall
with open(os.path.join(outdir, "binned_node_recall.csv"), "w") as f1, open(
os.path.join(outdir, "binned_node_precision.csv"), "w"
) as f2:
f1.write("Ref class," + ",".join([str((i + 1) * self.bin_size) for i in range(self.n_bins)]) + "\n")
f2.write("Ref class," + ",".join([str((i + 1) * self.bin_size) for i in range(self.n_bins)]) + "\n")
for c in range(self.n_classes):
f1.write(f"{c},")
f2.write(f"{c},")
for b in range(self.n_bins):
f1.write(f"{self.node_recall(b)[c]}")
f2.write(f"{self.node_precision(b)[c]}")
if b < self.n_bins - 1:
f1.write(",")
f2.write(",")
elif c < self.n_classes - 1:
f1.write("\n")
f2.write("\n")
# Binned edge precision and recall
with open(os.path.join(outdir, "binned_edge_recall.csv"), "w") as f1, open(
os.path.join(outdir, "binned_edge_precision.csv"), "w"
) as f2:
f1.write("Ref class," + ",".join([str((i + 1) * self.bin_size) for i in range(self.n_bins)]) + "\n")
f2.write("Ref class," + ",".join([str((i + 1) * self.bin_size) for i in range(self.n_bins)]) + "\n")
for c, label in enumerate(["No edge", "Edge"]):
f1.write(f"{label},")
f2.write(f"{label},")
for b in range(self.n_bins):
f1.write(f"{self.edge_recall(b)[c]}")
f2.write(f"{self.edge_precision(b)[c]}")
if b < self.n_bins - 1:
f1.write(",")
f2.write(",")
elif c == 0:
f1.write("\n")
f2.write("\n")