harry lakins harry lakins - 2 months ago 13
Python Question

Why is my very simple neural network not doing at all well?

I have created an extremely simple neural network to help my understanding. It has one neuron, and one input, and one weight. The idea is simple: given many random numbers between 0,200, learn that anything over 100 is correct, and under 100 is in correct (instead of just being told).

import random

weight = random.uniform(-1,1)


def train(g,c,i):
global weight
weight = weight + (i*(c-g)) #change weight by error change
if(g==c):
return True
else:
return False


def trial(i):
global weight
sum = i*weight
if(sum>0):
return 1
else:
return -1


def feedData():
suc = 0
for x in range(0,10000):
d = random.randint(0,200)
if(d>100): #tell what is correct and not (this is like the dataset)
correct = 1
else:
correct = -1

g = trial(d)
if(train(g,correct, d)==True):
suc += 1


print(suc)


feedData();


Out of 10000, I would expect at least 8000 to be correct. However, it always ranges between 4990 and 5100 success.

I obviously have a slight flaw in my understanding. Cheers for any advice.

Answer

I think your problem here is that you're lacking a bias term. The network you've built is multiplying a positive integer (d) by a weight value, and then comparing the result to see if it's positive or negative. In an ideal universe, what should the value of weight be? If weight is positive, the network will get about 50% of the inputs right; if it's negative, it will also be right about 50% of the time.

You'll see that the network can't solve this problem, until you introduce a second "weight" as a bias term. If you have sum = i * weight + bias, and you also update bias in train, then you should be able to correctly classify all inputs. I would initialise bias the same way as weight, and then do the update as:

bias = bias + (c-g)

Bias terms are often used in machine learning systems to account for a "bias" or "skew" in the input data (e.g., in a spam email classifier, maybe 80-95% of emails that we get are not spam, so the system should be biased against marking something as spam). In this case, the bias will allow the network to learn that it should produce some negative outputs, but all of your inputs are positive values.

To put it another way, let's think of linear algebra. Your input classes (that is, {x|x<100} and {x|x>100}) are linearly separable. The function that separates them is something like y = x - 100. This is a straight line on a 2D plot, which has positive slope, and intersects the y axis at y = -100, and the x axis at x = 100. Using this line, you can say that all values for x under 100 map to negative values of y (i.e., are incorrect), and all those above 100 map to positive values of y (i.e., are correct).

The difficulty with your code is that you can only express lines which go through the origin (because you're lacking a bias term).

Comments