Source code for bayescart.node

"""Node class for bayescart.

This module defines the Node class used to represent nodes in a tree.
It extends the Node implementation from treelib.
"""

import numpy as np
from treelib import Node as TreelibNode
import pandas as pd
from typing import Sequence, Any
from copy import deepcopy
from .mytyping import NDArrayInt, NDArrayFloat, T
from .exceptions import InvalidTreeError
from .node_data import NodeData
from collections import defaultdict


[docs] class Node(TreelibNode): """ Extended node class for Bayesian CART. Provides functionalities for adding data, splitting, and updating parameters. This class is mostly a wrapper around the NodeData object, which handles the data and parameters associated with the node. Attributes ---------- is_l : bool Flag indicating if this node is a left child. _data : NodeData The node data (parameters and associated data). _rng : np.random.Generator The random number generator. debug : bool If True, enable debug checks. _depth : int The depth of the node. """
[docs] def __init__(self, id: int, is_l: bool, data: NodeData, rng: np.random.Generator, debug: bool): super().__init__(identifier=id) self.is_l: bool = is_l self._data: NodeData = data self._rng = rng self.debug = debug self._depth = -1
@property def id(self): return self.identifier @property def depth(self): return self._depth @depth.setter def depth(self, val: int): if val < 0: raise ValueError('Node depth must be non-negative') self._depth = val def __deepcopy__(self, memo): return self.copy(light=False, memo=memo)
[docs] def copy(self, light: bool = False, no_data: bool = False, memo: dict|None = None) -> 'Node': if memo is None: memo = {} cls = self.__class__ result = cls.__new__(cls) for k, v in self.__dict__.items(): if k == '_data': setattr(result, k, v.copy(light=light, no_data=no_data, memo=memo)) elif k == '_rng': setattr(result, k, v) else: setattr(result, k, deepcopy(v, memo)) return result
[docs] def has_data(self) -> bool: return self._data.has_data()
def _gen_tags(self): """ Generate a string tag for the node for plotting purposes. """ left_or_right = 'L' if self.is_l else 'R' if self.is_leaf(): if self.has_data(): self.tag = f"{left_or_right}_{self.identifier}_{self._data.get_nobs()}_{self._data.get_params(print=True)}" else: self.tag = f"{left_or_right}_{self.identifier}_{self._data.nobs}_{self._data.get_params(print=True)}" # type: ignore else: self.tag = f'{left_or_right}_{self.identifier}_{self._data.get_split_var(print=True)} {self._data.get_split_set(print=True)}'
[docs] def get_nobs(self) -> int: return self._data.get_nobs()
[docs] def get_available_splits(self, *args, **kw) -> tuple[dict[str, Sequence[T]], dict[str, bool]]: """ Get available splits from the underlying NodeData. Returns ------- tuple (avail_vars, is_cat) """ return self._data.get_available_splits(*args, **kw)
[docs] def get_new_split(self) -> tuple[str, Sequence[T] | T]: """ Sample a new split for the node. Returns ------- tuple (split_var, split_val) """ # Note that the splitting procedure depends on whether the # # variable is categorical or not. This is checked by the # # NodeData object (which is the only object having access to the data). # Get all the possible splits avail_vars,_ = self._data.get_available_splits() # If no split available, return None if len(avail_vars) == 0: raise InvalidTreeError('No available variable to split') # sample a split variable uniformly split_var = self._rng.choice(list(avail_vars.keys())) avail_vals = avail_vars[split_var] # sample a split value is_cat, split_val = self._data.sample_split(split_var, avail_vals) return split_var, split_val
[docs] def get_split_info(self) -> tuple[str, Sequence[T] | T, bool]: """ Retrieve the current split rule: split variable (str), split value (array if cat, float else), whether the split is categorical. Returns ------- tuple (split_var, split_val, is_cat) """ split_var, split_val = self._data.get_split_var(), self._data.get_split_set() is_cat = self._data.is_cat_split return split_var, split_val, is_cat
[docs] def update_split_info(self, split_var: str, split_val: Sequence[T] | T): self._data.update_split_info(split_var, split_val)
# self._gen_tags()
[docs] def update_split_data(self, X: pd.DataFrame, y: pd.Series): self._data.update_split_data(X, y)
[docs] def get_split_data(self, split_var: str, split_val: Sequence[T] | T, left_params: Any, right_params: Any) -> tuple[NodeData, NodeData]: """ Split the data at this node into two parts. Returns the children. Parameters ---------- split_var : str The feature to split on. split_val : Sequence[T] or T The split rule. left_params : Any Parameters for the left child. right_params : Any Parameters for the right child. Returns ------- tuple (left_node_data, right_node_data) """ return self._data.get_split_data(split_var, split_val, left_params, right_params)
[docs] def get_data_split(self, split_var: str|None = None, split_val: Sequence[T] | T|None = None) -> tuple[pd.DataFrame, pd.DataFrame, pd.Series, pd.Series]: """ Split the node's data using the current (or provided) split rule. Returns the left and right data subsets. Parameters ---------- split_var : str or None, optional The splitting feature. If None, uses the current split_var. split_val : Sequence[T] or T or None, optional The splitting value(s). If None, uses the current split_set. Returns ------- tuple (left_X, right_X, left_y, right_y) """ if split_var is None: split_var = self._data.get_split_var() if split_val is None: split_val = self._data.get_split_set() return self._data.get_data_split(split_var, split_val)
[docs] def is_split_rule_empty(self) -> bool: return self._data.is_split_rule_emtpy()
[docs] def count_values(self) -> NDArrayInt: """ Count the of observations per class. Returns ------- NDArrayInt Array of counts per class. """ return self._data.count_values()
[docs] def get_data_averages(self) -> tuple[float, float, float]: """ Compute data averages (number of observations, mean, un-normalized variance) for regression. Returns ------- tuple (n, mean, un-normalized variance) """ return self._data.get_data_averages()
[docs] def update_node_params(self, params: tuple[float, float] | NDArrayFloat): self._data.update_node_params(params)
# self._gen_tags()
[docs] def get_preds(self) -> tuple[NDArrayFloat, float]: """ Get the predictions from the node's model. Returns ------- tuple (indices, predictions) """ return self._data.get_preds()
[docs] def get_true_preds(self) -> tuple[NDArrayFloat, NDArrayFloat]: """ Return the true predictions (i.e. the actual response values). Returns ------- tuple (indices, true response values) """ y = self._data.y return y.index.to_numpy(), y.to_numpy()
[docs] def get_params(self, print: bool = False) -> Any: return self._data.get_params(print)
[docs] def calc_avail_split_and_vars(self) -> tuple[int, int]: """ Compute the number of available variables for splitting, and the number of available splits for the current split variable. Returns ------- tuple (number of available variables, number of available splits) """ return self._data.calc_avail_split_and_vars()
[docs] def reset_split_info(self): self._data.reset_split_info()
[docs] class NodeFast(Node): """ Fast implementation of Node with optimized copy operations. """
[docs] def copy(self, light: bool = False, no_data: bool = False, memo: dict|None = None) -> 'Node': """ Copy the node, with an option for a light (optimized) copy. Change: Instead of thoroughly copying the object, just copy what we know we need. """ # _initial_tree_id contains the hash (str) of the first tree the node was attached to. In our case, this is unique. # _predecessor is a dict mapping tree_ids to node_ids, with the idea of sharing nodes across trees, potentially. # I am still not sure about this implementation, but it doesn't harm me. # Similarly, _successors maps onto the list of children, which in my case is up to two. if not light: return super().copy(light=light, no_data=no_data, memo=memo) else: cls = self.__class__ result = cls.__new__(cls) result.is_l = self.is_l result._data = self._data.copy(light=light, no_data=no_data, memo=memo) result.debug = self.debug result.identifier = self.identifier result._depth = self._depth result._tag = self._tag result.data = self.data result.expanded = self.expanded result._rng = self._rng result.ADD = self.ADD result.DELETE = self.DELETE result.INSERT = self.INSERT result.REPLACE = self.REPLACE k = self._initial_tree_id if self.debug: assert len(self._predecessor) == 1 assert len(self._successors[k]) <= 2 if k is not None: result._predecessor = {k: self._predecessor[k]} result._successors = self._successors result._successors = defaultdict(list) result._successors[k] = [x for x in self._successors[k]] result._initial_tree_id = k else: result._predecessor = {} result._successors = defaultdict(list) result._initial_tree_id = None return result