9th Dimension - 1 year ago 116

Python Question

I am aiming to do big things with TensorFlow, but I'm trying to start small.

I have small greyscale squares (with a little noise) and I want to classify them according to their colour (e.g. 3 categories: black, grey, white). I wrote a little Python class to generate squares, and 1-hot vectors, and modified their basic MNIST example to feed them in.

But it won't learn anything - e.g. for 3 categories it always guesses ≈33% correct.

`import tensorflow as tf`

import generate_data.generate_greyscale

data_generator = generate_data.generate_greyscale.GenerateGreyScale(28, 28, 3, 0.05)

ds = data_generator.generate_data(10000)

ds_validation = data_generator.generate_data(500)

xs = ds[0]

ys = ds[1]

num_categories = data_generator.num_categories

x = tf.placeholder("float", [None, 28*28])

W = tf.Variable(tf.zeros([28*28, num_categories]))

b = tf.Variable(tf.zeros([num_categories]))

y = tf.nn.softmax(tf.matmul(x,W) + b)

y_ = tf.placeholder("float", [None,num_categories])

cross_entropy = -tf.reduce_sum(y_*tf.log(y))

train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

init = tf.initialize_all_variables()

sess = tf.Session()

sess.run(init)

# let batch_size = 100 --> therefore there are 100 batches of training data

xs = xs.reshape(100, 100, 28*28) # reshape into 100 minibatches of size 100

ys = ys.reshape((100, 100, num_categories)) # reshape into 100 minibatches of size 100

for i in range(100):

batch_xs = xs[i]

batch_ys = ys[i]

sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))

accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

xs_validation = ds_validation[0]

ys_validation = ds_validation[1]

print sess.run(accuracy, feed_dict={x: xs_validation, y_: ys_validation})

My data generator looks like this:

`import numpy as np`

import random

class GenerateGreyScale():

def __init__(self, num_rows, num_cols, num_categories, noise):

self.num_rows = num_rows

self.num_cols = num_cols

self.num_categories = num_categories

# set a level of noisiness for the data

self.noise = noise

def generate_label(self):

lab = np.zeros(self.num_categories)

lab[random.randint(0, self.num_categories-1)] = 1

return lab

def generate_datum(self, lab):

i = np.where(lab==1)[0][0]

frac = float(1)/(self.num_categories-1) * i

arr = np.random.uniform(max(0, frac-self.noise), min(1, frac+self.noise), self.num_rows*self.num_cols)

return arr

def generate_data(self, num):

data_arr = np.zeros((num, self.num_rows*self.num_cols))

label_arr = np.zeros((num, self.num_categories))

for i in range(0, num):

label = self.generate_label()

datum = self.generate_datum(label)

data_arr[i] = datum

label_arr[i] = label

#data_arr = data_arr.astype(np.float32)

#label_arr = label_arr.astype(np.float32)

return data_arr, label_arr

Recommended for you: Get network issues from **WhatsUp Gold**. **Not end users.**

Answer Source

While dga and syncd's responses were helpful, I tried using non-zero weight initialization and larger datasets but to no avail. The thing that finally worked was using a different optimization algorithm.

I replaced:

`train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)`

with

`train_step = tf.train.AdamOptimizer(0.0005).minimize(cross_entropy)`

I also embedded the training for loop in another for loop to train for several epochs, resulting in convergence like this:

```
===# EPOCH 0 #===
Error: 0.370000004768
===# EPOCH 1 #===
Error: 0.333999991417
===# EPOCH 2 #===
Error: 0.282000005245
===# EPOCH 3 #===
Error: 0.222000002861
===# EPOCH 4 #===
Error: 0.152000010014
===# EPOCH 5 #===
Error: 0.111999988556
===# EPOCH 6 #===
Error: 0.0680000185966
===# EPOCH 7 #===
Error: 0.0239999890327
===# EPOCH 8 #===
Error: 0.00999999046326
===# EPOCH 9 #===
Error: 0.00400000810623
```

EDIT - WHY IT WORKS: I suppose the problem was that I didn't manually choose a good learning rate schedule, and Adam was able to generate a better one automatically.

Recommended from our users: **Dynamic Network Monitoring from WhatsUp Gold from IPSwitch**. ** Free Download**