Source code for skfolio.optimization.ensemble._base
"""Base Composition estimator.
Follow same implementation as Base composition from sklearn
"""
# Copyright (c) 2023
# Author: Hugo Delatte <delatte.hugo@gmail.com>
# License: BSD 3 clause
# Implementation derived from:
# scikit-learn, Copyright (c) 2007-2010 David Cournapeau, Fabian Pedregosa, Olivier
# Grisel Licensed under BSD 3 clause.
from abc import ABC, abstractmethod
from contextlib import suppress
import sklearn.base as skb
[docs]
class BaseComposition(skb.BaseEstimator, ABC):
"""Handles parameter management for ensemble estimators."""
@abstractmethod
def __init__(self):
pass
def _get_params(self, attr, deep=True):
out = super().get_params(deep=deep)
if not deep:
return out
estimators = getattr(self, attr)
try:
out.update(estimators)
except (TypeError, ValueError):
# Ignore TypeError for cases where estimators is not a list of
# (name, estimator) and ignore ValueError when the list is not
# formatted correctly. This is to prevent errors when calling
# `set_params`. `BaseEstimator.set_params` calls `get_params` which
# can error for invalid values for `estimators`.
return out
for name, estimator in estimators:
if hasattr(estimator, "get_params"):
for key, value in estimator.get_params(deep=True).items():
out[f"{name}__{key}"] = value
return out
def _set_params(self, attr, **params):
# Ensure strict ordering of parameter setting:
# 1. All steps
if attr in params:
setattr(self, attr, params.pop(attr))
# 2. Replace items with estimators in params
items = getattr(self, attr)
if isinstance(items, list) and items:
# Get item names used to identify valid names in params
# `zip` raises a TypeError when `items` does not contains
# elements of length 2
with suppress(TypeError):
item_names, _ = zip(*items, strict=True)
for name in params:
if "__" not in name and name in item_names:
self._replace_estimator(attr, name, params.pop(name))
# 3. Step parameters and other initialisation arguments
super().set_params(**params)
return self
def _replace_estimator(self, attr, name, new_val):
# assumes `name` is a valid estimator name
new_estimators = list(getattr(self, attr))
for i, (estimator_name, _) in enumerate(new_estimators):
if estimator_name == name:
new_estimators[i] = (name, new_val)
break
setattr(self, attr, new_estimators)
def _validate_names(self, names):
if len(set(names)) != len(names):
raise ValueError(f"Names provided are not unique: {list(names)!r}")
invalid_names = set(names).intersection(self.get_params(deep=False))
if invalid_names:
raise ValueError(
f"Estimator names conflict with constructor arguments: {sorted(invalid_names)!r}"
)
invalid_names = [name for name in names if "__" in name]
if invalid_names:
raise ValueError(
f"Estimator names must not contain __: got {invalid_names!r}"
)