博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Tensorflow 数据导入
阅读量:6595 次
发布时间:2019-06-24

本文共 21339 字,大约阅读时间需要 71 分钟。

导入数据

借助 API,您可以根据简单的可重用片段构建复杂的输入管道。例如,图片模型的管道可能会汇聚分布式文件系统中的文件中的数据、对每个图片应用随机扰动,并将随机选择的图片合并成用于训练的批次。文本模型的管道可能包括从原始文本数据中提取符号、根据对照表将其转换为嵌入标识符,以及将不同长度的序列组合成批次数据。使用 API 可以轻松处理大量数据、不同的数据格式以及复杂的转换。

API 在 TensorFlow 中引入了两个新的抽象类:

  • 表示一系列元素,其中每个元素包含一个或多个 Tensor 对象。例如,在图像管道中,元素可能是单个训练样本,具有一对表示图像数据和标签的张量。可以通过两种不同的方式来创建数据集:

    • 创建来源(例如 Dataset.from_tensor_slices()),以通过一个或多个 对象构建数据集。

    • 应用转换(例如 Dataset.batch()),以通过一个或多个 对象构建数据集。

  • 提供了从数据集中提取元素的主要方法。Iterator.get_next() 返回的操作会在执行时生成 Dataset 的下一个元素,并且此操作通常充当输入管道代码和模型之间的接口。最简单的迭代器是“单次迭代器”,它与特定的 Dataset 相关联,并对其进行一次迭代。要实现更复杂的用途,您可以通过 Iterator.initializer 操作使用不同的数据集重新初始化和参数化迭代器,这样一来,您就可以在同一个程序中对训练和验证数据进行多次迭代(举例而言)。

基本机制

本指南的这一部分介绍了创建不同种类的 DatasetIterator 对象的基础知识,以及如何从这些对象中提取数据。

要启动输入管道,您必须定义来源。例如,要通过内存中的某些张量构建 Dataset,您可以使用 tf.data.Dataset.from_tensors()tf.data.Dataset.from_tensor_slices()。或者,如果输入数据以推荐的 TFRecord 格式存储在磁盘上,那么您可以构建 。

一旦有了 Dataset 对象,可以将其转换为新的 Dataset,方法是链接 对象上的方法调用。例如,您可以应用单元素转换,例如 Dataset.map()(为每个元素应用一个函数),也可以应用多元素转换(例如 Dataset.batch())。要了解转换的完整列表,请参阅 的文档。

消耗 Dataset 中值的最常见方法是构建迭代器对象。通过此对象,可以一次访问数据集中的一个元素(例如通过调用 Dataset.make_one_shot_iterator())。 提供了两个操作:Iterator.initializer,您可以通过此操作(重新)初始化迭代器的状态;以及 Iterator.get_next(),此操作返回对应于有符号下一个元素的 对象。根据您的使用情形,您可以选择不同类型的迭代器,下文介绍了具体选项。

数据集结构

一个数据集包含多个元素,每个元素的结构都相同。一个元素包含一个或多个 对象,这些对象称为组件。每个组件都有一个 ,表示张量中元素的类型;以及一个 ,表示每个元素(可能部分指定)的静态形状。您可以通过 Dataset.output_typesDataset.output_shapes 属性检查数据集元素各个组件的推理类型和形状。这些属性的嵌套结构映射到元素的结构,此元素可以是单个张量、张量元组,也可以是张量的嵌套元组。例如:

dataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10])) print(dataset1.output_types) # ==> "tf.float32" print(dataset1.output_shapes) # ==> "(10,)" dataset2 = tf.data.Dataset.from_tensor_slices( (tf.random_uniform([4]), tf.random_uniform([4, 100], maxval=100, dtype=tf.int32))) print(dataset2.output_types) # ==> "(tf.float32, tf.int32)" print(dataset2.output_shapes) # ==> "((), (100,))" dataset3 = tf.data.Dataset.zip((dataset1, dataset2)) print(dataset3.output_types) # ==> (tf.float32, (tf.float32, tf.int32)) print(dataset3.output_shapes) # ==> "(10, ((), (100,)))"

为元素的每个组件命名通常会带来便利性,例如,如果它们表示训练样本的不同特征。除了元组之外,还可以使用 collections.namedtuple 或将字符串映射到张量的字典来表示 Dataset 的单个元素。

dataset = tf.data.Dataset.from_tensor_slices( { "a": tf.random_uniform([4]), "b": tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)}) print(dataset.output_types) # ==> "{'a': tf.float32, 'b': tf.int32}" print(dataset.output_shapes) # ==> "{'a': (), 'b': (100,)}"

Dataset 转换支持任何结构的数据集。在使用 Dataset.map()Dataset.flat_map()Dataset.filter() 转换时(这些转换会对每个元素应用一个函数),元素结构决定了函数的参数:

dataset1 = dataset1.map(lambda x: ...) dataset2 = dataset2.flat_map(lambda x, y: ...) # Note: Argument destructuring is not available in Python 3. dataset3 = dataset3.filter(lambda x, (y, z): ...)

创建迭代器

构建了表示输入数据的 Dataset 后,下一步就是创建 Iterator 来访问该数据集中的元素。 API 目前支持下列迭代器,复杂程度逐渐增大:

  • 单次
  • 可初始化
  • 可重新初始化,以及
  • 可馈送

单次迭代器是最简单的迭代器形式,仅支持对数据集进行一次迭代,不需要显式初始化。单次迭代器可以处理基于队列的现有输入管道支持的几乎所有情况,但它们不支持参数化。以 Dataset.range() 为例:

dataset = tf.data.Dataset.range(100) iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() for i in range(100): value = sess.run(next_element) assert i == value

注意:目前,单次迭代器是唯一易于与 Estimator 搭配使用的类型。

您需要先运行显式 iterator.initializer 操作,然后才能使用可初始化迭代器。虽然有些不便,但它允许您使用一个或多个 tf.placeholder() 张量(可在初始化迭代器时馈送)参数化数据集的定义。继续以 Dataset.range() 为例:

max_value = tf.placeholder(tf.int64, shape=[]) dataset = tf.data.Dataset.range(max_value) iterator = dataset.make_initializable_iterator() next_element = iterator.get_next() # Initialize an iterator over a dataset with 10 elements. sess.run(iterator.initializer, feed_dict={ max_value: 10}) for i in range(10): value = sess.run(next_element) assert i == value # Initialize the same iterator over a dataset with 100 elements. sess.run(iterator.initializer, feed_dict={ max_value: 100}) for i in range(100): value = sess.run(next_element) assert i == value

可重新初始化迭代器可以通过多个不同的 Dataset 对象进行初始化。例如,您可能有一个训练输入管道,它会对输入图片进行随机扰动来改善泛化;还有一个验证输入管道,它会评估对未修改数据的预测。这些管道通常会使用不同的 Dataset 对象,这些对象具有相同的结构(即每个组件具有相同类型和兼容形状)。

# Define training and validation datasets with the same structure. training_dataset = tf.data.Dataset.range(100).map( lambda x: x + tf.random_uniform([], -10, 10, tf.int64)) validation_dataset = tf.data.Dataset.range(50) # A reinitializable iterator is defined by its structure. We could use the # `output_types` and `output_shapes` properties of either `training_dataset` # or `validation_dataset` here, because they are compatible. iterator = tf.data.Iterator.from_structure(training_dataset.output_types, training_dataset.output_shapes) next_element = iterator.get_next() training_init_op = iterator.make_initializer(training_dataset) validation_init_op = iterator.make_initializer(validation_dataset) # Run 20 epochs in which the training dataset is traversed, followed by the # validation dataset. for _ in range(20): # Initialize an iterator over the training dataset. sess.run(training_init_op) for _ in range(100): sess.run(next_element) # Initialize an iterator over the validation dataset. sess.run(validation_init_op) for _ in range(50): sess.run(next_element)

可馈送迭代器可以与 一起使用,以选择所使用的 Iterator(在每次调用 时)(通过熟悉的 feed_dict 机制)。它提供的功能与可重新初始化迭代器的相同,但在迭代器之间切换时不需要从数据集的开头初始化迭代器。例如,以上面的同一训练和验证数据集为例,您可以使用 定义一个可让您在两个数据集之间切换的可馈送迭代器:

# Define training and validation datasets with the same structure. training_dataset = tf.data.Dataset.range(100).map( lambda x: x + tf.random_uniform([], -10, 10, tf.int64)).repeat() validation_dataset = tf.data.Dataset.range(50) # A feedable iterator is defined by a handle placeholder and its structure. We # could use the `output_types` and `output_shapes` properties of either # `training_dataset` or `validation_dataset` here, because they have # identical structure. handle = tf.placeholder(tf.string, shape=[]) iterator = tf.data.Iterator.from_string_handle( handle, training_dataset.output_types, training_dataset.output_shapes) next_element = iterator.get_next() # You can use feedable iterators with a variety of different kinds of iterator # (such as one-shot and initializable iterators). training_iterator = training_dataset.make_one_shot_iterator() validation_iterator = validation_dataset.make_initializable_iterator() # The `Iterator.string_handle()` method returns a tensor that can be evaluated # and used to feed the `handle` placeholder. training_handle = sess.run(training_iterator.string_handle()) validation_handle = sess.run(validation_iterator.string_handle()) # Loop forever, alternating between training and validation. while True: # Run 200 steps using the training dataset. Note that the training dataset is # infinite, and we resume from where we left off in the previous `while` loop # iteration. for _ in range(200): sess.run(next_element, feed_dict={ handle: training_handle}) # Run one pass over the validation dataset. sess.run(validation_iterator.initializer) for _ in range(50): sess.run(next_element, feed_dict={ handle: validation_handle})

消耗迭代器中的值

Iterator.get_next() 方法返回一个或多个 对象,这些对象对应于迭代器有符号的下一个元素。每次评估这些张量时,它们都会获取底层数据集中下一个元素的值。(请注意,与 TensorFlow 中的其他有状态对象一样,调用 Iterator.get_next() 并不会立即使迭代器进入下个状态。您必须在 TensorFlow 表达式中使用此函数返回的 对象,并将该表达式的结果传递到 tf.Session.run(),以获取下一个元素并使迭代器进入下个状态。)

如果迭代器到达数据集的末尾,则执行 Iterator.get_next() 操作会产生 。在此之后,迭代器将处于不可用状态;如果需要继续使用,则必须对其重新初始化。

dataset = tf.data.Dataset.range(5) iterator = dataset.make_initializable_iterator() next_element = iterator.get_next() # Typically `result` will be the output of a model, or an optimizer's # training operation. result = tf.add(next_element, next_element) sess.run(iterator.initializer) print(sess.run(result)) # ==> "0" print(sess.run(result)) # ==> "2" print(sess.run(result)) # ==> "4" print(sess.run(result)) # ==> "6" print(sess.run(result)) # ==> "8" try: sess.run(result) except tf.errors.OutOfRangeError: print("End of dataset") # ==> "End of dataset"

一种常见模式是将“训练循环”封装在 try-except 块中:

sess.run(iterator.initializer) while True: try: sess.run(result) except tf.errors.OutOfRangeError: break

如果数据集的每个元素都具有嵌套结构,则 Iterator.get_next() 的返回值将是一个或多个 对象,这些对象具有相同的嵌套结构:

dataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10])) dataset2 = tf.data.Dataset.from_tensor_slices((tf.random_uniform([4]), tf.random_uniform([4, 100]))) dataset3 = tf.data.Dataset.zip((dataset1, dataset2)) iterator = dataset3.make_initializable_iterator() sess.run(iterator.initializer) next1, (next2, next3) = iterator.get_next()

请注意,next1next2next3 是由同一个操作/节点(通过 Iterator.get_next() 创建)生成的张量。因此,评估其中任何一个张量都会使所有组件的迭代器进入下个状态。典型的迭代器消耗方会在一个表达式中包含所有组件。

保存迭代器状态

函数通过迭代器创建一个 SaveableObject,该对象可用于保存和恢复迭代器(实际上是整个输入管道)的当前状态。以这种方式创建的可保存对象可以添加到 变量列表或 集合中,以便采用与 相同的方式进行保存和恢复。请参阅,详细了解如何保存和恢复变量。

# Create saveable object from iterator. saveable = tf.contrib.data.make_saveable_from_iterator(iterator) # Save the iterator state by adding it to the saveable objects collection. tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable) saver = tf.train.Saver() with tf.Session() as sess: if should_checkpoint: saver.save(path_to_checkpoint) # Restore the iterator state. with tf.Session() as sess: saver.restore(sess, path_to_checkpoint)

读取输入数据

消耗 NumPy 数组

如果您的所有输入数据都适合存储在内存中,则根据输入数据创建 Dataset 的最简单方法是将它们转换为 对象,并使用 Dataset.from_tensor_slices()

# Load the training data into two NumPy arrays, for example using `np.load()`. with np.load("/var/data/training_data.npy") as data: features = data["features"] labels = data["labels"] # Assume that each row of `features` corresponds to the same row as `labels`. assert features.shape[0] == labels.shape[0] dataset = tf.data.Dataset.from_tensor_slices((features, labels))

请注意,上面的代码段会将 featureslabels 数组作为 tf.constant() 指令嵌入在 TensorFlow 图中。这样非常适合小型数据集,但会浪费内存,因为会多次复制数组的内容,并可能会达到 协议缓冲区的 2GB 上限。

作为替代方案,您可以根据 tf.placeholder() 张量定义 Dataset,并在对数据集初始化 Iterator 时馈送 NumPy 数组。

# Load the training data into two NumPy arrays, for example using `np.load()`. with np.load("/var/data/training_data.npy") as data: features = data["features"] labels = data["labels"] # Assume that each row of `features` corresponds to the same row as `labels`. assert features.shape[0] == labels.shape[0] features_placeholder = tf.placeholder(features.dtype, features.shape) labels_placeholder = tf.placeholder(labels.dtype, labels.shape) dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder)) # [Other transformations on `dataset`...] dataset = ... iterator = dataset.make_initializable_iterator() sess.run(iterator.initializer, feed_dict={ features_placeholder: features, labels_placeholder: labels})

消耗 TFRecord 数据

API 支持多种文件格式,因此您可以处理那些不适合存储在内存中的大型数据集。例如,TFRecord 文件格式是一种面向记录的简单二进制格式,很多 TensorFlow 应用采用此格式来训练数据。通过 类,您可以将一个或多个 TFRecord 文件的内容作为输入管道的一部分进行流式传输。

# Creates a dataset that reads all of the examples from two files. filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"] dataset = tf.data.TFRecordDataset(filenames)

TFRecordDataset 初始化程序的 filenames 参数可以是字符串、字符串列表,也可以是字符串 。因此,如果您有两组分别用于训练和验证的文件,则可以使用 tf.placeholder(tf.string) 来表示文件名,并使用适当的文件名初始化迭代器:

filenames = tf.placeholder(tf.string, shape=[None]) dataset = tf.data.TFRecordDataset(filenames) dataset = dataset.map(...) # Parse the record into tensors. dataset = dataset.repeat() # Repeat the input indefinitely. dataset = dataset.batch(32) iterator = dataset.make_initializable_iterator() # You can feed the initializer with the appropriate filenames for the current # phase of execution, e.g. training vs. validation. # Initialize `iterator` with training data. training_filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"] sess.run(iterator.initializer, feed_dict={ filenames: training_filenames}) # Initialize `iterator` with validation data. validation_filenames = ["/var/data/validation1.tfrecord", ...] sess.run(iterator.initializer, feed_dict={ filenames: validation_filenames})

消耗文本数据

很多数据集都是作为一个或多个文本文件分布的。 提供了一种从一个或多个文本文件中提取行的简单方法。给定一个或多个文件名,TextLineDataset 会为这些文件的每行生成一个字符串值元素。像 TFRecordDataset 一样,TextLineDataset 将接受 filenames(作为 ),因此您可以通过传递 tf.placeholder(tf.string) 进行参数化。

filenames = ["/var/data/file1.txt", "/var/data/file2.txt"] dataset = tf.data.TextLineDataset(filenames)

默认情况下,TextLineDataset 会生成每个文件的每一行,这可能是不可取的(例如,如果文件以标题行开头或包含注释)。可以使用 Dataset.skip()Dataset.filter() 转换来移除这些行。为了将这些转换分别应用于每个文件,我们使用 Dataset.flat_map() 为每个文件创建一个嵌套的 Dataset

filenames = ["/var/data/file1.txt", "/var/data/file2.txt"] dataset = tf.data.Dataset.from_tensor_slices(filenames) # Use `Dataset.flat_map()` to transform each file as a separate nested dataset, # and then concatenate their contents sequentially into a single "flat" dataset. # * Skip the first line (header row). # * Filter out lines beginning with "#" (comments). dataset = dataset.flat_map( lambda filename: ( tf.data.TextLineDataset(filename) .skip(1) .filter(lambda line: tf.not_equal(tf.substr(line, 0, 1), "#"))))

消耗 CSV 数据

CSV 文件格式是用于以纯文本格式存储表格数据的常用格式。 类提供了一种从符合 的一个或多个 CSV 文件中提取记录的方法。给定一个或多个文件名以及默认值列表后,CsvDataset 将生成一个元素元组,元素类型对应于为每个 CSV 记录提供的默认元素类型。像 TFRecordDatasetTextLineDataset 一样,CsvDataset 将接受 filenames(作为 ),因此您可以通过传递 tf.placeholder(tf.string) 进行参数化。

# Creates a dataset that reads all of the records from two CSV files, each with # eight float columns filenames = ["/var/data/file1.csv", "/var/data/file2.csv"] record_defaults = [tf.float32] * 8 # Eight required float columns dataset = tf.contrib.data.CsvDataset(filenames, record_defaults)

如果某些列为空,则可以提供默认值而不是类型。

# Creates a dataset that reads all of the records from two CSV files, each with # four float columns which may have missing values record_defaults = [[0.0]] * 8 dataset = tf.contrib.data.CsvDataset(filenames, record_defaults)

默认情况下,CsvDataset 生成文件的每一列或每一行,这可能是不可取的;例如,如果文件以应忽略的标题行开头,或如果输入中不需要某些列。可以分别使用 headerselect_cols 参数移除这些行和字段。

# Creates a dataset that reads all of the records from two CSV files with # headers, extracting float data from columns 2 and 4. record_defaults = [[0.0]] * 2 # Only provide defaults for the selected columns dataset = tf.contrib.data.CsvDataset(filenames, record_defaults, header=True, select_cols=[2,4])

使用 Dataset.map() 预处理数据

Dataset.map(f) 转换通过将指定函数 f 应用于输入数据集的每个元素来生成新数据集。此转换基于 (通常应用于函数式编程语言中的列表和其他结构)。函数 f 会接受表示输入中单个元素的 对象,并返回表示新数据集中单个元素的 对象。此函数的实现使用标准的 TensorFlow 指令将一个元素转换为另一个元素。

本部分介绍了如何使用 Dataset.map() 的常见示例。

解析 tf.Example 协议缓冲区消息

许多输入管道都从 TFRecord 格式的文件中提取 协议缓冲区消息(例如这种文件使用 编写而成)。每个 记录都包含一个或多个“特征”,输入管道通常会将这些特征转换为张量。

# Transforms a scalar string `example_proto` into a pair of a scalar string and # a scalar integer, representing an image and its label, respectively. def _parse_function(example_proto): features = { "image": tf.FixedLenFeature((), tf.string, default_value=""), "label": tf.FixedLenFeature((), tf.int64, default_value=0)} parsed_features = tf.parse_single_example(example_proto, features) return parsed_features["image"], parsed_features["label"] # Creates a dataset that reads all of the examples from two files, and extracts # the image and label features. filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"] dataset = tf.data.TFRecordDataset(filenames) dataset = dataset.map(_parse_function)

解码图片数据并调整其大小

在用真实的图片数据训练神经网络时,通常需要将不同大小的图片转换为通用大小,这样就可以将它们批处理为具有固定大小的数据。

# Reads an image from a file, decodes it into a dense tensor, and resizes it # to a fixed shape. def _parse_function(filename, label): image_string = tf.read_file(filename) image_decoded = tf.image.decode_jpeg(image_string) image_resized = tf.image.resize_images(image_decoded, [28, 28]) return image_resized, label # A vector of filenames. filenames = tf.constant(["/var/data/image1.jpg", "/var/data/image2.jpg", ...]) # `labels[i]` is the label for the image in `filenames[i]. labels = tf.constant([0, 37, ...]) dataset = tf.data.Dataset.from_tensor_slices((filenames, labels)) dataset = dataset.map(_parse_function)

使用 tf.py_func() 应用任意 Python 逻辑

为了确保性能,我们建议您尽可能使用 TensorFlow 指令预处理数据。不过,在解析输入数据时,调用外部 Python 库有时很有用。为此,请在 Dataset.map() 转换中调用 tf.py_func() 指令。

import cv2 # Use a custom OpenCV function to read the image, instead of the standard # TensorFlow `tf.read_file()` operation. def _read_py_function(filename, label): image_decoded = cv2.imread(filename.decode(), cv2.IMREAD_GRAYSCALE) return image_decoded, label # Use standard TensorFlow operations to resize the image to a fixed shape. def _resize_function(image_decoded, label): image_decoded.set_shape([None, None, None]) image_resized = tf.image.resize_images(image_decoded, [28, 28]) return image_resized, label filenames = ["/var/data/image1.jpg", "/var/data/image2.jpg", ...] labels = [0, 37, 29, 1, ...] dataset = tf.data.Dataset.from_tensor_slices((filenames, labels)) dataset = dataset.map( lambda filename, label: tuple(tf.py_func( _read_py_function, [filename, label], [tf.uint8, label.dtype]))) dataset = dataset.map(_resize_function)

批处理数据集元素

简单的批处理

最简单的批处理形式是将数据集中的 n 个连续元素堆叠为一个元素。Dataset.batch() 转换正是这么做的,它与 tf.stack() 运算符具有相同的限制(被应用于元素的每个组件):即对于每个组件 i,所有元素的张量形状都必须完全相同。

inc_dataset = tf.data.Dataset.range(100) dec_dataset = tf.data.Dataset.range(0, -100, -1) dataset = tf.data.Dataset.zip((inc_dataset, dec_dataset)) batched_dataset = dataset.batch(4) iterator = batched_dataset.make_one_shot_iterator() next_element = iterator.get_next() print(sess.run(next_element)) # ==> ([0, 1, 2, 3], [ 0, -1, -2, -3]) print(sess.run(next_element)) # ==> ([4, 5, 6, 7], [-4, -5, -6, -7]) print(sess.run(next_element)) # ==> ([8, 9, 10, 11], [-8, -9, -10, -11])

使用填充批处理张量

上述方法适用于具有相同大小的张量。不过,很多模型(例如序列模型)处理的输入数据可能具有不同的大小(例如序列的长度不同)。为了解决这种情况,可以通过 Dataset.padded_batch() 转换来指定一个或多个会被填充的维度,从而批处理不同形状的张量。

dataset = tf.data.Dataset.range(100) dataset = dataset.map(lambda x: tf.fill([tf.cast(x, tf.int32)], x)) dataset = dataset.padded_batch(4, padded_shapes=[None]) iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() print(sess.run(next_element)) # ==> [[0, 0, 0], [1, 0, 0], [2, 2, 0], [3, 3, 3]] print(sess.run(next_element)) # ==> [[4, 4, 4, 4, 0, 0, 0], # [5, 5, 5, 5, 5, 0, 0], # [6, 6, 6, 6, 6, 6, 0], # [7, 7, 7, 7, 7, 7, 7]]

您可以通过 Dataset.padded_batch() 转换为每个组件的每个维度设置不同的填充,并且可以采用可变长度(在上面的示例中用 None 表示)或恒定长度。也可以替换填充值,默认设置为 0。

训练工作流程

处理多个周期

API 提供了两种主要方式来处理同一数据的多个周期。

要迭代数据集多个周期,最简单的方法是使用 Dataset.repeat() 转换。例如,要创建一个将其输入重复 10 个周期的数据集:

filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"] dataset = tf.data.TFRecordDataset(filenames) dataset = dataset.map(...) dataset = dataset.repeat(10) dataset = dataset.batch(32)

应用不带参数的 Dataset.repeat() 转换将无限次地重复输入。Dataset.repeat() 转换将其参数连接起来,而不会在一个周期结束和下一个周期开始时发出信号。

如果您想在每个周期结束时收到信号,则可以编写在数据集结束时捕获 的训练循环。此时,您可以收集关于该周期的一些统计信息(例如验证错误)。

filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"] dataset = tf.data.TFRecordDataset(filenames) dataset = dataset.map(...) dataset = dataset.batch(32) iterator = dataset.make_initializable_iterator() next_element = iterator.get_next() # Compute for 100 epochs. for _ in range(100): sess.run(iterator.initializer) while True: try: sess.run(next_element) except tf.errors.OutOfRangeError: break # [Perform end-of-epoch calculations here.]

随机重排输入数据

Dataset.shuffle() 转换会使用类似于 的算法随机重排输入数据集:它会维持一个固定大小的缓冲区,并从该缓冲区统一地随机选择下一个元素。

filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"] dataset = tf.data.TFRecordDataset(filenames) dataset = dataset.map(...) dataset = dataset.shuffle(buffer_size=10000) dataset = dataset.batch(32) dataset = dataset.repeat()

使用高阶 API

API 简化了在分布式设置下运行 TensorFlow 的很多方面。MonitoredTrainingSession 使用 表示训练已完成,因此要将其与 API 结合使用,我们建议使用 Dataset.make_one_shot_iterator()。例如:

filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"] dataset = tf.data.TFRecordDataset(filenames) dataset = dataset.map(...) dataset = dataset.shuffle(buffer_size=10000) dataset = dataset.batch(32) dataset = dataset.repeat(num_epochs) iterator = dataset.make_one_shot_iterator() next_example, next_label = iterator.get_next() loss = model_function(next_example, next_label) training_op = tf.train.AdagradOptimizer(...).minimize(loss) with tf.train.MonitoredTrainingSession(...) as sess: while not sess.should_stop(): sess.run(training_op)

要在 input_fn 中使用 Dataset(input_fn 属于 ),只需返回 Dataset 即可,框架将负责为您创建和初始化迭代器。例如:

def dataset_input_fn(): filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"] dataset = tf.data.TFRecordDataset(filenames) # Use `tf.parse_single_example()` to extract data from a `tf.Example` # protocol buffer, and perform any additional per-record preprocessing. def parser(record): keys_to_features = { "image_data": tf.FixedLenFeature((), tf.string, default_value=""), "date_time": tf.FixedLenFeature((), tf.int64, default_value=""), "label": tf.FixedLenFeature((), tf.int64, default_value=tf.zeros([], dtype=tf.int64)), } parsed = tf.parse_single_example(record, keys_to_features) # Perform additional preprocessing on the parsed data. image = tf.image.decode_jpeg(parsed["image_data"]) image = tf.reshape(image, [299, 299, 1]) label = tf.cast(parsed["label"], tf.int32) return { "image_data": image, "date_time": parsed["date_time"]}, label # Use `Dataset.map()` to build a pair of a feature dictionary and a label # tensor for each example. dataset = dataset.map(parser) dataset = dataset.shuffle(buffer_size=10000) dataset = dataset.batch(32) dataset = dataset.repeat(num_epochs) # Each element of `dataset` is tuple containing a dictionary of features # (in which each value is a batch of values for that feature), and a batch of # labels. return dataset

转载于:https://www.cnblogs.com/llfctt/p/10770246.html

你可能感兴趣的文章
React组件: 提取图片颜色
查看>>
3D应用开发中的欧拉角和旋转矩阵
查看>>
RxJava2.0的初学者必备教程(九)
查看>>
记一次omi的项目之旅
查看>>
Android API级别、代号、发布时间及平台亮点整理
查看>>
LLDP(链路层发现协议)
查看>>
Ubuntu14 添加程序启动
查看>>
我的友情链接
查看>>
windows网络安全以及常见网络***方式
查看>>
警告 初始化默认驱动器时出错“找不到运行 Active Directory Web 服务的默认服务器。”...
查看>>
JS字符串转换数字
查看>>
使用IntelliJ IDEA开发SpringMVC网站(四)用户管理
查看>>
js 验证中文
查看>>
Linux下运行java DES AES加解密
查看>>
牛津词典 2018 年度词汇 ——「有毒」!
查看>>
Android Arcface人脸识别sdk使用工具类
查看>>
android studio单个工程文件的代理设置
查看>>
我的友情链接
查看>>
一行命令获取当前JVM所有可设置的参数以及当前默认值
查看>>
Linux mint 14下的powerDNS+mysql+powerAdmin搭建个性DNS域名解析服务器
查看>>