Skip to content

MDN - Mixture Density Network

Neural network that models complex distributions as mixtures of simpler components.


Class Definition

Bases: BaseModel

Mixture density network that can switch between gamma and Gaussian distribution components. The distributional forecasts are mixtures of num_components specified distributions.

Functions

__init__(distribution='gamma', num_hidden_layers=2, num_components=5, hidden_size=100, dropout_rate=0.2, learning_rate=0.001)

Args: p: the number of features in the model. num_hidden_layers: the number of hidden layers in the network. num_components: the number of components in the mixture. hidden_size: the number of neurons in each hidden layer. distribution: the type of distribution for the MDN ('gamma' or 'gaussian').

forward(x)

Calculate the parameters of the mixture components. Args: x: the input features (shape: (n, p)) Returns: A list containing the mixture weights, and distribution-specific parameters.

mean(x)

Calculate the predicted means for the given observations, depending on the mixture distribution. Args: x: the input features (shape: (n, p)) Returns: the predicted means (shape: (n,))

Overview

MDN (Mixture Density Network) is ideal for: - Multi-modal distributions - Data with multiple peaks - Complex relationships - Non-linear feature-target mappings - Uncertainty quantification - Rich distributional representations - Flexible modeling - Adaptive number of mixture components

Architecture

graph LR
    A[Input Features] --> B[Neural Network]
    B --> C[Mixing Weights]
    B --> D[Component Means]
    B --> E[Component Stds]

    C --> F[Mixture Distribution]
    D --> F
    E --> F

    F --> G[Predictions]

Quick Example

from drn.models import MDN
import torch

# Initialize MDN with 3 mixture components
mdn_model = MDN(
    input_dim=8,
    num_components=3,
    hidden_size=128,
    num_hidden_layers=2
)

# Train on complex multimodal data
mdn_model.fit(X_train, y_train, epochs=150)

# Generate predictions
predictions = mdn_model.predict(X_test)

# Access mixture properties
mixing_weights = predictions.mixture_distribution.probs
component_means = predictions.mixture_distribution.component_distribution.mean