Training¶
Training framework and utilities for DRN models, including the main training loop, loss functions, and optimization helpers.
Main Training Function¶
drn.train.train(model, train_dataset, val_dataset, epochs=200, patience=5, lr=None, device=None, log_interval=10, batch_size=128, optimizer=torch.optim.Adam, print_details=True, keep_best=True, gradient_clipping=False)
¶
A generic neural network training function given a model and datasets. Args: model: The model to train. train_dataset: Dataset for training. val_dataset: Dataset for validation. epochs: Number of epochs to train for. patience: Number of epochs with no improvement after which training will be stopped. lr: Learning rate for the optimizer. device: Device to use for training (default is determined automatically). log_interval: How often to log training progress. batch_size: Batch size for training and validation. optimizer: Optimizer class to use (default is Adam). print_details: Whether to print detailed logs during training. keep_best: Whether to return the best model found during training. gradient_clipping: Whether to apply gradient clipping.
Loss Functions¶
DRN Loss¶
drn.models.drn_loss(pred, y, kind='jbce', kl_alpha=0.0, mean_alpha=0.0, tv_alpha=0.0, dv_alpha=0.0, kl_direction='forwards')
¶
JBCE Loss¶
drn.models.ddr.jbce_loss(dists, y, alpha=0.0)
¶
The joint binary cross entropy loss. Args: dists: the predicted distributions y: the observed values alpha: the penalty parameter
NLL Loss¶
drn.models.ddr.nll_loss(dists, y, alpha=0.0)
¶
Model Utilities¶
Cutpoint Generation¶
drn.models.drn_cutpoints(c_0, c_K, y, proportion=None, num_cutpoints=None, min_obs=1)
¶
GLM Utilities¶
drn.models.glm.gaussian_deviance_loss(y_pred, y_true)
¶
Calculate the Normal deviance loss for the Gaussian distribution. Args: y_pred: the predicted values (shape: (n,)) y_true: the observed values (shape: (n,)) Returns: the deviance loss (shape: (,))
drn.models.glm.gamma_deviance_loss(y_pred, y_true)
¶
Calculate the Tweedie deviance loss for the gamma distribution. Args: y_pred: the predicted values (shape: (n,)) y_true: the observed values (shape: (n,)) Returns: the deviance loss (shape: (,))