Skip to content

Training

Training framework and utilities for DRN models, including the main training loop, loss functions, and optimization helpers.

Main Training Function

drn.train.train(model, train_dataset, val_dataset, epochs=200, patience=5, lr=None, device=None, log_interval=10, batch_size=128, optimizer=torch.optim.Adam, print_details=True, keep_best=True, gradient_clipping=False)

A generic neural network training function given a model and datasets. Args: model: The model to train. train_dataset: Dataset for training. val_dataset: Dataset for validation. epochs: Number of epochs to train for. patience: Number of epochs with no improvement after which training will be stopped. lr: Learning rate for the optimizer. device: Device to use for training (default is determined automatically). log_interval: How often to log training progress. batch_size: Batch size for training and validation. optimizer: Optimizer class to use (default is Adam). print_details: Whether to print detailed logs during training. keep_best: Whether to return the best model found during training. gradient_clipping: Whether to apply gradient clipping.

Loss Functions

DRN Loss

drn.models.drn_loss(pred, y, kind='jbce', kl_alpha=0.0, mean_alpha=0.0, tv_alpha=0.0, dv_alpha=0.0, kl_direction='forwards')

JBCE Loss

drn.models.ddr.jbce_loss(dists, y, alpha=0.0)

The joint binary cross entropy loss. Args: dists: the predicted distributions y: the observed values alpha: the penalty parameter

NLL Loss

drn.models.ddr.nll_loss(dists, y, alpha=0.0)

Model Utilities

Cutpoint Generation

drn.models.drn_cutpoints(c_0, c_K, y, proportion=None, num_cutpoints=None, min_obs=1)

GLM Utilities

drn.models.glm.gaussian_deviance_loss(y_pred, y_true)

Calculate the Normal deviance loss for the Gaussian distribution. Args: y_pred: the predicted values (shape: (n,)) y_true: the observed values (shape: (n,)) Returns: the deviance loss (shape: (,))

drn.models.glm.gamma_deviance_loss(y_pred, y_true)

Calculate the Tweedie deviance loss for the gamma distribution. Args: y_pred: the predicted values (shape: (n,)) y_true: the observed values (shape: (n,)) Returns: the deviance loss (shape: (,))