Peter - 1 year ago 112

Python Question

I'm trying to implement a gradient descent algorithm that was previously written in matlab in python with numpy, but I'm getting a set of similar but different results.

Here's the matlab code

`function [theta] = gradientDescentMulti(X, y, theta, alpha, num_iters)`

m = length(y);

num_features = size(X,2);

for iter = 1:num_iters;

temp_theta = theta;

for i = 1:num_features

temp_theta(i) = theta(i)-((alpha/m)*(X * theta - y)'*X(:,i));

end

theta = temp_theta;

end

end

and my python version

`def gradient_descent(X,y, alpha, trials):`

m = X.shape[0]

n = X.shape[1]

theta = np.zeros((n, 1))

for i in range(trials):

temp_theta = theta

for p in range(n):

thetaX = np.dot(X, theta)

tMinY = thetaX-y

temp_theta[p] = temp_theta[p]-(alpha/m)*np.dot(tMinY.T, X[:,p:p+1])

theta = temp_theta

return theta

Test case and results in matlab

`X = [1 2 1 3; 1 7 1 9; 1 1 8 1; 1 3 7 4]`

y = [2 ; 5 ; 5 ; 6];

[theta] = gradientDescentMulti(X, y, zeros(4,1), 0.01, 1);

theta =

0.0450

0.1550

0.2225

0.2000

test case and result in python

`test_X = np.array([[1,2,1,3],[1,7,1,9],[1,1,8,1],[1,3,7,4]])`

test_y = np.array([[2], [5], [5], [6]])

theta, cost = gradient_descent(test_X, test_y, 0.01, 1)

print theta

>>[[ 0.045 ]

[ 0.1535375 ]

[ 0.20600144]

[ 0.14189214]]

Recommended for you: Get network issues from **WhatsUp Gold**. **Not end users.**

Answer Source

This line in your Python:

```
temp_theta = theta
```

doesn't do what you think it does. It doesn't make a copy of `theta`

and "assign" it to the "variable" `temp_theta`

-- it just says "`temp_theta`

is now a new name for the object currently named by `theta`

".

So when you modify `temp_theta`

here:

```
temp_theta[p] = temp_theta[p]-(alpha/m)*np.dot(tMinY.T, X[:,p:p+1])
```

You're actually modifying `theta`

-- because there's only the one array, now with two names.

If you instead write

```
temp_theta = theta.copy()
```

you'll get something like

```
(3.5) dsm@notebook:~/coding$ python peter.py
[[ 0.045 ]
[ 0.155 ]
[ 0.2225]
[ 0.2 ]]
```

which matches your Matlab results.

Recommended from our users: **Dynamic Network Monitoring from WhatsUp Gold from IPSwitch**. ** Free Download**