Interpretability¶
Model interpretation and explanation tools for understanding DRN predictions and distributional properties.
DRN Explainer¶
Main class for interpreting DRN models using SHAP and custom visualization methods.
drn.interpretability.DRNExplainer
¶
Functions¶
__init__(drn, glm, default_cutpoints, background_data_raw, preprocessor=None)
¶
Initialise the DRNExplainer with the given parameters.
Args: drn (torch.nn.Module): The DRN neural network model. glm (torch.nn.Module): The baseline Generalised Linear Model (GLM). default_cutpoints (list): Cutpoints used for training the DRN. background_data_raw (pd.DataFrame): Background data prior to preprocessing, preprocessor (ColumnTransformer, optional): To convert raw data into a format suitable for the DRN.
cdf_plot(instance, grid=None, cutpoints=None, other_df_models=None, model_names=None, synthetic_data=None, x_range=None, plot_title=None, plot_baseline=True, density_transparency=1.0, dist_property='Mean', quantile_bounds=None, nsamples_background_fraction=0.1, adjustment=True, method='Kernel', labelling_gap=0.01, top_K_features=3, y_range=None, shap_fontsize=25, figsize=None)
¶
Plot the cumulative distribution function.
empirical_cdf(samples, x)
¶
Compute the empirical CDF for a given value x based on the provided samples.
kernel_shap(explaining_data, distributional_property, adjustment=True, nsamples_background_fraction=1.0, glm_output=False, other_shap_values=None)
¶
Pass on the explaining instance, background data, feature processing and value function to the KernelSHAP_DRN class
kernel_shap_plot(instance_raw, instance, dist_property, quantile_bounds, method='Kernel', nsamples_background_fraction=1.0, adjustment=True, axes=None, top_K_features=3, y_max=None, y_min=None, labelling_gap=0.05, fontsize=25)
¶
Visualises the impact of SHAP values for the top K features on a specified distributional property, including the option to display adjustment effects.
Parameters:
- instance_raw (pd.DataFrame): The instance data prior to any processing.
- instance (torch.Tensor): Processed instance data ready for the model.
- dist_property (str): Target distributional property (e.g., 'Mean', 'Variance').
- method: Method used for SHAP value computation, defaulting to 'Kernel'.
- nsamples_background_fraction (float): Fraction of background data utilised, defaults to 1.0.
- adjustment (bool): Whether to include adjustment effects in the visualisation, defaults to True.
- axes: Matplotlib axes object for plotting.
- top_K_features (int): Number of top features to highlight based on SHAP values.
- Plot styling parameters like y_max
, y_min
, labelling_gap
, fontsize
are for visual adjustments.
max_pdf_in_region(drn_pdf, glm_pdf, interval_width, cutpoint_idx)
¶
Find the maximum pdf value within the region
mean_drn(instances)
¶
Calculate the mean predicted by the DRN network given the selected instances/features
mean_glm(instances)
¶
Calculate the mean predicted by the GLM given the selected instances/features
mean_value_function(instances, adjustment)
¶
Calculate the mean value function given the selected instances/features
plot_adjustment_factors(instance, observation=None, cutpoints=None, num_interpolations=None, other_df_models=None, model_names=None, percentiles=None, cutpoints_label_bool=False, synthetic_data=None, plot_adjustments_labels=True, axes=None, x_range=None, y_range=None, plot_title=None, plot_mean_adjustment=False, plot_y_label=None, density_transparency=1.0, figsize=None)
¶
Plot the adjustment factors for each of the partitioned interval. expand: interpolation of cutpoints for density evaluations.
plot_dp_adjustment_shap(instance_raw, dist_property='Mean', quantile_bounds=None, method='Kernel', nsamples_background_fraction=1.0, top_K_features=3, adjustment=True, other_df_models=None, model_names=None, cutpoints=None, num_interpolations=None, labelling_gap=0.05, synthetic_data=None, synthetic_data_samples=int(1000000.0), observation=None, plot_baseline=True, x_range=None, y_range=None, plot_y_label=None, plot_title=None, figsize=None, density_transparency=1.0, shap_fontsize=25, legend_loc='upper left')
¶
Plot SHAP value-based adjustments with an option to include density functions.
Args:
instance_raw: Raw data before one-hot encoding, used for instance-specific analysis.
dist_property: Distributional property to adjust ('Mean', 'Variance', 'Quantile').
method: Technique for SHAP value computation ('Kernel', 'Tree', 'FastSHAP', etc.).
nsamples_background_fraction: Fraction of background data for SHAP calculation.
top_K_features: Number of top features based on SHAP values.
adjustment: Whether to plot the SHAP values of the adjusted or unadjusted distributional property
other_df_models: Other distributional forecasting models for comparison.
model_names: Names of the other distributional forecasting models.
cutpoints: Cutpoints for partitioning feature space, defaults to self.default_cutpoints
.
num_interpolations: Number of points for density interpolation, defaults to 2000.
labelling_gap: Gap between labels in the plot for readability.
synthetic_data, synthetic_data_samples: Synthetic data function for true density comparison and number of samples generated.
observation: Specific observation value for vertical line plotting.
plot_baseline: Flag to include baseline model's density plot.
x_range, y_range: Axis ranges for the plot.
plot_y_label, plot_title: Custom labels for the plot's axes and title.
density_transparency: Alpha value for density plot transparency.
shap_fontsize, figsize, label_adjustment_factor: Plot styling parameters.
legend_loc: Location of the legend in the plot.
quantile_drn(instances, percentile=[90], grid=None)
¶
Calculate the quantile predicted by the DRN network given the selected instances/features
quantile_glm(instances, percentile=[90], grid=None)
¶
Calculate the quantile predicted by the GLM given the selected instances/features
quantile_value_function(instances, adjustment, grid=None, percentile=[90])
¶
Calculate the quantile value function given the selected instances/features
real_adjustment_factors(instances, cutpoints)
¶
Calculate the real adjustment factors.
region_adjustments(instance, region_start, region_end)
¶
Calculate and round the adjustment factors
region_text(instance, interval_width, drn_pdf, glm_pdf, y_max, region_start, region_end, cutpoint_idx, adjustment_idx, cutpoints_label_bool=False, percentiles=None)
¶
Text the density adjustment regions
set_value_function(distributional_property, adjustment, model_function)
¶
Calculate the numeric part from the distributional property XX% quantile. Set the value function accordingly.
Kernel SHAP Integration¶
Specialized SHAP explainer for distributional properties of DRN models.
drn.kernel_shap_explainer.KernelSHAP_DRN
¶
This class produces the Kernel SHAP values regarding the distributional property of interest. It produces the raw Kernel SHAP values. It also generates SHAP dependence plot for any pair of features, considering categorical features. Beeswarm plot can be generated for any features.
Functions¶
__init__(explaining_data, nsamples_background_fraction, background_data_raw, value_function, glm_value_function, other_shap_values=None, random_state=42)
¶
Args: See the DRNExplainer class for explanations regarding {explaining_data, nsamples_background_fraction, background_data_raw, preprocessor} value_function: v_{M}(S, x), given any instance x and indices S \subseteq {1, ..., p}
beeswarm_plot(features=None, output='value')
¶
Create the beeswarm summary plots features: a list of feature names required for plotting adjusting: False --> explaining the drn model; True --> explaining how the drn adjusts the glm
forward()
¶
The raw Kernel SHAP (either adjusted or DRN) output.
global_importance_plot(features=None, output='value')
¶
Creates a global importance plot based on the absolute SHAP values.
shap_dependence_plot(features_tuple, output='value')
¶
Create the SHAP dependence plots features_tuple: the pair of features required for plotting other_shap_values: allows for externally calculated SHAP values, i.e., FastSHAP...
shap_glm_values()
¶
The raw Kernel SHAP (GLM) output.
shap_values_mean_adjustments()
¶
The SHAP values and feature names
Key Methods¶
Plot Adjustment Factors¶
drn.interpretability.DRNExplainer.plot_adjustment_factors(instance, observation=None, cutpoints=None, num_interpolations=None, other_df_models=None, model_names=None, percentiles=None, cutpoints_label_bool=False, synthetic_data=None, plot_adjustments_labels=True, axes=None, x_range=None, y_range=None, plot_title=None, plot_mean_adjustment=False, plot_y_label=None, density_transparency=1.0, figsize=None)
¶
Plot the adjustment factors for each of the partitioned interval. expand: interpolation of cutpoints for density evaluations.
Plot Distributional Property Adjustment with SHAP¶
drn.interpretability.DRNExplainer.plot_dp_adjustment_shap(instance_raw, dist_property='Mean', quantile_bounds=None, method='Kernel', nsamples_background_fraction=1.0, top_K_features=3, adjustment=True, other_df_models=None, model_names=None, cutpoints=None, num_interpolations=None, labelling_gap=0.05, synthetic_data=None, synthetic_data_samples=int(1000000.0), observation=None, plot_baseline=True, x_range=None, y_range=None, plot_y_label=None, plot_title=None, figsize=None, density_transparency=1.0, shap_fontsize=25, legend_loc='upper left')
¶
Plot SHAP value-based adjustments with an option to include density functions.
Args:
instance_raw: Raw data before one-hot encoding, used for instance-specific analysis.
dist_property: Distributional property to adjust ('Mean', 'Variance', 'Quantile').
method: Technique for SHAP value computation ('Kernel', 'Tree', 'FastSHAP', etc.).
nsamples_background_fraction: Fraction of background data for SHAP calculation.
top_K_features: Number of top features based on SHAP values.
adjustment: Whether to plot the SHAP values of the adjusted or unadjusted distributional property
other_df_models: Other distributional forecasting models for comparison.
model_names: Names of the other distributional forecasting models.
cutpoints: Cutpoints for partitioning feature space, defaults to self.default_cutpoints
.
num_interpolations: Number of points for density interpolation, defaults to 2000.
labelling_gap: Gap between labels in the plot for readability.
synthetic_data, synthetic_data_samples: Synthetic data function for true density comparison and number of samples generated.
observation: Specific observation value for vertical line plotting.
plot_baseline: Flag to include baseline model's density plot.
x_range, y_range: Axis ranges for the plot.
plot_y_label, plot_title: Custom labels for the plot's axes and title.
density_transparency: Alpha value for density plot transparency.
shap_fontsize, figsize, label_adjustment_factor: Plot styling parameters.
legend_loc: Location of the legend in the plot.
CDF Plot¶
drn.interpretability.DRNExplainer.cdf_plot(instance, grid=None, cutpoints=None, other_df_models=None, model_names=None, synthetic_data=None, x_range=None, plot_title=None, plot_baseline=True, density_transparency=1.0, dist_property='Mean', quantile_bounds=None, nsamples_background_fraction=0.1, adjustment=True, method='Kernel', labelling_gap=0.01, top_K_features=3, y_range=None, shap_fontsize=25, figsize=None)
¶
Plot the cumulative distribution function.