"""
Utility functions for BCART experiments evaluation and analysis.
This module contains helper functions for computing tree probabilities,
comparing trees from BCART runs and provide summary tables of estimated
posterior probabilities for the most commonly visited trees. This can be
used to gauge convergence.
"""
from .tree import Tree
from .bcart import BCART
import numpy as np
import pandas as pd
def _gen_bcart_obj(tree: Tree, run):
"""
Generate a BCARTClassic object from a given tree and run result. Can be used to compute probabilities.
This function makes a copy of the input tree, assigns the data from the run
(X and y) to the root node if not already present, updates the subtree data,
and returns a BCARTClassic object configured with the run setup.
Parameters
----------
tree : Tree
The tree for which to generate a BCART object.
run : dict
A run result dictionary containing keys 'data' and 'setup'. The 'data'
key must include X and y; 'setup' includes hyperparameters such as alpha,
beta, a, mu_bar, nu, lambd, iters, burnin, thinning, and move_prob.
Returns
-------
BCARTClassic
A BCARTClassic object initialized with the given tree and run parameters.
"""
tree = tree.copy()
X, y = run['data']['X'], run['data']['y']
stp = run['setup']
if not tree.nodes[0].has_data():
tree.get_root()._data.X = run['data']['X']
tree.get_root()._data.y = run['data']['y']
tree.update_subtree_data(tree.get_root())
bcart = BCART(X=X, y=y, alpha=stp['alpha'], beta=stp['beta'], a=stp['a'], mu_bar=stp['mu_bar'], nu=stp['nu'], lambd=stp['lambd'], iters=stp['iters'], burnin=stp['burnin'], thinning=stp['thinning'], move_prob=stp['move_prob'], light=stp['light'], seed=stp['seed'])
bcart.tree = tree
return bcart
[docs]
def calc_tree_post_prob(tree: Tree, run):
"""
Calculate the (log) posterior probability of a tree given a run result.
Parameters
----------
tree : Tree
The tree for which to compute the posterior probability.
run : dict
A run result dictionary containing configuration and data information.
Returns
-------
float
The log posterior probability of the given tree.
"""
bcart = _gen_bcart_obj(tree, run)
return bcart.get_log_posterior_prob(tree)
[docs]
def calc_tree_llik(tree: Tree, run):
"""
Calculate the integrated log-likelihood of a tree given a run result.
Parameters
----------
tree : Tree
The tree for which to compute the likelihood.
run : dict
A run result dictionary containing configuration and data information.
Returns
-------
float
The integrated log-likelihood P(Y|X,T) of the tree.
"""
if tree.llik is not None:
return tree.llik
bcart = _gen_bcart_obj(tree, run)
return bcart.calc_llik(bcart.tree)
[docs]
def calc_log_tree_prob(tree: Tree, run):
"""
Calculate the log prior probability of a tree given a run result.
Parameters
----------
tree : Tree
The tree for which to compute the prior probability.
run : dict
A run result dictionary containing configuration and data information.
Returns
-------
float
The log prior probability P(T|X) of the tree.
"""
if tree.log_tree_prior_prob is not None:
return tree.log_tree_prior_prob
bcart = _gen_bcart_obj(tree, run)
return bcart.calc_log_tree_prob(bcart.tree)
[docs]
def calc_log_tree_prob_and_llik(tree: Tree, run):
"""
Calculate both the log prior probability and the integrated log-likelihood of a tree.
Parameters
----------
tree : Tree
The tree for which to compute the probabilities.
run : dict
A run result dictionary containing configuration and data information.
Returns
-------
tuple of float
A tuple (prior, llik) where 'prior' is the log prior probability and 'llik'
is the integrated log-likelihood of the tree.
"""
if tree.llik is not None:
llik = tree.llik
else:
llik = None
if tree.log_tree_prior_prob is not None:
prior = tree.log_tree_prior_prob
else:
prior = None
if llik is None or prior is None:
bcart = _gen_bcart_obj(tree, run)
if llik is None:
llik = bcart.calc_llik(bcart.tree)
tree.llik = llik
if prior is None:
prior = bcart.calc_log_tree_prob(bcart.tree)
tree.log_tree_prior_prob = prior
return prior, llik
#%% Compare trees rigorously
def _get_run_from_res(res_or_run):
if isinstance(res_or_run, list):
return res_or_run[0]
elif isinstance(res_or_run, dict):
return res_or_run
def _calc_cats_if_data_missing(res_or_run):
run = _get_run_from_res(res_or_run)
if not run['tree_store'][0].all_nodes()[0].has_data(): # type: ignore
categories = {}
X, y = run['data']['X'], run['data']['y'] # type: ignore
for col in X.columns:
if hasattr(X[col], 'cat'):
categories[col] = set(X[col].cat.categories)
return categories
else:
return None
[docs]
def compare_trees(tree1: Tree, tree2: Tree, res_or_run, type='prob'):
"""
Compare two trees to check if they are equivalent under a given mode.
There are different levels of comparison:
- Basic, using buil-in tree comaprison (hard=0)
- Probability, using data likelihood and tree prior (hard=1).
Note that different likelihoods imply different partitions,
with high probability. Instead, the prior should tell apart trees
that have same partition but different structure so that the two
trees are not probabilistically equivalent.
Parameters
----------
tree1 : Tree
The first tree to compare.
tree2 : Tree
The second tree to compare.
res_or_run : dict or list
Either a single run result dictionary or a list of run results,
used to extract necessary configuration and data.
type : str, optional
Comparison type ('basic' or 'prob'). Default is 'prob'.
Returns
-------
bool
True if the trees are considered equal under the chosen criteria, False otherwise.
"""
categories = _calc_cats_if_data_missing(res_or_run)
basic = tree1.is_equal(tree2, hard=1, categories=categories)
if basic:
# if this says they are the same, then they are. If No, it might miss some.
return True
if type == 'basic':
return basic
if type == 'prob':
prior1, llik1 = calc_log_tree_prob_and_llik(tree1, _get_run_from_res(res_or_run))
prior2, llik2 = calc_log_tree_prob_and_llik(tree2, _get_run_from_res(res_or_run))
# bcart1 = _gen_bcart_obj(tree1, _get_run_from_res(res_or_run))
# bcart2 = _gen_bcart_obj(tree2, _get_run_from_res(res_or_run))
# llik1 = bcart1.calc_llik(bcart1.tree)
# llik2 = bcart2.calc_llik(bcart2.tree)
# prior1 = bcart1.calc_log_tree_prob(bcart1.tree)
# prior2 = bcart2.calc_log_tree_prob(bcart2.tree)
cond1 = np.isclose(llik1, llik2)
cond2 = np.isclose(prior1, prior2)
prob = cond1 and cond2
return prob
def _find_tree_idx(tree, trees, res_or_run, type):
"""
Find the index of a tree within a list of unique trees.
Parameters
----------
tree : Tree
The tree to search for.
trees : list
A list of unique Tree objects.
res_or_run : dict or list
A run result dictionary or list of run results.
type : str
Comparison type for matching trees (e.g. 'prob' or 'basic').
Returns
-------
int or None
The index of the matching tree if found; otherwise, None.
"""
for idx, unq_tree in enumerate(trees):
if compare_trees(tree, unq_tree, res_or_run, type=type):
return idx
return None
def _summarize_trees(trees, res_or_run, top_n: int|None = None, type='prob'):
tot_trees = len(trees)
unique_trees: list[Tree] = [trees[0]]
counts = {0: 1}
for tree in trees[1:]:
tree_idx = _find_tree_idx(tree, unique_trees, res_or_run, type)
if tree_idx is not None:
counts[tree_idx] += 1
else:
unique_trees.append(tree)
counts[len(unique_trees)-1] = 1
most_freq_trees_idx = sorted([(k,v/tot_trees) for k, v in counts.items()], key=lambda x: x[1], reverse=True)
if top_n is not None:
most_freq_trees_idx = most_freq_trees_idx[:top_n]
else:
top_n = len(unique_trees)
return [most_freq_trees_idx[i][1] for i in range(top_n)], [unique_trees[idx[0]] for idx in most_freq_trees_idx]
[docs]
def summarize_trees(run, top_n=None, type='prob', plot=False):
"""
Summarize the most common trees from a single run.
The function extracts the stored 'tree_store' from the run result, summarizes the top
trees using frequency of occurrence (by the chosen comparison type), and optionally plots
each unique tree.
Parameters
----------
run : dict
A run result dictionary containing a 'tree_store' key.
top_n : int or None, optional
The number of top trees to extract. If None, all unique trees are returned.
type : str, optional
The type of comparison to use ('prob' or 'basic'). Default is 'prob'.
plot : bool, optional
If True, each unique tree is displayed using its show() method. Default is False.
Returns
-------
tuple
A tuple (freq_list, unique_trees) representing the frequencies and corresponding unique trees.
"""
trees: list[Tree] = run['tree_store']
res = _summarize_trees(trees, run, top_n, type)
if plot:
for tree in res[1]:
tree.show()
return res
[docs]
def produce_tree_table(res):
"""
Produce a summary table of tree posterior probability estimates across runs.
This function:
1. Extracts the top 5 most frequent trees from each run.
2. Pools all top trees and computes a unique set.
3. Sorts the unique trees by their average empirical frequency (posterior probability)
across runs.
4. Constructs a table with each unique tree's estimated frequency per run, mean, standard
deviation, and number of terminal nodes.
Parameters
----------
res : list
A list of run result dictionaries.
Returns
-------
pd.DataFrame
A pandas DataFrame summarizing the frequency estimates for the most frequent trees.
"""
# 1. extract the top 5 most likely from each run
freq_runs = []
unique_trees_runs = []
all_trees = []
for run in res:
freq, unique_trees = summarize_trees(run)
freq_runs.append(freq)
unique_trees_runs.append(unique_trees)
# now only keep the top 5 trees
all_trees.extend(unique_trees[:5])
# 2. we pool all top 5's and take the unique set; this number gives us the total rows
_, unique_trees = _summarize_trees(all_trees, res)
nrows = len(unique_trees)
ncols = len(res)
tbl = np.zeros((nrows, ncols))
# for each row, we want to find the corresponding chain probability (col), hence compare prob across all visited trees
for col in range(ncols):
freq, trees = freq_runs[col], unique_trees_runs[col]
for row in range(nrows):
# find the frequency in the relevant chain fo the current row tree
tree_idx = _find_tree_idx(unique_trees[row], trees, res, type='prob')
tbl[row, col] = float(f'{freq[tree_idx]:0.2f}') if tree_idx is not None else 0
sorted_idx = np.argsort(tbl.mean(axis=1))[::-1]
tbl = tbl[sorted_idx,:]
n_leaves = np.array([unique_trees[idx].get_n_leaves() for idx in sorted_idx])
mn = np.round(tbl.mean(axis=1)[:,np.newaxis], decimals=2)
std = np.round(tbl.std(axis=1)[:,np.newaxis], decimals=2)
tbl = np.hstack([tbl, mn, std, n_leaves[:,np.newaxis]])
tbl = pd.DataFrame(tbl, columns=[f'C{i}' for i in range(ncols)] + ['Mn', 'Std', 'b'])
return tbl