Djvu Djvu - 4 months ago 15
Python Question

AssertionError when running the main function

I implement the kmeans algorithm in python, the code as following. I test the code use some simple data. just as following, which store in a file called data.txt

2 5

3 7

-1 -2

-3 -3

5 4

4 -4

3 -7

3.5 -9

my problem is that during the iteration, some cluster seem become empty, that is the (number of cluster) < k, and after my analysis, this seem will occure, but after search the web, I found no body deal this in the kmeans algorithm.

So I do not know where is the fault? is that because my test data is so simple

import sys
import numpy as np
from math import sqrt

"""
useage: python mykmeans.py mydata.txt k

"""

GAP = 2
MIN_VAL = 1000000

def get_distance(point1, point2):
dis = sqrt(pow(point1[0] - point2[0], 2) + pow(point1[1] - point2[1], 2))

return dis


def cluster_dis(centroid, cluster):
dis = 0.0
for point in cluster:
dis += get_distance(centroid, point)

return dis

def update_centroids(centroids, cluster_id, cluster):
x, y = 0.0, 0.0
length = len(cluster)
if length == 0: # TODOļ¼š this is my question? do we need to examine this?
return

for item in cluster:
x += item[0]
y += item[1]
centroids[cluster_id] = (x / length, y / length)


def kmeans(data, k):
assert k <= len(data)

seed_ids = np.random.randint(0, len(data), k)
centroids = [data[idx] for idx in seed_ids]
clusters = [[] for _ in xrange(k)]
cluster_idx = [-1] * len(data)

pre_dis = 0
while True:
for point_id, point in enumerate(data):
min_distance, tmp_id = MIN_VAL, -1
for seed_id, seed in enumerate(centroids):
distance = get_distance(seed, point)
if distance < min_distance:
min_distance = distance
tmp_id = seed_id
if cluster_idx[point_id] != -1:
dex = clusters[cluster_idx[point_id]].index(point)
del clusters[cluster_idx[point_id]][dex]
clusters[tmp_id].append(point)
cluster_idx[point_id] = tmp_id

now_dis = 0.0
for cluster_id, cluster in enumerate(clusters):
now_dis += cluster_dis(centroids[cluster_id], cluster)
update_centroids(centroids, cluster_id, cluster)

delta_dis = now_dis - pre_dis
pre_dis = now_dis

if delta_dis < GAP:
break

print(centroids)
print(clusters)

return centroids, clusters

def get_data(file_name):
try:
fr = open(file_name)
lines = fr.read().splitlines()
except IOError, e:
pass
finally:
fr.close()

data = []
for line in lines:
tmp = line.split()
x, y = float(tmp[0]), float(tmp[1])
data.append([x, y])

return data

def main():
args = sys.argv[1:]
assert len(args) > 1
file_name, k = args[0], int(args[1])

data = get_data(file_name)
kmeans(data, k)


if __name__ == '__main__':
main()

Answer

It is possible that k-means induces an empty cluster. Here is one example shown in figures. I also copied the figures below in case the link may expire some day.

The first figure below shows the distribution of the 7 points. Initially 3, 5, and 6 are selected as the cluster centers.

enter image description here

The '+' below shows the cluster centers changes after 1st iteration, and the same color indicates the corresponding points are in the same clusters.

enter image description here

From the figure below, you can see after 2 iterations, the blue cluster becomes empty, and there are indeed 2 clusters instead of the initialization value 3.

enter image description here

So the empty cluster probably due to the initialization and 'incorrect' cluster number. You may try different k in your code and run the program several times to observe the clustering result, making it more robust.

Comments