mlspm.losses#
- class mlspm.losses.GraphLoss(node_factor: float = 1.0, edge_factor: float = 1.0)[source]#
Loss that compares two graphs.
- Parameters:
node_factor – Weight for node classification loss.
edge_factor – Weight for edge classification loss.
- forward(pred: Tuple[list[Tensor], list[Tensor], list[Tensor]], ref: Tuple[list[Tensor], list[Tensor]], separate_loss_factors=False) Tensor | list[Tensor, Tensor, Tensor][source]#
- Parameters:
pred – Predicted graph batch as returned by
GraphImgNet.forward()ref –
Reference graph batch. A tuple (node_classes, edges), where
node_classes - Node classes as class index numbers. List of tensors of shape
(n_atoms,).edges - Edges as pairs of node indices. List of tensors of shape
(2, n_edges).
separate_loss_factors – Whether to return a single total loss or a separated list of values with each loss component.
- Returns:
Computed loss value. Either a single value when
separate_loss_factors==False, or a list[total_loss, node_loss, edge_loss]whenseparate_loss_factors==True.