加速训练之并行化 tf.data.Dataset 生成器

虚幻大学 xuhss 171℃ 0评论

? 优质资源分享 ?

学习路线指引(点击解锁) 知识定位 人群定位
? Python实战微信订餐小程序 ? 进阶级 本课程是python flask+微信小程序的完美结合,从项目搭建到腾讯云部署上线,打造一个全栈订餐系统。
?Python量化交易实战? 入门级 手把手带你打造一个易扩展、更安全、效率更高的量化交易系统

在处理大规模数据时,数据无法全部载入内存,我们通常用两个选项

  • 使用tfrecords
  • 使用 tf.data.Dataset.from_generator()

tfrecords的并行化使用前文已经有过介绍,这里不再赘述。如果我们不想生成tfrecord中间文件,那么生成器就是你所需要的。

本文主要记录针对 from_generator()的并行化方法,在 tf.data 中,并行化主要通过 mapnum_parallel_calls 实现,但是对一些场景,我们的generator()中有一些处理逻辑,是无法直接并行化的,最简单的方法就是将generator()中的逻辑抽出来,使用map实现。

tf.data.Dataset generator 并行

generator()中的复杂逻辑,我们对其进行简化,即仅在生成器中做一些下标取值的类型操作,将generator()中处理部分使用py_function 包裹(wrapped) ,然后调用map处理。

def func(i):
    i = i.numpy() # Decoding from the EagerTensor object
    x, y = your_processing_function(training_set[i])
    return x, y

z = list(range(len(training_set))) # The index generator

dataset = tf.data.Dataset.from_generator(lambda: z, tf.uint8)

dataset = dataset.map(lambda i: tf.py_function(func=func, 
                                               inp=[i], 
                                               Tout=[tf.uint8,
                                                     tf.float32]
                                               ), 
                      num_parallel_calls=tf.data.AUTOTUNE)

由于隐式推断的原因,有时tensor的输出shape是未知的,需要额外处理

dataset = dataset.batch(8)
def \_fixup\_shape(x, y):
    x.set_shape([None, None, None, nb_channels]) # n, h, w, c
    y.set_shape([None, nb_classes]) # n, nb\_classes
    return x, y
dataset = dataset.map(_fixup_shape)

tf.Tensor与tf.EagerTensor

为什么需要 tf.py_function,先来看下tf.Tensortf.EagerTensor

EagerTensor是实时的,可以在任何时候获取到它的值,即通过numpy获取

Tensor是非实时的,它是静态图中的组件,只有当喂入数据、运算完成才能获得该Tensor的值,

map中映射的函数运算,而仅仅是告诉dataset,你每一次拿出来的样本时要先进行一遍function运算之后才使用的,所以function的调用是在每次迭代dataset的时候才调用的,属于静态图逻辑

tensorflow.python.framework.ops.EagerTensor
tensorflow.python.framework.ops.Tensor

tf.py_function在这里起了什么作用?

Wraps a python function into a TensorFlow op that executes it eagerly.

刚才说到map数据静态图逻辑,默认参数都是Tensor。而 使用tf.py_function()包装后,参数就变成了EagerTensor。

references

【1】https://medium.com/@acordier/tf-data-dataset-generators-with-parallelization-the-easy-way-b5c5f7d2a18

【2】https://blog.csdn.net/qq_27825451/article/details/105247211

【3】https://www.tensorflow.org/guide/data_performance#parallelizing_data_extraction

转载请注明:xuhss » 加速训练之并行化 tf.data.Dataset 生成器

喜欢 (0)

您必须 登录 才能发表评论!