mlspm.losses#

class mlspm.losses.GraphLoss(node_factor: float = 1.0, edge_factor: float = 1.0)[source]#

Loss that compares two graphs.

Parameters:
  • node_factor – Weight for node classification loss.

  • edge_factor – Weight for edge classification loss.

forward(pred: Tuple[list[Tensor], list[Tensor], list[Tensor]], ref: Tuple[list[Tensor], list[Tensor]], separate_loss_factors=False) Tensor | list[Tensor, Tensor, Tensor][source]#
Parameters:
  • pred – Predicted graph batch as returned by GraphImgNet.forward()

  • ref

    Reference graph batch. A tuple (node_classes, edges), where

    • node_classes - Node classes as class index numbers. List of tensors of shape (n_atoms,).

    • edges - Edges as pairs of node indices. List of tensors of shape (2, n_edges).

  • separate_loss_factors – Whether to return a single total loss or a separated list of values with each loss component.

Returns:

Computed loss value. Either a single value when separate_loss_factors==False, or a list [total_loss, node_loss, edge_loss] when separate_loss_factors==True.