Skip to content

Interpretability

Model interpretation and explanation tools for understanding DRN predictions and distributional properties.

DRN Explainer

Main class for interpreting DRN models using SHAP and custom visualization methods.

drn.interpretability.DRNExplainer

Functions

__init__(drn, glm, default_cutpoints, background_data_raw, preprocessor=None)

Initialise the DRNExplainer with the given parameters.

Args: drn (torch.nn.Module): The DRN neural network model. glm (torch.nn.Module): The baseline Generalised Linear Model (GLM). default_cutpoints (list): Cutpoints used for training the DRN. background_data_raw (pd.DataFrame): Background data prior to preprocessing, preprocessor (ColumnTransformer, optional): To convert raw data into a format suitable for the DRN.

cdf_plot(instance, grid=None, cutpoints=None, other_df_models=None, model_names=None, synthetic_data=None, x_range=None, plot_title=None, plot_baseline=True, density_transparency=1.0, dist_property='Mean', quantile_bounds=None, nsamples_background_fraction=0.1, adjustment=True, method='Kernel', labelling_gap=0.01, top_K_features=3, y_range=None, shap_fontsize=25, figsize=None)

Plot the cumulative distribution function.

empirical_cdf(samples, x)

Compute the empirical CDF for a given value x based on the provided samples.

kernel_shap(explaining_data, distributional_property, adjustment=True, nsamples_background_fraction=1.0, glm_output=False, other_shap_values=None)

Pass on the explaining instance, background data, feature processing and value function to the KernelSHAP_DRN class

kernel_shap_plot(instance_raw, instance, dist_property, quantile_bounds, method='Kernel', nsamples_background_fraction=1.0, adjustment=True, axes=None, top_K_features=3, y_max=None, y_min=None, labelling_gap=0.05, fontsize=25)

Visualises the impact of SHAP values for the top K features on a specified distributional property, including the option to display adjustment effects.

Parameters: - instance_raw (pd.DataFrame): The instance data prior to any processing. - instance (torch.Tensor): Processed instance data ready for the model. - dist_property (str): Target distributional property (e.g., 'Mean', 'Variance'). - method: Method used for SHAP value computation, defaulting to 'Kernel'. - nsamples_background_fraction (float): Fraction of background data utilised, defaults to 1.0. - adjustment (bool): Whether to include adjustment effects in the visualisation, defaults to True. - axes: Matplotlib axes object for plotting. - top_K_features (int): Number of top features to highlight based on SHAP values. - Plot styling parameters like y_max, y_min, labelling_gap, fontsize are for visual adjustments.

max_pdf_in_region(drn_pdf, glm_pdf, interval_width, cutpoint_idx)

Find the maximum pdf value within the region

mean_drn(instances)

Calculate the mean predicted by the DRN network given the selected instances/features

mean_glm(instances)

Calculate the mean predicted by the GLM given the selected instances/features

mean_value_function(instances, adjustment)

Calculate the mean value function given the selected instances/features

plot_adjustment_factors(instance, observation=None, cutpoints=None, num_interpolations=None, other_df_models=None, model_names=None, percentiles=None, cutpoints_label_bool=False, synthetic_data=None, plot_adjustments_labels=True, axes=None, x_range=None, y_range=None, plot_title=None, plot_mean_adjustment=False, plot_y_label=None, density_transparency=1.0, figsize=None)

Plot the adjustment factors for each of the partitioned interval. expand: interpolation of cutpoints for density evaluations.

plot_dp_adjustment_shap(instance_raw, dist_property='Mean', quantile_bounds=None, method='Kernel', nsamples_background_fraction=1.0, top_K_features=3, adjustment=True, other_df_models=None, model_names=None, cutpoints=None, num_interpolations=None, labelling_gap=0.05, synthetic_data=None, synthetic_data_samples=int(1000000.0), observation=None, plot_baseline=True, x_range=None, y_range=None, plot_y_label=None, plot_title=None, figsize=None, density_transparency=1.0, shap_fontsize=25, legend_loc='upper left')

Plot SHAP value-based adjustments with an option to include density functions.

Args: instance_raw: Raw data before one-hot encoding, used for instance-specific analysis. dist_property: Distributional property to adjust ('Mean', 'Variance', 'Quantile'). method: Technique for SHAP value computation ('Kernel', 'Tree', 'FastSHAP', etc.). nsamples_background_fraction: Fraction of background data for SHAP calculation. top_K_features: Number of top features based on SHAP values. adjustment: Whether to plot the SHAP values of the adjusted or unadjusted distributional property other_df_models: Other distributional forecasting models for comparison. model_names: Names of the other distributional forecasting models. cutpoints: Cutpoints for partitioning feature space, defaults to self.default_cutpoints. num_interpolations: Number of points for density interpolation, defaults to 2000. labelling_gap: Gap between labels in the plot for readability. synthetic_data, synthetic_data_samples: Synthetic data function for true density comparison and number of samples generated. observation: Specific observation value for vertical line plotting. plot_baseline: Flag to include baseline model's density plot. x_range, y_range: Axis ranges for the plot. plot_y_label, plot_title: Custom labels for the plot's axes and title. density_transparency: Alpha value for density plot transparency. shap_fontsize, figsize, label_adjustment_factor: Plot styling parameters. legend_loc: Location of the legend in the plot.

quantile_drn(instances, percentile=[90], grid=None)

Calculate the quantile predicted by the DRN network given the selected instances/features

quantile_glm(instances, percentile=[90], grid=None)

Calculate the quantile predicted by the GLM given the selected instances/features

quantile_value_function(instances, adjustment, grid=None, percentile=[90])

Calculate the quantile value function given the selected instances/features

real_adjustment_factors(instances, cutpoints)

Calculate the real adjustment factors.

region_adjustments(instance, region_start, region_end)

Calculate and round the adjustment factors

region_text(instance, interval_width, drn_pdf, glm_pdf, y_max, region_start, region_end, cutpoint_idx, adjustment_idx, cutpoints_label_bool=False, percentiles=None)

Text the density adjustment regions

set_value_function(distributional_property, adjustment, model_function)

Calculate the numeric part from the distributional property XX% quantile. Set the value function accordingly.

Kernel SHAP Integration

Specialized SHAP explainer for distributional properties of DRN models.

drn.kernel_shap_explainer.KernelSHAP_DRN

This class produces the Kernel SHAP values regarding the distributional property of interest. It produces the raw Kernel SHAP values. It also generates SHAP dependence plot for any pair of features, considering categorical features. Beeswarm plot can be generated for any features.

Functions

__init__(explaining_data, nsamples_background_fraction, background_data_raw, value_function, glm_value_function, other_shap_values=None, random_state=42)

Args: See the DRNExplainer class for explanations regarding {explaining_data, nsamples_background_fraction, background_data_raw, preprocessor} value_function: v_{M}(S, x), given any instance x and indices S \subseteq {1, ..., p}

beeswarm_plot(features=None, output='value')

Create the beeswarm summary plots features: a list of feature names required for plotting adjusting: False --> explaining the drn model; True --> explaining how the drn adjusts the glm

forward()

The raw Kernel SHAP (either adjusted or DRN) output.

global_importance_plot(features=None, output='value')

Creates a global importance plot based on the absolute SHAP values.

shap_dependence_plot(features_tuple, output='value')

Create the SHAP dependence plots features_tuple: the pair of features required for plotting other_shap_values: allows for externally calculated SHAP values, i.e., FastSHAP...

shap_glm_values()

The raw Kernel SHAP (GLM) output.

shap_values_mean_adjustments()

The SHAP values and feature names

Key Methods

Plot Adjustment Factors

drn.interpretability.DRNExplainer.plot_adjustment_factors(instance, observation=None, cutpoints=None, num_interpolations=None, other_df_models=None, model_names=None, percentiles=None, cutpoints_label_bool=False, synthetic_data=None, plot_adjustments_labels=True, axes=None, x_range=None, y_range=None, plot_title=None, plot_mean_adjustment=False, plot_y_label=None, density_transparency=1.0, figsize=None)

Plot the adjustment factors for each of the partitioned interval. expand: interpolation of cutpoints for density evaluations.

Plot Distributional Property Adjustment with SHAP

drn.interpretability.DRNExplainer.plot_dp_adjustment_shap(instance_raw, dist_property='Mean', quantile_bounds=None, method='Kernel', nsamples_background_fraction=1.0, top_K_features=3, adjustment=True, other_df_models=None, model_names=None, cutpoints=None, num_interpolations=None, labelling_gap=0.05, synthetic_data=None, synthetic_data_samples=int(1000000.0), observation=None, plot_baseline=True, x_range=None, y_range=None, plot_y_label=None, plot_title=None, figsize=None, density_transparency=1.0, shap_fontsize=25, legend_loc='upper left')

Plot SHAP value-based adjustments with an option to include density functions.

Args: instance_raw: Raw data before one-hot encoding, used for instance-specific analysis. dist_property: Distributional property to adjust ('Mean', 'Variance', 'Quantile'). method: Technique for SHAP value computation ('Kernel', 'Tree', 'FastSHAP', etc.). nsamples_background_fraction: Fraction of background data for SHAP calculation. top_K_features: Number of top features based on SHAP values. adjustment: Whether to plot the SHAP values of the adjusted or unadjusted distributional property other_df_models: Other distributional forecasting models for comparison. model_names: Names of the other distributional forecasting models. cutpoints: Cutpoints for partitioning feature space, defaults to self.default_cutpoints. num_interpolations: Number of points for density interpolation, defaults to 2000. labelling_gap: Gap between labels in the plot for readability. synthetic_data, synthetic_data_samples: Synthetic data function for true density comparison and number of samples generated. observation: Specific observation value for vertical line plotting. plot_baseline: Flag to include baseline model's density plot. x_range, y_range: Axis ranges for the plot. plot_y_label, plot_title: Custom labels for the plot's axes and title. density_transparency: Alpha value for density plot transparency. shap_fontsize, figsize, label_adjustment_factor: Plot styling parameters. legend_loc: Location of the legend in the plot.

CDF Plot

drn.interpretability.DRNExplainer.cdf_plot(instance, grid=None, cutpoints=None, other_df_models=None, model_names=None, synthetic_data=None, x_range=None, plot_title=None, plot_baseline=True, density_transparency=1.0, dist_property='Mean', quantile_bounds=None, nsamples_background_fraction=0.1, adjustment=True, method='Kernel', labelling_gap=0.01, top_K_features=3, y_range=None, shap_fontsize=25, figsize=None)

Plot the cumulative distribution function.