databricks.labs.dqx.anomaly.ensemble_training
Ensemble trainer for row anomaly detection models.
Encapsulates the logic for training multiple models with different random seeds to create a robust ensemble.
EnsembleTrainer Objects
class EnsembleTrainer()
Trains ensemble of anomaly detection models with different random seeds.
Responsibilities:
- Feature engineering (done once, reused for all models)
- Training multiple models with varied seeds
- Computing validation metrics for each model
- Registering models to the registry
- Aggregating ensemble metrics
Uses dependency injection for the model registry, enabling testing with mocks.
__init__
def __init__(registry: ModelRegistryBase | None = None) -> None
Initialize ensemble trainer.
Arguments:
registry- Model registry to use. Defaults to MLflow/Unity Catalog.
train
def train(train_df: DataFrame, val_df: DataFrame, columns: list[str],
params: AnomalyParams, ensemble_size: int,
model_name: str) -> EnsembleTrainingResult
Train an ensemble of models.
Arguments:
train_df- Training DataFrameval_df- Validation DataFramecolumns- Feature columns to useparams- Training parametersensemble_size- Number of models in ensemblemodel_name- Base name for registered models
Returns:
EnsembleTrainingResult with model URIs, metrics, and metadata
train_ensemble
def train_ensemble(
train_df: DataFrame, val_df: DataFrame, columns: list[str],
params: AnomalyParams, ensemble_size: int, model_name: str
) -> tuple[list[str], dict[str, Any], dict[str, float], dict[str, float],
SparkFeatureMetadata]
Train ensemble of models with different random seeds.