"""
Simple implementation of analytic Variational Bayes to infer a 
nonlinear forward model

This implements section 4 of the FMRIB Variational Bayes tutorial 1
for a single exponential decay model
"""
import numpy as np

# This starts the random number generator off with the same seed value
# each time, so the results are repeatable. However it is worth changing
# the seed (or simply removing this line) to see how different data samples
# affect the results
np.random.seed(0)

# Ground truth parameters
A_TRUTH = 42
LAM_TRUTH = 1.5
NOISE_PREC_TRUTH = 100
NOISE_VAR_TRUTH = 1/NOISE_PREC_TRUTH
NOISE_STD_TRUTH = np.sqrt(NOISE_VAR_TRUTH)

# The nonlinear model we are going to fit
def model(t, a, lam):
    """
    Simple exponential decay model
    """
    return a * np.exp(-lam * t)

# Observed data samples are generated by Numpy from the ground truth
# Gaussian distribution. Reducing the number of samples should make
# the inference less 'confident' - i.e. the output variances for
# MU and BETA will increase
N = 100
DT = 0.2
t = np.array([float(t)*DT for t in range(N)])
DATA_CLEAN = model(t, A_TRUTH, LAM_TRUTH)
DATA_NOISY = DATA_CLEAN + np.random.normal(0, NOISE_STD_TRUTH, [N])
print("Data samples are:")
print(t)
print(DATA_CLEAN)
print(DATA_NOISY)

# Priors - noninformative because of high variance
#
# Note that the noise posterior is a gamma distribution
# with shape and scale parameters s, c. The mean here is
# b*c and the variance is c * b^2. To make this more 
# intuitive we define a prior mean and variance for the 
# noise parameter BETA and express the prior scale
# and shape parameters in terms of these
#
# So long as the priors stay noninformative they should not 
# have a big impact on the inferred values - this is the 
# point of noninformative priors. However if you start to
# reduce the prior variances the inferred values will be
# drawn towards the prior values and away from the values
# suggested by the data
a0 = 1.0
a_var0 = 1000
lam0 = 1.0
lam_var0 = 1.0

beta_mean0 = 1
beta_var0 = 1000
# c=scale, s=shape parameters for Gamma distribution
c0 = beta_var0 / beta_mean0
s0 = beta_mean0**2 / beta_var0

# Priors as vectors/matrices - M=means, C=covariance, P=precision
M0 = np.array([a0, lam0])
C0 = np.array([[a_var0, 0], [0, lam_var0]])
P0 = np.linalg.inv(C0)

def calc_jacobian(M):
    """
    Numerical differentiation to calculate Jacobian matrix
    of partial derivatives of model prediction with respect to
    parameters
    """
    J = None
    for param_idx, param_value in enumerate(M):
        ML = np.array(M)
        MU = np.array(M)
        delta = param_value * 1e-5
        if delta < 0:
            delta = -delta
        if delta < 1e-10:
            delta = 1e-10
            
        MU[param_idx] += delta
        ML[param_idx] -= delta
        
        YU = model(t, MU[0], MU[1])
        YL = model(t, ML[0], ML[1])
        if J is None:
            J = np.zeros([len(YU), len(M)], dtype=np.float32)
        J[:, param_idx] = (YU - YL) / (2*delta)
    return J

def update_model_params(k, M, P, s, c, J):
    """
    Update model parameters

    From section 4.2 of the FMRIB Variational Bayes Tutorial I
    
    k = data - prediction
    M = means (prior = M0)
    P = precision (prior=P0)
    s = noise shape (prior = s0)
    c = noise scale (prior = c0)
    J = Jacobian
    """
    P_new = s*c*np.dot(J.transpose(), J) + P0
    C_new = np.linalg.inv(P_new)
    M_new = np.dot(C_new, (s * c * np.dot(J.transpose(), (k + np.dot(J, M))) + np.dot(P0, M0)))
    return M_new, P_new

def update_noise(k, P, J):
    """
    Update noise parameters

    From section 4.2 of the FMRIB Variational Bayes Tutorial I
    
    k = data - prediction
    P = precision (prior=P0)
    J = Jacobian
    """
    C = np.linalg.inv(P)
    c_new = N/2 + c0
    s_new = 1/(1/s0 + 1/2 * np.dot(k.transpose(), k) + 1/2 * np.trace(np.dot(C, np.dot(J.transpose(), J))))
    return c_new, s_new

# Initial posterior parameters
M = np.array([1.0, 1.0])
C = np.array([[1.0, 0], [0.0, 1.0]])
P = np.linalg.inv(C)
c = 1e-8
s = 50.0

print("Iteration 0: A=%f, lam=%f, noise=%f" % (M[0], M[1], c*s))

# Update model and noise parameters iteratively
for idx in range(20):
    k = DATA_NOISY - model(t, M[0], M[1])
    J = calc_jacobian(M)
    M, P = update_model_params(k, M, P, s, c, J)
    c, s = update_noise(k, P, J)
    print("Iteration %i: A=%f, lam=%f, noise=%f" % (idx+1, M[0], M[1], c*s))
