eclique eclique - 1 month ago 19
Python Question

Picking and assigning multiple subsets from multiple sets in numpy

Given base array X of shape (2, 3, 4) which can be interpreted as two sets of 3 elements each, where every element is 4-dimensional, I want to sample from this array X in the following way.
From each of 2 sets I want to pick 2 subsets each defined by the binary array of length 3, other subsets would be set to 0. So the sampling process is defined by the array of shape (2, 2, 3). The result of this sampling should have shape (2, 2, 3, 4).

Here's the code that does what I need but I wonder if it could be rewritten more efficiently using numpy indexing.

import numpy as np
np.random.seed(3)

sets = np.random.randint(0, 10, [2, 3, 4])
subset_masks = np.random.randint(0, 2, [2, 2, 3])

print('Base set\n', sets, '\n')
print('Subset masks\n', subset_masks, '\n')

result = np.empty([2, 2, 3, 4])
for set_index in range(sets.shape[0]):
for subset_index, subset in enumerate(subset_masks[set_index]):
print('----')
picked_subset = subset.reshape(3, 1) * sets[set_index]
result[set_index][subset_index] = picked_subset
print('Picking subset ', subset, 'from set #', set_index)
print(picked_subset, '\n')


Output

Base set
[[[8 9 3 8]
[8 0 5 3]
[9 9 5 7]]

[[6 0 4 7]
[8 1 6 2]
[2 1 3 5]]]

Subset masks
[[[0 0 1]
[1 0 0]]

[[1 0 1]
[0 1 1]]]

----
Picking subset [0 0 1] from set # 0
[[0 0 0 0]
[0 0 0 0]
[9 9 5 7]]

----
Picking subset [1 0 0] from set # 0
[[8 9 3 8]
[0 0 0 0]
[0 0 0 0]]

----
Picking subset [1 0 1] from set # 1
[[6 0 4 7]
[0 0 0 0]
[2 1 3 5]]

----
Picking subset [0 1 1] from set # 1
[[0 0 0 0]
[8 1 6 2]
[2 1 3 5]]

Answer Source

Extend each of them to 4D by adding new axis for subset_masks along the last one and for sets as the second axis. For adding those new axes, we can use None/np.newaxis. Then, leverage NumPy broadcasting to perform the element-wise multiplication, like so -

subset_masks[...,None]*sets[:,None]

Just for the kicks probably, we can also use np.einsum -

np.einsum('ijk,ilj->iljk',sets,subset_masks)