blerud blerud - 11 months ago 58
Python Question

theano shared variable has wrong shape in scan function

I have a theano shared variable of shape (1, 500), but when passed to a scan function the shape turns out to be (1, 1, 500). Example code snippet is below.

y_t1 = theano.shared(name='y_t1', value=np.zeros((1, 500), dtype=theano.config.floatX))

def forward(X, y_t1):
return y_t1

(hyp), _ = theano.scan(fn=forward, sequences=X, outputs_info=[y_t1])

y_t1 is created with size (1, 500) and reports its shape to be (1, 500) outside of the function "forward", but inside "forward" it has shape (1, 1, 500). Why does this happen?


Answer Source

Pass it in as

(hyp), _ = theano.scan(fn=forward, sequences=X, outputs_info=y_t1)

It should work fine then. (I've removed the brackets around y_t1 in outputs_info)

Explanation: Theano converts whatever you pass in after the = to a tensor. So if you pass in a list, it is first converted to a tensor of that shape. Thus when you're passing in [y_t1], you're basically adding an extra dimension to your input.