Source code for trilearn.distributions.matrix_multivariate_normal
import numpy as np
[docs]def sample(M, S, Sigma):
""" Generates a sample from the multivariate matrix
normal distribution.
"""
N = M.shape[0]
K = M.shape[1]
kron = np.kron(S, Sigma)
vec_M = np.array(M.reshape(1, N*K))[0]
vec_X = np.random.multivariate_normal(vec_M, kron)
X = vec_X.reshape(M.shape)
return X