mlspm.logging#
- class mlspm.logging.LossLogPlot(log_path: str, plot_path: str, loss_labels: list[str], loss_weights: list[float] | None = None, print_interval: int = 10, init_epoch: int | None = None, stream: ~typing.TextIO = <_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>)[source]#
Bases:
objectLog and plot model training loss history. Add losses for each batch with
add_train_loss()andadd_train_loss(), and at the end of each epoch callnext_epoch()to print the status. Works with distributed training.- Parameters:
log_path – Path where loss log is saved.
plot_path – Path where plot of loss history is saved.
loss_labels – Labels for different loss components. If length > 1, an additional component
'Total'is prepended to the list.loss_weights – Weights for different loss components when there is more than one.
print_interval – Loss values are printed every print_interval batches.
init_epoch – Initial epoch. If not None and existing log has more epochs, discard them.
stream – Stream where log is printed to.
- add_train_loss(losses: Tensor | ndarray | float | list[float])[source]#
Add losses for one training batch. Averaged over parallel processes.
- Parameters:
losses – Losses to append to the list.
- add_val_loss(losses: Tensor | ndarray | float | list[float])[source]#
Add losses for one validation batch. Averaged over parallel processes.
- Parameters:
losses – Losses to append to the list.
- get_joinable(mode: str = 'train')[source]#
Return a joinable for uneven training/validation inputs.
- Parameters:
mode – Choose ‘train or ‘val’.
- loss_str(losses: list[float] | ndarray | Tensor) str[source]#
Get a pretty string for loss values.
- Parameters:
losses – List of losses of the same length as the number of loss labels.
- Returns:
String representation of the losses.
- class mlspm.logging.SyncedLoss(num_losses: int)[source]#
Bases:
JoinableGather loss values to a list that is averaged over parallel ranks.
- Parameters:
num_losses – Number of different loss values.
- append(losses: Tensor | ndarray | float | list[float])[source]#
Append a new batch of loss values.
- Parameters:
losses – Loss values. Length should match
self.num_losses.
- property join_device#
Return the device from which to perform collective communications needed by the join context manager.
- join_hook(**kwargs)[source]#
Return a
JoinHookinstance for the givenJoinable.- Parameters:
kwargs (dict) – a
dictcontaining any keyword arguments to modify the behavior of the join hook at run time; allJoinableinstances sharing the same join context manager are forwarded the same value forkwargs.
- property join_process_group#
Returns the process group for the collective communications needed by the join context manager itself.