Source code for skfolio.distribution.multivariate._utils

"""Utils module for multivariate distribution."""

# Copyright (c) 2025
# Author: Hugo Delatte <delatte.hugo@gmail.com>
# Credits: Matteo Manzi, Vincent Maladière, Carlo Nicolini
# SPDX-License-Identifier: BSD-3-Clause

from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import auto
from functools import cached_property
from itertools import combinations
from typing import Union

import numpy as np
import scipy.sparse.csgraph as ssc
import scipy.stats as st
import sklearn.feature_selection as sf

from skfolio.utils.tools import AutoEnum


[docs] class DependenceMethod(AutoEnum): """ Enumeration of methods to measure bivariate dependence. Attributes ---------- KENDALL_TAU Use Kendall's tau correlation coefficient. MUTUAL_INFORMATION Use mutual information estimated via a k-nearest neighbors method. WASSERSTEIN_DISTANCE Use the Wasserstein (Earth Mover's) distance. """ KENDALL_TAU = auto() MUTUAL_INFORMATION = auto() WASSERSTEIN_DISTANCE = auto()
@dataclass class EdgeCondSets: """ Container for conditioning sets associated with an edge in an R-vine. Attributes ---------- conditioned : tuple[int, int] A tuple of conditioned variable indices. conditioning : set[int] A set of conditioning variable indices. """ conditioned: tuple[int, int] conditioning: set[int] def to_set(self) -> set[int]: """Union of conditioned and conditioning sets.""" return set(self.conditioned) | self.conditioning def __add__(self, other: "EdgeCondSets") -> "EdgeCondSets": """Combine two EdgeCondSets, merging conditioned and conditioning sets.""" if not isinstance(other, self.__class__): raise TypeError( f"Cannot add a EdgeCondSets with an object of type {type(other)}" ) s1 = self.to_set() s2 = other.to_set() conditioning = s1 & s2 conditioned = tuple(s1 ^ s2) # maintain order if conditioned[0] in other.conditioned: conditioned = conditioned[::-1] return self.__class__(conditioned=conditioned, conditioning=conditioning) def __repr__(self) -> str: """String representation of the EdgeCondSets.""" if self.conditioning: return f"{self.conditioned} | {self.conditioning}" return str(self.conditioned) class BaseNode(ABC): """Base class for Nodes of the R-vine tree. Parameters ---------- ref : int or Edge For RootNode: reference of the variable index. For ChildNode: reference of the edge in the previous tree. Attributes ---------- edges : set[Edge] The set of edges attached to this node. tree : Tree The Tree containing this Node. """ def __init__(self, ref: Union[int, "Edge"]): self._ref = ref self.edges: set[Edge] = set() self.tree: Tree | None = None # Reference to the Tree containing this Node @property def ref(self) -> Union[int, "Edge"]: """Return the reference of this node (read-only).""" return self._ref @abstractmethod def clear_cache(self, **kwargs): """Clear the cached pseudo-values and margin values (u and v).""" pass def __repr__(self) -> str: """String representation of the node.""" return f"Node({self.ref})" class RootNode(BaseNode): """Root Node of the R-vine tree. Parameters ---------- ref : int The reference variable index. central : bool True if the node is central; otherwise, False. pseudo_values : ndarray, optional The pseudo-values of the Root Node. Attributes ---------- edges : set[Edge] The set of edges attached to this node. tree : Tree The Tree containing this Node. """ def __init__( self, ref: int, central: bool, pseudo_values: np.ndarray | None = None ): super().__init__(ref=ref) self.central = central self.pseudo_values = pseudo_values def clear_cache(self, **kwargs): """Clear the cached margin values (u and v).""" self.pseudo_values = None class ChildNode(BaseNode): """Child Node of the R-vine tree. A child node is an edge from the previous tree. Parameters ---------- ref : Edge The reference edge in the previous tree. Attributes ---------- edges : set[Edge] The set of edges attached to this node. tree : Tree The Tree containing this Node. """ def __init__(self, ref: "Edge"): super().__init__(ref=ref) # pointer from Edge to Node ref.ref_node = self self._central: bool | None = None self._u: np.ndarray | None = None self._v: np.ndarray | None = None self._u_count: int = 0 self._v_count: int = 0 self._u_count_total: int = 0 self._v_count_total: int = 0 @property def central(self) -> bool: """Determine whether this node is considered central. It is inherited from the associated edge's centrality. Returns ------- central: bool True if the node is central; otherwise, False. """ if self._central is None: self._central = self.ref.strongly_central return self._central @property def u(self) -> np.ndarray: """Get the first margin value (u) for the node. It is obtained by computing the partial derivative of the copula with respect to v. Returns ------- u : ndarray The u values for this node. """ is_count = self.tree is not None and self.tree.is_count_visits if is_count: self._u_count_total += 1 else: self._u_count += 1 if self._u is None: X = self.ref.get_X() if is_count: self._u = np.array([np.nan]) else: self._u = self.ref.copula.partial_derivative(X, first_margin=False) value = self._u # Clear cache if ( not is_count and self._u_count_total != 0 and self._u_count == self._u_count_total ): self._u = None self._u_count = 0 return value @u.setter def u(self, value: np.ndarray) -> None: self._u = value @property def v(self) -> np.ndarray: """Get the second margin value (v) for the node. It is obtained by computing the partial derivative of the copula with respect to u. Returns ------- v : ndarray The v values for this node. """ is_count = self.tree is not None and self.tree.is_count_visits if is_count: self._v_count_total += 1 else: self._v_count += 1 if self._v is None: X = self.ref.get_X() if is_count: self._v = np.array([np.nan]) else: self._v = self.ref.copula.partial_derivative(X, first_margin=True) value = self._v # Clear cache if ( not is_count and self._v_count_total != 0 and self._v_count == self._v_count_total ): self._v = None self._v_count = 0 return value @v.setter def v(self, value: np.ndarray): self._v = value def get_var(self, is_left: bool) -> int: """Return the variable index associated with this node. The variable is determined by the conditioned set of the edge. Parameters ---------- is_left : bool Indicates whether to select the left or right node. Returns ------- var : int The variable index corresponding to this node. """ if is_left is None: raise ValueError("is_left cannot be None for Child Nodes") var = self.ref.cond_sets.conditioned[0 if is_left else 1] return var def clear_cache(self, clear_count: bool): """Clear the cached margin values (u and v) and counts. Parameters ---------- clear_count : bool If True, the visit counts are also reset. """ self._u = None self._v = None if clear_count: self._u_count = 0 self._v_count = 0 self._u_count_total = 0 self._v_count_total = 0 class Edge: """ Represents an edge in an R-vine tree connecting two nodes. This class encapsulates the information for an edge between two nodes in an R-vine, including the associated copula, the dependence measure, and the conditioning sets. Attributes ---------- node1 : RootNode | ChildNode The first node in the edge. node2 : RootNode | ChildNode The second node in the edge. dependence_method : DependenceMethod The method used to measure dependence between the two nodes. copula : object or None The fitted copula for this edge (if available). ref_node : Node or None A pointer to the node in the next tree constructed from this edge. """ def __init__( self, node1: RootNode | ChildNode, node2: RootNode | ChildNode, dependence_method: DependenceMethod = DependenceMethod.KENDALL_TAU, ): self.node1 = node1 self.node2 = node2 self.dependence_method = dependence_method self.copula = None self.ref_node = None # Pointer to the next tree Node @cached_property def weakly_central(self) -> bool: """Determine if the edge is weakly central. An edge is weakly central if at least one of its two nodes is central. """ return self.node1.central or self.node2.central @cached_property def strongly_central(self) -> bool: """Determine if the edge is strongly central. An edge is strongly central if both of its nodes are central. """ return self.node1.central and self.node2.central @cached_property def dependence(self) -> float: """Dependence measure between the two nodes. This is computed on the data from the edge using the specified dependence method. """ X = self.get_X() dep = _dependence(X, dependence_method=self.dependence_method) return dep @cached_property def cond_sets(self) -> EdgeCondSets: """Compute the conditioning sets for the edge. For a root node edge, the conditioned set consists of the two variable indices. For non-root nodes, the conditioning sets are obtained by combining the conditioning sets of the two edges from the previous tree. """ if isinstance(self.node1, RootNode): return EdgeCondSets( conditioned=(self.node1.ref, self.node2.ref), conditioning=set() ) return self.node1.ref.cond_sets + self.node2.ref.cond_sets def ref_to_nodes(self): """Connect this edge to its two nodes.""" self.node1.edges.add(self) self.node2.edges.add(self) def get_X(self) -> np.ndarray: """Retrieve the bivariate pseudo-observation data associated with the edge. For a root edge, this returns the pseudo-values from node1 and node2. For non-root edges, the appropriate margins (u or v) are selected based on the shared node order. Returns ------- X : ndarray of shape (n_observations, 2) The bivariate pseudo-observation data corresponding to this edge. """ if isinstance(self.node1, RootNode): u = self.node1.pseudo_values v = self.node2.pseudo_values else: is_left1, is_left2 = self.node1.ref.shared_node_is_left(self.node2.ref) u = self.node1.v if is_left1 else self.node1.u v = self.node2.v if is_left2 else self.node2.u X = np.stack([u, v]).T return X def shared_node_is_left(self, other: "Edge") -> tuple[bool, bool]: """Determine the ordering of shared nodes between this edge and another edge. If the two edges share one node, this method indicates for each edge whether the shared node is the left node. Parameters ---------- other : Edge Another edge to compare with. Returns ------- is_left1, is_left2 : tuple[bool, bool] A tuple (is_left1, is_left2) where is_left1 is True if the shared node is the left node of self and is_left2 is True if the shared node is the left node of other. Raises ------ ValueError If the edges do not share exactly one node. """ if self.node1 == other.node1: return True, True if self.node2 == other.node1: return False, True if self.node2 == other.node2: return False, False # self.node1 == other.node2 raise ValueError("Edges are not correctly ordered") def share_one_node(self, other: "Edge") -> bool: """Check whether two edges share exactly one node. Parameters ---------- other : Edge Another edge to compare with. Returns ------- bool True if the two edges share exactly one node; otherwise, False. """ return len({self.node1, self.node2} & {other.node1, other.node2}) == 1 def __repr__(self) -> str: """String representation of the edge.""" if self.copula is None: return f"Edge({self.cond_sets})" return f"Edge({self.cond_sets}, {self.copula.fitted_repr})" class Tree: """ Represents an R-vine tree at level k. A Tree consists of a set of nodes and the edges connecting them. It represents one level (k) in the R-vine structure. Parameters ---------- level : int The tree level (k) in the R-vine. nodes : list[Node] A list of Node objects representing the nodes in this tree. Attributes ---------- edges : list[Edge] The list of edges in the Tree. is_count_visits : bool Whether to count the number of visit of each Node during sampling. """ def __init__(self, level: int, nodes: list[RootNode | ChildNode]): self.level = level self._nodes = nodes for node in nodes: # pointer from Node to Tree node.tree = self self.edges = None self.is_count_visits: bool = False @property def nodes(self) -> list[RootNode | ChildNode]: """Return the tree nodes (read-only).""" return self._nodes def set_edges_from_mst(self, dependence_method: DependenceMethod) -> None: """Construct the Maximum Spanning Tree (MST) from the current nodes using the specified dependence method. The MST is built based on pairwise dependence measures computed between nodes. If any edge is (weakly) central, a central factor is added to the dependence measure to favor edges connected to central nodes. Parameters ---------- dependence_method : DependenceMethod The method used to compute the dependence measure between nodes (e.g., Kendall's tau). Returns ------- None """ n = len(self.nodes) dependence_matrix = np.zeros((n, n)) eligible_edges = {} central = False for i, j in combinations(range(n), 2): node1 = self.nodes[i] node2 = self.nodes[j] if self.level == 0 or node1.ref.share_one_node(node2.ref): edge = Edge( node1=node1, node2=node2, dependence_method=dependence_method ) if not central and edge.weakly_central: central = True # Negate the matrix to use minimum_spanning_tree for maximum spanning # Add a cst to ensure that even if dep is 0, we still build a valid MST dep = abs(edge.dependence) + 1e-5 dependence_matrix[i, j] = dep eligible_edges[(i, j)] = edge if np.any(np.isnan(dependence_matrix)): raise RuntimeError("dependence_matrix contains NaNs") if central: max_dep = np.max(dependence_matrix) for (i, j), edge in eligible_edges.items(): if edge.weakly_central: if edge.strongly_central: central_factor = 3 * max_dep else: central_factor = 2 * max_dep dep = dependence_matrix[i, j] + central_factor dependence_matrix[i, j] = dep # Compute the minimum spanning tree mst = ssc.minimum_spanning_tree(-dependence_matrix, overwrite=True) edges = [] # Extract the indices of the non-zero entries (edges) for i, j in zip(*mst.nonzero(), strict=True): edge = eligible_edges[(i, j)] # connect Nodes to Edges edge.ref_to_nodes() edges.append(edge) self.edges = edges def clear_cache(self, clear_count: bool = True): """Clear cached values for all nodes in the tree.""" for node in self.nodes: node.clear_cache(clear_count=clear_count) def __repr__(self): """String representation of the tree.""" return f"Tree(level {self.level})" def _dependence(X, dependence_method: DependenceMethod) -> float: """Compute the dependence between two variables in X using the specified method. Parameters ---------- X : array-like of shape (n_observations, 2) A 2D array of bivariate inputs (u, v), where u and v are assumed to lie in [0, 1]. dependence_method : DependenceMethod The method to use for measuring dependence. Options are: - DependenceMethod.KENDALL_TAU - DependenceMethod.MUTUAL_INFORMATION - DependenceMethod.WASSERSTEIN_DISTANCE Returns ------- dependence : float The computed dependence measure. Raises ------ ValueError If X does not have exactly 2 columns or if an unsupported dependence method is provided. """ X = np.asarray(X) if X.ndim != 2 or X.shape[1] != 2: raise ValueError("X must be a 2D array with exactly 2 columns.") match dependence_method: case DependenceMethod.KENDALL_TAU: dep = st.kendalltau(X[:, 0], X[:, 1]).statistic case DependenceMethod.MUTUAL_INFORMATION: dep = sf.mutual_info_regression(X[:, 0].reshape(-1, 1), X[:, 1])[0] case DependenceMethod.WASSERSTEIN_DISTANCE: dep = st.wasserstein_distance(X[:, 0], X[:, 1]) case _: raise ValueError(f"Dependence method {dependence_method} not valid") return dep