Chris Chris - 5 days ago 5
Python Question

PyMC3 Multinomial Model doesn't work with non-integer observe data

I'm trying to use PyMC3 to solve a fairly simple multinomial distribution. It works perfectly if I have the 'noise' value set to 0.0. However when I change it to anything else, for example 0.01, I get an error in the find_MAP() function and it hangs if I don't use find_MAP(). Is there some reason that the multinomial has to be sparse?


import numpy as np
from pymc3 import *
import pymc3 as mc
import pandas as pd
print 'pymc3 version: ' + mc.__version__



sample_size = 10
number_of_experiments = 1



true_probs = [0.2, 0.1, 0.3, 0.4]



k = len(true_probs)



noise = 0.0
y = np.random.multinomial(n=number_of_experiments, pvals=true_probs, size=sample_size)+noise
y_denominator = np.sum(y,axis=1)
y = y/y_denominator[:,None]



with Model() as multinom_test:
probs = Dirichlet('probs', a = np.ones(k), shape = k)
for i in range(sample_size):
data = Multinomial('data_%d' % (i),
n = y[i].sum(),
p = probs,
observed = y[i])


with multinom_test:
start = find_MAP()
trace = sample(5000, Slice())
trace[probs].mean(0)


Error:

ValueError: Optimization error: max, logp or dlogp at max have non-finite values. Some values may be outside of distribution support. max: {'probs_stickbreaking_': array([ 0.00000000e+00, -4.47034834e-08, 0.00000000e+00])} logp: array(-inf) dlogp: array([ 0.00000000e+00, 2.98023221e-08, 0.00000000e+00])Check that 1) you don't have hierarchical parameters, these will lead to points with infinite density. 2) your distribution logp's are properly specified. Specific issues:

Answer

This works for me

sample_size = 10
number_of_experiments = 100

true_probs = [0.2, 0.1, 0.3, 0.4]
k = len(true_probs)
noise = 0.01
y = np.random.multinomial(n=number_of_experiments, pvals=true_probs, size=sample_size)+noise

with pm.Model() as multinom_test:
    a = pm.Dirichlet('a', a=np.ones(k))
    for i in range(sample_size):
        data_pred = pm.Multinomial('data_pred_%s'% i, n=number_of_experiments, p=a, observed=y[i])
    trace = pm.sample(50000, pm.Metropolis())
    #trace = pm.sample(1000) # also works with NUTS

pm.traceplot(trace[500:]);

traceplot

Comments