Mike S. Mike S. - 1 month ago 9
Python Question

Find All Binary Splits of a Nominal Attribute

Question



I'm trying to build a binary decision tree classifier in Python from scratch based on a data set that has only nominal attributes.

One step I'm stuck on is finding all possible ways to compute a binary split of a nominal attribute. For example, for an attribute with possible values [a, b, c, d], I am looking for a way to split these in two arrays such that we obtain:

left right
---- -----
a bcd
b acd
c abd
d abc
ab cd
ac bd
ad bc


without duplicate splits (e.g. we don't need "bc" in
left
and "ad" in
right
since this would yield the same binary split as "ad" in
left
and "bc" in
right
). Order within each split is also irrelevant (e.g. "ad" is the same as "da" in one side of a split).

Current Attempt



The exact terminology is escaping me, but I think this is some form of combination/permutation problem. I know its not quite a powerset I'm after. The closest question I could find similar to mine is linked here.

So far I've started an iterative procedure:

for i in range(1, array.length / 2):
for j in range(1, i):
# stuck here


The reason for looping only through the floor of half the length of the attribute's possible values (
array
) is because if we store up to
array.length / 2
values in
left
, right has
1 - (array.length / 2)
values, covering all possible splits.

Also, I've heard of the
itertools
library .. so perhaps there's a way to achieve what I'm after with that?

Answer

I would use itertools.product to write a function that splits a sequence into all possible divisions of two halves. I'd iterate through that and remove duplicates using a set.

import itertools

def binary_splits(seq):
    for result_indices in itertools.product((0,1), repeat=len(seq)):
        result = ([], [])
        for seq_index, result_index in enumerate(result_indices):
            result[result_index].append(seq[seq_index])
        #skip results where one of the sides is empty
        if not result[0] or not result[1]: continue
        #convert from list to tuple so we can hash it later
        yield map(tuple, result)

def binary_splits_no_dupes(seq):
    seen = set()
    for item in binary_splits(seq):
        key = tuple(sorted(item))
        if key in seen: continue
        yield key
        seen.add(key)

for left, right in binary_splits_no_dupes("abcd"):
    print left, right

Result:

('a', 'b', 'c') ('d',)
('a', 'b', 'd') ('c',)
('a', 'b') ('c', 'd')
('a', 'c', 'd') ('b',)
('a', 'c') ('b', 'd')
('a', 'd') ('b', 'c')
('a',) ('b', 'c', 'd')