MDN - Mixture Density Network¶
Neural network that models complex distributions as mixtures of simpler components.
Class Definition¶
Bases: BaseModel
Mixture density network that can switch between gamma and Gaussian distribution components.
The distributional forecasts are mixtures of num_components
specified distributions.
Functions¶
__init__(distribution='gamma', num_hidden_layers=2, num_components=5, hidden_size=100, dropout_rate=0.2, learning_rate=0.001)
¶
Args: p: the number of features in the model. num_hidden_layers: the number of hidden layers in the network. num_components: the number of components in the mixture. hidden_size: the number of neurons in each hidden layer. distribution: the type of distribution for the MDN ('gamma' or 'gaussian').
forward(x)
¶
Calculate the parameters of the mixture components. Args: x: the input features (shape: (n, p)) Returns: A list containing the mixture weights, and distribution-specific parameters.
mean(x)
¶
Calculate the predicted means for the given observations, depending on the mixture distribution. Args: x: the input features (shape: (n, p)) Returns: the predicted means (shape: (n,))
Overview¶
MDN (Mixture Density Network) is ideal for: - Multi-modal distributions - Data with multiple peaks - Complex relationships - Non-linear feature-target mappings - Uncertainty quantification - Rich distributional representations - Flexible modeling - Adaptive number of mixture components
Architecture¶
graph LR
A[Input Features] --> B[Neural Network]
B --> C[Mixing Weights]
B --> D[Component Means]
B --> E[Component Stds]
C --> F[Mixture Distribution]
D --> F
E --> F
F --> G[Predictions]
Quick Example¶
from drn.models import MDN
import torch
# Initialize MDN with 3 mixture components
mdn_model = MDN(
input_dim=8,
num_components=3,
hidden_size=128,
num_hidden_layers=2
)
# Train on complex multimodal data
mdn_model.fit(X_train, y_train, epochs=150)
# Generate predictions
predictions = mdn_model.predict(X_test)
# Access mixture properties
mixing_weights = predictions.mixture_distribution.probs
component_means = predictions.mixture_distribution.component_distribution.mean