import datetime
import logging
import os
import sys
import time
from typing import Optional, TextIO
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.distributed as dist
from torch.distributed.algorithms.join import Join, Joinable, JoinHook
from .utils import _calc_plot_dim, _get_distributed
[docs]
def setup_file_logger(save_path: str, logger_name: str, first_line: str = ""):
logger = logging.getLogger(logger_name)
logger.handlers = []
f_handler = logging.FileHandler(save_path)
f_handler.setLevel(logging.DEBUG)
logger.addHandler(f_handler)
# create console handler
ch = logging.StreamHandler()
ch.setLevel(logging.WARNING)
logger.addHandler(ch)
logger.setLevel(logging.INFO)
logger.info(first_line)
return logger
[docs]
class SyncedLoss(Joinable):
"""Gather loss values to a list that is averaged over parallel ranks.
Arguments:
num_losses: Number of different loss values.
"""
def __init__(self, num_losses: int):
super().__init__()
self.num_losses = num_losses
self.world_size, self.local_rank, self.global_rank, self.group = _get_distributed()
self.reset()
[docs]
def reset(self):
"""Empty list of losses"""
self.losses = []
self.n_batches = 0
def __len__(self):
return len(self.losses)
def __getitem__(self, index: int):
return self.losses[index]
[docs]
def mean(self) -> np.ndarray:
"""Get average loss over batches."""
return np.mean(self.losses, axis=0)
@property
def join_process_group(self):
return self.group
@property
def join_device(self):
return self.local_rank
[docs]
def join_hook(self, **kwargs):
return _SyncedLossJoinHook(self)
def _sync_losses(self, losses, shadow=False):
assert self.world_size > 1, self.world_size
if not shadow:
assert len(losses) == self.num_losses, (losses, self.num_losses)
# We haven't joined yet
Join.notify_join_context(self)
# Count non-joined ranks
world_size_eff = torch.ones(1, device=self.local_rank)
dist.all_reduce(world_size_eff, op=dist.ReduceOp.SUM)
# Sum losses over non-joined ranks
losses = torch.tensor(losses, device=self.local_rank)
dist.all_reduce(losses, op=dist.ReduceOp.SUM)
else: # We joined already, so shadow the reduce operations
# Don't count towards non-joined ranks
world_size_eff = torch.zeros(1, device=self.local_rank)
dist.all_reduce(world_size_eff, op=dist.ReduceOp.SUM)
# Also don't count towards sum of losses
losses = torch.zeros(self.num_losses, device=self.local_rank)
dist.all_reduce(losses, op=dist.ReduceOp.SUM)
# Add averaged losses to list
losses /= world_size_eff
losses = list(losses.cpu().numpy())
self.losses.append(losses)
return losses
[docs]
def append(self, losses: torch.Tensor | np.ndarray | float | list[float]):
"""
Append a new batch of loss values.
Arguments:
losses: Loss values. Length should match ``self.num_losses``.
"""
if not isinstance(losses, list):
losses = [losses]
losses_ = []
for loss in losses:
if isinstance(loss, torch.Tensor):
if loss.size() == ():
losses_.append(loss.item())
else:
losses_ += list(loss.cpu().detach().numpy())
elif isinstance(loss, np.ndarray):
if loss.size == 1:
losses_.append(loss.item())
else:
losses_ += list(loss)
elif isinstance(loss, (int, float)):
losses_.append(loss)
else:
raise ValueError(f"Loss has unsupported type `{type(loss)}`")
losses = losses_
if np.isnan(losses).any():
raise ValueError(
f"Found a nan in losses ({losses}) at rank {self.global_rank} after {len(self.losses)} batches. "
f"Some of the previous losses were: {self.losses[-5:]}"
)
if self.world_size > 1:
losses = self._sync_losses(losses)
else:
self.losses.append(losses)
self.n_batches += 1
return losses
class _SyncedLossJoinHook(JoinHook):
"""Hook for when the number of batches does not match between processes."""
def __init__(self, synced_loss):
self.synced_loss = synced_loss
def main_hook(self):
self.synced_loss._sync_losses([], shadow=True)
def post_hook(self, is_last_joiner):
pass
[docs]
class LossLogPlot:
"""
Log and plot model training loss history. Add losses for each batch with :meth:`add_train_loss` and :meth:`add_train_loss`,
and at the end of each epoch call :meth:`next_epoch` to print the status. Works with distributed training.
Arguments:
log_path: Path where loss log is saved.
plot_path: Path where plot of loss history is saved.
loss_labels: Labels for different loss components. If length > 1, an additional component ``'Total'`` is prepended to the list.
loss_weights: Weights for different loss components when there is more than one.
print_interval: Loss values are printed every **print_interval** batches.
init_epoch: Initial epoch. If not None and existing log has more epochs, discard them.
stream: Stream where log is printed to.
"""
def __init__(
self,
log_path: str,
plot_path: str,
loss_labels: list[str],
loss_weights: Optional[list[float]] = None,
print_interval: int = 10,
init_epoch: Optional[int] = None,
stream: TextIO = sys.stdout,
):
self.log_path = log_path
self.plot_path = plot_path
self.print_interval = print_interval
self.stream = stream
if len(loss_labels) > 1:
loss_labels = ["Total"] + loss_labels
self.loss_labels = loss_labels
if loss_weights is None or len(loss_weights) == 0:
loss_weights = [""] * len(self.loss_labels)
else:
if len(loss_labels) == 1:
assert len(loss_weights) == 1
else:
assert len(loss_weights) == (len(loss_labels) - 1)
loss_weights = [""] + loss_weights
self.loss_weights = loss_weights
self.train_losses = np.empty((0, len(loss_labels)))
self.val_losses = np.empty((0, len(loss_labels)))
self.world_size, self.local_rank, self.global_rank, _ = _get_distributed()
self.epoch = 1
self._synced_losses = {"train": SyncedLoss(len(self.loss_labels)), "val": SyncedLoss(len(self.loss_labels))}
self._init_log(init_epoch)
def _init_log(self, init_epoch: Optional[int]):
log_exists = os.path.isfile(self.log_path)
if self.world_size > 1:
dist.barrier()
if not log_exists:
if self.global_rank > 0:
return
self._write_log()
print(f"Created log at {self.log_path}", file=self.stream, flush=True)
else:
with open(self.log_path, "r") as f:
header = f.readline().rstrip("\r\n").split(";")
hl = (len(header) - 1) // 2
if len(self.loss_labels) != hl:
raise ValueError(
f"The length of the given list of loss names and the length of the header of the existing log at {self.log_path} do not match."
)
for line in f:
if init_epoch is not None and self.epoch >= init_epoch:
break
line = line.rstrip("\n").split(";")
if len(line) < 3:
continue
self.train_losses = np.append(self.train_losses, [[float(s) for s in line[1 : hl + 1]]], axis=0)
self.val_losses = np.append(self.val_losses, [[float(s) for s in line[hl + 1 :]]], axis=0)
self.epoch += 1
if self.global_rank == 0:
if init_epoch is not None:
self._write_log() # Make sure there are no additional rows in the log
print(f"Using existing log at {self.log_path}", file=self.stream, flush=True)
def _write_log(self):
with open(self.log_path, "w") as f:
f.write("epoch")
for i, label in enumerate(self.loss_labels):
label = f";train_{label}"
if self.loss_weights[i]:
label += f" (x {self.loss_weights[i]})"
f.write(label)
for i, label in enumerate(self.loss_labels):
label = f";val_{label}"
if self.loss_weights[i]:
label += f" (x {self.loss_weights[i]})"
f.write(label)
f.write("\n")
for epoch, (train_loss, val_loss) in enumerate(zip(self.train_losses, self.val_losses)):
f.write(str(epoch + 1))
for l in train_loss:
f.write(f";{l}")
for l in val_loss:
f.write(f";{l}")
f.write("\n")
def _add_loss(self, losses: torch.Tensor | np.ndarray | float | list[float], mode: str="train"):
synced_loss = self._synced_losses[mode]
losses = synced_loss.append(losses)
if len(losses) != len(self.loss_labels):
raise ValueError(f"Length of losses ({len(losses)}) does not match with number of loss labels ({len(self.loss_labels)}).")
if self.global_rank == 0 and len(synced_loss) % self.print_interval == 0:
self._print_losses(mode)
def _print_losses(self, mode: str = "train"):
if self.global_rank > 0:
return
synced_loss = self._synced_losses[mode]
losses = np.mean(synced_loss[-self.print_interval :], axis=0)
print(f"Epoch {self.epoch}, {mode} batch {len(synced_loss)} - Loss: " + self.loss_str(losses), file=self.stream, flush=True)
[docs]
def loss_str(self, losses: list[float] | np.ndarray | torch.Tensor) -> str:
"""
Get a pretty string for loss values.
Arguments:
losses: List of losses of the same length as the number of loss labels.
Returns:
String representation of the losses.
"""
if len(losses) != len(self.loss_labels):
raise ValueError(f"Length of losses ({len(losses)}) does not match with number of loss labels ({len(self.loss_labels)}).")
if len(self.loss_labels) == 1:
msg = f"{self.loss_labels[0]}: {losses[0]:.6f}"
else:
msg = f"{losses[0]:.6f}"
msg_loss = [f"{label}: {loss:.6f}" for label, loss in zip(self.loss_labels[1:], losses[1:])]
msg += " (" + ", ".join(msg_loss) + ")"
return msg
[docs]
def add_train_loss(self, losses: torch.Tensor | np.ndarray | float | list[float]):
"""Add losses for one training batch. Averaged over parallel processes.
Arguments:
losses: Losses to append to the list.
"""
if len(self._synced_losses["train"]) == 0:
self.epoch_start = time.perf_counter()
self._add_loss(losses, mode="train")
[docs]
def add_val_loss(self, losses: torch.Tensor | np.ndarray | float | list[float]):
"""Add losses for one validation batch. Averaged over parallel processes.
Arguments:
losses: Losses to append to the list.
"""
if len(self._synced_losses["val"]) == 0:
self.val_start = time.perf_counter()
self._add_loss(losses, mode="val")
[docs]
def next_epoch(self):
"""
Increment epoch by one, write current average batch losses to log, empty batch losses,
report epoch time to terminal, and update loss history plot.
"""
train_loss = self._synced_losses["train"].mean()
val_loss = self._synced_losses["val"].mean()
self.train_losses = np.append(self.train_losses, train_loss[None], axis=0)
self.val_losses = np.append(self.val_losses, val_loss[None], axis=0)
n_train = self._synced_losses["train"].n_batches
n_val = self._synced_losses["val"].n_batches
print(
f"Epoch {self.epoch} at rank {self.global_rank} contained {n_train} training batches " f"and {n_val} validation batches",
file=self.stream,
flush=True,
)
if self.global_rank == 0:
epoch_end = time.perf_counter()
train_step = (self.val_start - self.epoch_start) / n_train
val_step = (epoch_end - self.val_start) / n_val
print(f"Completed epoch {self.epoch} at {datetime.datetime.now()}", file=self.stream, flush=True)
print(f"Train loss: {self.loss_str(train_loss)}", file=self.stream, flush=True)
print(f"Val loss: {self.loss_str(val_loss)}", file=self.stream, flush=True)
print(
f"Epoch time: {epoch_end - self.epoch_start:.2f}s - Train step: {train_step:.5f}s " f"- Val step: {val_step:.5f}s",
file=self.stream,
flush=True,
)
self._write_log()
self.plot_history()
self.epoch += 1
self._synced_losses["train"].reset()
self._synced_losses["val"].reset()
[docs]
def plot_history(self, show: bool = False):
"""
Plot history of current losses into ``self.plot_path``.
Arguments:
show: Whether to show the plot on screen.
"""
if self.global_rank > 0:
return
x = range(1, len(self.train_losses) + 1)
n_rows, n_cols = _calc_plot_dim(len(self.loss_labels), f=0)
fig, axes = plt.subplots(n_rows, n_cols, figsize=(7 * n_cols, 6 * n_rows))
if n_rows == 1 and n_cols == 1:
axes = np.expand_dims(axes, axis=0)
for i, (label, ax) in enumerate(zip(self.loss_labels, axes.flatten())):
ax.semilogy(x, self.train_losses[:, i], "-bx")
ax.semilogy(x, self.val_losses[:, i], "-gx")
ax.legend(["Training", "Validation"])
ax.set_xlabel("Epochs")
ax.set_ylabel("Loss")
if self.loss_weights[i]:
label = f"{label} (x {self.loss_weights[i]})"
ax.set_title(label)
fig.tight_layout()
plt.savefig(self.plot_path)
print(f"Loss history plot saved to {self.plot_path}", file=self.stream, flush=True)
if show:
plt.show()
else:
plt.close()
[docs]
def get_joinable(self, mode: str = "train"):
"""Return a joinable for uneven training/validation inputs.
Arguments:
mode: Choose 'train or 'val'.
"""
if mode not in ["train", "val"]:
raise ValueError(f"mode should be 'train' or 'val', but got `{mode}`")
return self._synced_losses[mode]