mlspm.logging#

class mlspm.logging.LossLogPlot(log_path: str, plot_path: str, loss_labels: list[str], loss_weights: list[float] | None = None, print_interval: int = 10, init_epoch: int | None = None, stream: ~typing.TextIO = <_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>)[source]#

Bases: object

Log and plot model training loss history. Add losses for each batch with add_train_loss() and add_train_loss(), and at the end of each epoch call next_epoch() to print the status. Works with distributed training.

Parameters:
  • log_path – Path where loss log is saved.

  • plot_path – Path where plot of loss history is saved.

  • loss_labels – Labels for different loss components. If length > 1, an additional component 'Total' is prepended to the list.

  • loss_weights – Weights for different loss components when there is more than one.

  • print_interval – Loss values are printed every print_interval batches.

  • init_epoch – Initial epoch. If not None and existing log has more epochs, discard them.

  • stream – Stream where log is printed to.

add_train_loss(losses: Tensor | ndarray | float | list[float])[source]#

Add losses for one training batch. Averaged over parallel processes.

Parameters:

losses – Losses to append to the list.

add_val_loss(losses: Tensor | ndarray | float | list[float])[source]#

Add losses for one validation batch. Averaged over parallel processes.

Parameters:

losses – Losses to append to the list.

get_joinable(mode: str = 'train')[source]#

Return a joinable for uneven training/validation inputs.

Parameters:

mode – Choose ‘train or ‘val’.

loss_str(losses: list[float] | ndarray | Tensor) str[source]#

Get a pretty string for loss values.

Parameters:

losses – List of losses of the same length as the number of loss labels.

Returns:

String representation of the losses.

next_epoch()[source]#

Increment epoch by one, write current average batch losses to log, empty batch losses, report epoch time to terminal, and update loss history plot.

plot_history(show: bool = False)[source]#

Plot history of current losses into self.plot_path.

Parameters:

show – Whether to show the plot on screen.

class mlspm.logging.SyncedLoss(num_losses: int)[source]#

Bases: Joinable

Gather loss values to a list that is averaged over parallel ranks.

Parameters:

num_losses – Number of different loss values.

append(losses: Tensor | ndarray | float | list[float])[source]#

Append a new batch of loss values.

Parameters:

losses – Loss values. Length should match self.num_losses.

property join_device#

Return the device from which to perform collective communications needed by the join context manager.

join_hook(**kwargs)[source]#

Return a JoinHook instance for the given Joinable.

Parameters:

kwargs (dict) – a dict containing any keyword arguments to modify the behavior of the join hook at run time; all Joinable instances sharing the same join context manager are forwarded the same value for kwargs.

property join_process_group#

Returns the process group for the collective communications needed by the join context manager itself.

mean() ndarray[source]#

Get average loss over batches.

reset()[source]#

Empty list of losses

mlspm.logging.setup_file_logger(save_path: str, logger_name: str, first_line: str = '')[source]#