Python TensorFlow教程中的next_batch batch_xs,batch_ys = mnist.train.next_batch(100)从何而来?

声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow 原文地址: http://stackoverflow.com/questions/40368697/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me): StackOverFlow

提示:将鼠标放在中文语句上可以显示对应的英文。显示中英文
时间:2020-08-19 23:26:52  来源:igfitidea点击:

Where does next_batch in the TensorFlow tutorial batch_xs, batch_ys = mnist.train.next_batch(100) come from?

pythonnumpytensorflow

提问by Dan

I am trying out the TensorFlow tutorial and don't understand where does next_batch in this line come from?

我正在尝试 TensorFlow 教程,但不明白这一行中的 next_batch 来自哪里?

 batch_xs, batch_ys = mnist.train.next_batch(100)

I looked at

我在看

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

And didn't see next_batch there either.

也没有在那里看到 next_batch 。

Now when trying out next_batch in my own code, I am getting

现在在我自己的代码中尝试 next_batch 时,我得到了

AttributeError: 'numpy.ndarray' object has no attribute 'next_batch'

So I would like to understand where does next_batch come from?

所以我想了解next_batch从哪里来?

回答by Nick Becker

next_batchis a method of the DataSetclass (see https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/learn/python/learn/datasets/mnist.pyfor more information on what's in the class).

next_batchDataSet类的方法(有关类中内容的更多信息,请参见https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/learn/python/learn/datasets/mnist.py)。

When you load the mnist data and assign it to the variable mnistwith:

当您加载 mnist 数据并将其分配给变量mnist时:

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

look at the class of mnist.train. You can see it by typing:

看看班级mnist.train。您可以通过键入以下内容查看它:

print mnist.train.__class__

You'll see the following:

您将看到以下内容:

<class 'tensorflow.contrib.learn.python.learn.datasets.mnist.Dataset'>

Because mnist.trainis an instance of class DataSet, you can use the class's function next_batch. For more information on classes, check out the documentation.

因为mnist.train是 class 的一个实例DataSet,所以可以使用 class 的 function next_batch。有关类的更多信息,请查看文档

回答by Dark Element

After looking through the tensorflow repository, it seems to originate here:

查看 tensorflow 存储库后,它似乎起源于这里:

https://github.com/tensorflow/tensorflow/blob/9230423668770036179a72414482d45ddde40a3b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py#L905

https://github.com/tensorflow/tensorflow/blob/9230423668770036179a72414482d45ddde40a3b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py#L905

However if you're looking to implement it in your own code (for your own dataset), it would likely be much simpler to write it yourself in a dataset object, as I did. As I understand it, it's a method to shuffle the entire dataset, and return $mini_batch_size number of samples from the shuffled dataset.

但是,如果您希望在您自己的代码中实现它(对于您自己的数据集),那么像我一样在数据集对象中自己编写它可能会简单得多。据我了解,这是一种对整个数据集进行混洗并从混洗数据集中返回 $mini_batch_size 样本数的方法。

Here's some pseudocode:

这是一些伪代码:

shuffle data.x and data.y while retaining relation return [data.x[:mb_n], data.y[:mb_n]]

shuffle data.x and data.y while retaining relation return [data.x[:mb_n], data.y[:mb_n]]

回答by Bright Chang

You can just use the help function:

您可以只使用帮助功能:

help(tf.contrib.learn.datasets.mnist.DataSet.next_batch)

and get the document of function next_batch

并获取函数 next_batch 的文档