Chris Chris - 1 year ago 148
Python Question

PyMC3 Multinomial Model doesn't work with non-integer observe data

I'm trying to use PyMC3 to solve a fairly simple multinomial distribution. It works perfectly if I have the 'noise' value set to 0.0. However when I change it to anything else, for example 0.01, I get an error in the find_MAP() function and it hangs if I don't use find_MAP(). Is there some reason that the multinomial has to be sparse?

import numpy as np
from pymc3 import *
import pymc3 as mc
import pandas as pd
print 'pymc3 version: ' + mc.__version__

sample_size = 10
number_of_experiments = 1

true_probs = [0.2, 0.1, 0.3, 0.4]

k = len(true_probs)

noise = 0.0
y = np.random.multinomial(n=number_of_experiments, pvals=true_probs, size=sample_size)+noise
y_denominator = np.sum(y,axis=1)
y = y/y_denominator[:,None]

with Model() as multinom_test:
probs = Dirichlet('probs', a = np.ones(k), shape = k)
for i in range(sample_size):
data = Multinomial('data_%d' % (i),
n = y[i].sum(),
p = probs,
observed = y[i])

with multinom_test:
start = find_MAP()
trace = sample(5000, Slice())


ValueError: Optimization error: max, logp or dlogp at max have non-finite values. Some values may be outside of distribution support. max: {'probs_stickbreaking_': array([ 0.00000000e+00, -4.47034834e-08, 0.00000000e+00])} logp: array(-inf) dlogp: array([ 0.00000000e+00, 2.98023221e-08, 0.00000000e+00])Check that 1) you don't have hierarchical parameters, these will lead to points with infinite density. 2) your distribution logp's are properly specified. Specific issues:

Answer Source

This works for me

sample_size = 10
number_of_experiments = 100

true_probs = [0.2, 0.1, 0.3, 0.4]
k = len(true_probs)
noise = 0.01
y = np.random.multinomial(n=number_of_experiments, pvals=true_probs, size=sample_size)+noise

with pm.Model() as multinom_test:
    a = pm.Dirichlet('a', a=np.ones(k))
    for i in range(sample_size):
        data_pred = pm.Multinomial('data_pred_%s'% i, n=number_of_experiments, p=a, observed=y[i])
    trace = pm.sample(50000, pm.Metropolis())
    #trace = pm.sample(1000) # also works with NUTS



Recommended from our users: Dynamic Network Monitoring from WhatsUp Gold from IPSwitch. Free Download