Dan Dan - 1 month ago 53
Python Question

Where does next_batch in the TensorFlow tutorial batch_xs, batch_ys = mnist.train.next_batch(100) come from?

I am trying out the TensorFlow tutorial and don't understand where does next_batch in this line come from?

batch_xs, batch_ys = mnist.train.next_batch(100)


I looked at

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)


And didn't see next_batch there either.

Now when trying out next_batch in my own code, I am getting

AttributeError: 'numpy.ndarray' object has no attribute 'next_batch'


So I would like to understand where does next_batch come from?

Answer

next_batch is a method of the DataSet class (see https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/learn/python/learn/datasets/mnist.py for more information on what's in the class).

When you load the mnist data and assign it to the variable mnist with:

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

look at the class of mnist.train. You can see it by typing:

print mnist.train.__class__

You'll see the following:

<class 'tensorflow.contrib.learn.python.learn.datasets.mnist.Dataset'>

Because mnist.train is an instance of class DataSet, you can use the class's function next_batch. For more information on classes, check out the documentation.