"""
Sequential Monte Carlo sampler for junction tree distributions.
"""
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import trilearn.distributions.sequential_junction_tree_distributions as seqdist
import trilearn.graph.graph as glib
import trilearn.graph.junction_tree as jtlib
#import trilearn.graph.junction_tree_gt as jtgt
import trilearn.graph.junction_tree_collapser
import trilearn.graph.junction_tree_expander
import trilearn.set_process as sp
import trilearn.auxiliary_functions as aux
[docs]
def smc_ggm_graphs(N, alpha, beta, radius, X, D, delta):
cache = {}
seq_dist = seqdist.GGMJTPosterior()
seq_dist.init_model(X, D, delta, cache)
(trees, log_w) = approximate(N, alpha, beta, radius, seq_dist)
log_w_rescaled = np.array(log_w.T)[seq_dist.p - 1] - \
max(np.array(log_w.T)[seq_dist.p - 1])
norm_w = np.exp(log_w_rescaled) / sum(np.exp(log_w_rescaled))
graphs = [jtlib.graph(tree) for tree in trees]
return (graphs, norm_w)
[docs]
def smc_approximate_ggm(N, alpha, beta, radius, X, D, delta):
(graphs, probs) = smc_ggm_graphs(N, alpha, beta, radius, X, D, delta)
dist = {graphs[i]: probs[i] for i in range(len(graphs))}
return dist
[docs]
def approximate(N, alpha, beta, radius, seq_dist, debug=False, neig_set_cache={}):
""" Sequential Monte Carlo for junction trees using the christmas
tree algorithm as proposal kernel.
Args:
N (int): number
alpha (float): sparsity parameter for the Christmas tree algorithm
beta (float): sparsity parameter for the Christmas tree algorithm
radius (float): defines the radius within which ned nodes are selected
seqdist (SequentialJTDistributions): the distribution to be sampled from
Returns:
(new_trees, log_w)
References:
"""
p = seq_dist.p
log_w = np.matrix(np.zeros((N, p)))
new_trees = [None for _ in range(N)]
old_trees = [None for _ in range(N)]
ind_perms = np.matrix(np.zeros((N, p)), dtype=object)
total = set(range(p))
for n in range(p):
norm_w = None
new_trees = [None for _ in range(N)]
if n > 0:
log_w_rescaled = np.array(log_w.T)[n - 1] - max(np.array(log_w.T)[n - 1])
norm_w = np.exp(log_w_rescaled) / sum(np.exp(log_w_rescaled))
I = np.random.choice(N, size=N, p=norm_w)
for i in range(N):
if i % 5000 == 0 and not i == 0 and debug:
print("n: " + str(n) + ", i: " + str(i))
if n == 0:
ind_perms[i, n] = sp.gen_order_neigh([], radius, total)
node = ind_perms[i, n][n]
T = jtlib.JunctionTree()
#T = jtgt.JunctionTreeGT()
T.add_node(frozenset([node]), label=tuple([node]), color="red")
new_trees[i] = T
log_w[i, n] = 0.0
else:
order_frozenset = frozenset(ind_perms[I[i], n - 1])
if order_frozenset not in neig_set_cache:
neig_set_cache[order_frozenset] = sp.order_neigh_set(ind_perms[I[i], n - 1], radius, total)
ind_perms[i, n] = ind_perms[I[i], n - 1] + [aux.random_element_from_coll(neig_set_cache[order_frozenset])]
node = ind_perms[i, n][n]
new_trees[i], K_st, old_cliques, old_separators, new_cliques, new_separators = trilearn.graph.junction_tree_expander.sample(
old_trees[I[i]], node, alpha, beta, only_tree=False)
# Backward kernel
log_R = trilearn.graph.junction_tree_collapser.log_pdf(new_trees[i], old_trees[I[i]], node)
log_density_ratio = seq_dist.log_ratio(old_cliques,
old_separators,
new_cliques,
new_separators,
old_trees[I[i]], new_trees[i])
log_w[i, n] = log_density_ratio + log_R - np.log(K_st)
old_trees = new_trees
return (new_trees, log_w)
[docs]
def approximate_cond(N, alpha, beta, radius, seq_dist, T_cond, perm_cond, debug=False, neig_set_cache={}):
""" SMC an junction trees conditioned on the trajectories T_cond
and perm_cond.
"""
p = seq_dist.p
log_w = np.matrix(np.zeros((N, p)))
Is = np.matrix(np.zeros((N, p)), dtype=int)
old_trees = [None for _ in range(N)]
new_trees = [None for _ in range(N)]
ind_perms = np.matrix(np.zeros((N, p)), dtype=object)
total = range(p)
maxradius = radius >= p
copy_time = 0.0
for n in range(p):
# Reset the new trees and perms so that we do not alter the old ones
new_trees = [None for _ in range(N)]
norm_w = None
if n > 0:
log_w_rescaled = np.array(log_w.T)[n - 1] - max(np.array(log_w.T)[n - 1])
norm_w = np.exp(log_w_rescaled) / sum(np.exp(log_w_rescaled))
I = np.random.choice(N, size=N, p=norm_w)
for i in range(N):
if i % 500 == 0 and not i == 0 and debug:
print("n: " + str(n) + ", i: " + str(i))
if n == 0:
# Index permutation
ind_perms[i, n] = sp.gen_order_neigh([], radius, total)
node = ind_perms[i, n][n]
T = jtlib.JunctionTree()
#T = jtgt.JunctionTreeGT()
T.add_node(frozenset(ind_perms[i, n]))
new_trees[i] = T
log_w[i, n] = 0.0
else:
tmp = np.matrix(I).reshape((N, 1))
Is[np.ix_(range(N), [n])] = tmp
if i == 0:
# Weights for the fixed trajectory
T_old = T_cond[n - 1]
T = T_cond[n]
new_trees[i] = T
ind_perms[i, n] = perm_cond[n]
old_cliques = T_old.nodes()
old_separators = T_old.get_separators()
new_cliques = T.nodes()
new_separators = T.get_separators()
node = list(set(perm_cond[n]) - set(perm_cond[n - 1]))[0]
K_st = trilearn.graph.junction_tree_expander.pdf(T_old, T, alpha, beta, node)
log_order_pr = sp.backward_order_neigh_log_prob(perm_cond[n - 1],
perm_cond[n],
radius, maxradius)
else:
# Weights for rest
T_old = old_trees[I[i]] # Create an nx.Graph once for speed.
# Get permutation
order_frozenset = frozenset(ind_perms[I[i], n - 1])
if order_frozenset not in neig_set_cache:
neig_set_cache[order_frozenset] = sp.order_neigh_set(ind_perms[I[i], n - 1], radius, total)
ind_perms[i, n] = ind_perms[I[i], n - 1] + [aux.random_element_from_coll(neig_set_cache[order_frozenset])]
node = ind_perms[i, n][n] # the added node
# Expand the junction tree T
new_trees[i], K_st, old_cliques, old_separators, new_cliques, new_separators = trilearn.graph.junction_tree_expander.sample(
T_old, node, alpha, beta, only_tree=False)
log_order_pr = sp.backward_order_neigh_log_prob(ind_perms[I[i], n - 1],
ind_perms[i, n],
radius, maxradius)
T = new_trees[i]
log_R = log_order_pr + trilearn.graph.junction_tree_collapser.log_pdf(T, T_old, node)
log_w[i, n] = seq_dist.log_ratio(old_cliques,
old_separators,
new_cliques,
new_separators,
T_old,
T) + log_R - np.log(K_st)
old_trees = new_trees
return (new_trees, log_w, Is)
[docs]
def est_log_norm_consts(order, n_particles, sequential_distribution, alpha=0.5, beta=0.5, n_smc_estimates=1,
debug=False):
log_consts = np.zeros(
n_smc_estimates * (order)
).reshape(n_smc_estimates, (order))
def estimate_norm_const(order, weights):
log_consts = np.zeros(order)
for n in range(1, order):
log_consts[n] = log_consts[n - 1] + np.log(np.mean(weights[:, n]))
return log_consts
for t in tqdm(range(n_smc_estimates), desc="Const estimates"):
(trees, log_w) = approximate(n_particles, alpha, beta, sequential_distribution.p, sequential_distribution)
w = np.exp(log_w)
log_consts[t, :] = estimate_norm_const(order, w)
if debug:
unique_trees = set()
for tree in trees:
tree_alt = (frozenset(tree.nodes()), frozenset([frozenset(e) for e in tree.edges()]))
unique_trees.add(tree_alt)
print("Sampled unique junction trees: " + str(len(unique_trees)))
unique_graphs = set([glib.hash_graph(jtlib.graph(tree)) for tree in trees])
print("Sampled unique chordal graphs: {n_unique_chordal_graphs}".format(
n_unique_chordal_graphs=len(unique_graphs)),
)
if n_smc_estimates == 1:
log_consts = log_consts.flatten()
return log_consts
[docs]
def est_n_dec_graphs(order, n_particles, alpha=0.5, beta=0.5, n_smc_estimates=1, debug=False):
sd = seqdist.CondUniformJTDistribution(order)
log_consts = est_log_norm_consts(order, n_particles, sd, alpha, beta, n_smc_estimates, debug)
return np.exp(log_consts)
[docs]
def est_dec_max_clique_size(order, n_particles, alpha=0.5, beta=0.5, n_smc_estimates=1, debug=False):
expected_maxl_clique_sizes = []
for t in range(n_smc_estimates):
if debug: print("Iteration: " + str(t + 1) + "/" + str(n_smc_estimates))
max_clique_sizes, norm_w = uniform_dec_maxl_clique_size_samples(order, n_particles,
alpha=alpha, beta=beta, debug=debug)
est_exp = (max_clique_sizes * norm_w).sum() # weighted expected value
expected_maxl_clique_sizes.append(est_exp)
if debug:
print(t, est_exp)
return expected_maxl_clique_sizes
[docs]
def get_smc_trajs(Is):
""" This method is made for visualizing the collapsing in SMC.
"""
p = Is.shape[1]
N = Is.shape[0]
for i in reversed(range(N)):
t = get_traj(p - 1, i, Is) + [i]
if i == 0:
plt.plot(range(p), t, color="r")
else:
plt.plot(range(p), t, color="b")
plt.show()
[docs]
def get_traj(n, i, Is):
if n == 0:
return []
else:
return get_traj(n - 1, Is[i, n], Is) + [Is[i, n]]