user850760 user850760 - 1 year ago 161
Python Question

CDF of MultivariateNormalDiag in tensorflow

I can run this example from here:

mu = [1, 2, 3.]
diag_stdev = [4, 5, 6.]
dist = tf.contrib.distributions.MultivariateNormalDiag(mu, diag_stdev)
dist.pdf([-1., 0, 1])

but when I substitute the last line for
dist.cdf([-1., 0, 1])
I get a not implemented error:

NotImplementedError: log_cdf is not implemented

Can anybody suggest a workaround for the time being at least?

Answer Source

Based on the solutions in here and here, I've implemented the following solution:

import tensorflow as tf
import numpy as np
from scipy.stats import mvn

def py_func(func, inp, Tout, stateful=True, name=None, grad=None):

    # Need to generate a unique name to avoid duplicates:
    rnd_name = 'PyFuncGrad' + str(np.random.randint(0, 1E+10))

    g = tf.get_default_graph()
    with g.gradient_override_map({"PyFunc": rnd_name}):
        return tf.py_func(func, inp, Tout, stateful=stateful, name=name)

def np_cdf(mean, diag_sigma, value, name=None):
  low = np.array([-30, -30])
  cdf = list()
  for variables in zip(value, mean, diag_sigma):
    S = np.diag(variables[2])
    p, _ = mvn.mvnun(low,variables[0],variables[1],S)

  cdfs = np.asarray(cdf, dtype=np.float32).reshape([-1,1])
  return cdfs

def cdf_gradient(op, grad): 
  mu = op[0]
  diag_sigma = op[1]
  value = op[2]
  dist = tf.contrib.distributions.MultivariateNormalDiag(mu, diag_sigma)
  pdf = dist.pdf(value)
  dc_dv = tf.inv(diag_sigma) * pdf
  dc_dm = -1 * dc_dv
  dc_ds = tf.div(value-mu,tf.square(diag_sigma)+1e-6) * pdf
  return grad * dc_dm, grad * dc_ds, grad * dc_dv

def tf_cdf(mean, diag_sigma, value, name=None):

    with tf.name_scope(name, "MyCDF", [mean, diag_sigma, value]) as name:
        cdf = py_func(np_cdf,
                        [mean, diag_sigma, value],
                        grad=cdf_gradient)  # <-- here's the call to the gradient
        return cdf[0]
Recommended from our users: Dynamic Network Monitoring from WhatsUp Gold from IPSwitch. Free Download