CodyBugstein CodyBugstein - 1 month ago 8
Python Question

How does the axis parameter from NumPy work?

Can someone explain exactly what the

axis
parameter in NumPy does?

I am terribly confused.

I'm trying to use the function
myArray.sum(axis=num)


At first I thought if the array is itself 3 dimensions,
axis=0
will return three elements, consisting of the sum of all nested items in that same position. If each dimension contained five dimensions, I expected
axis=1
to return a result of five items, and so on.

However this is not the case, and the documentation does not do a good job helping me out (they use a 3x3x3 array so it's hard to tell what's happening)

Here's what I did:

>>> e
array([[[1, 0],
[0, 0]],

[[1, 1],
[1, 0]],

[[1, 0],
[0, 1]]])
>>> e.sum(axis = 0)
array([[3, 1],
[1, 1]])
>>> e.sum(axis=1)
array([[1, 0],
[2, 1],
[1, 1]])
>>> e.sum(axis=2)
array([[1, 0],
[2, 1],
[1, 1]])
>>>


Clearly the result is not intuitive.

Answer

Clearly,

e.shape == (3, 2, 2)

Sum over an axis is a reduction operation so the specified axis disappears. Hence,

e.sum(axis=0).shape == (2, 2)
e.sum(axis=1).shape == (3, 2)
e.sum(axis=2).shape == (3, 2)