文 / Ruoyu Liu 和 Robert Crowe,来自 TFX 团队
TensorFlow Extended (TFX) 是 Google 专为生产环境的机器学习流水线 (ML Pipeline) 部署而打造的平台,是 Google 机器学习服务和应用的中坚力量。目前我们已开放 TFX 的源代码,各地的开发者可在生产级 TFX 流水线上创建与部署自己的模型。
TFX 可以用多种方式扩展与自定义。我们曾在之前的文章中讲述过如何通过自定义 Executor
来变更 TFX 组件的行为。在本文中,我们将展示如何通过创建一个全新的 TFX 组件以及使用 TFX 流水线来自定义 TFX。
简介
TFX 提供了一套标准组件,可以进行组合从而形成标准的 ML 工作流。尽管这一套标准组件可以满足许多场景的需求,但仍有部分场景有额外需求,需进行定制。这些场景可以使用我们接下介绍的 自定义组件 来扩展 TFX。
在之前的一篇文章中,我们介绍过上下游语义(组件的输入和输出)与相同的场景,这类情况可以通过复用现有的组件并替换 Executor
的行为来创建新的 “半自定义” 组件。现有组件既可以是标准组件之一,也可以是您或其他人创建的自定义组件。
但是,如果新组件的上下游语义与现有组件不同,那么您需要创建新的 “完全自定义” 的自定义组件,这也是本文的主题。
文章后半部分将说明如何使用简单的HelloWorld
组件从头开始创建自定义组件。为简单起见,HelloWorld 组件只会将所有输入复制为自己的输出,并提供给下游组件使用,以演示消耗和发出数据工件。
改进的流水线工作流
ExampleGen
和所有依赖示例数据的下游组件之间加入新的 HelloWorld 组件。这意味着新组件:
需要生成与 ExampleGen 相同类型的输出,以便最初依赖 ExampleGen 的组件得到相同的输入类型
图 1 原工作流
图 2 加入新的自定义组件之后
构建自己的自定义组件
接下来,我们将逐步构建新组件。
通道
TFX 通道 (Channel) 是一个将数据生成者和数据消费者模型连接起来的抽象概念。从概念上讲,一个组件从通道读取输入工件,并将输出工件写入通道,作为下游组件的输入。通道使用工件类型进行类型化(如下一节所述),这意味着写入通道或从通道读取的所有工件都具有相同的工件类型。
ComponentSpec
ComponentSpec
类中,我们将定义带有详细类型信息的协定。需要三个参数:
INPUTS
:传递到组件
Executor
的输入工件的类型化参数字典。通常,输入工件是上游组件的输出,因此具有相同的类型。
OUTPUTS
:由组件生成的输出工件的类型化参数字典。
PARAMETERS
:传递到组件 Executor
的额外的ExecutionParameter
项目字典。我们希望在 DSL 流水线中能灵活定义并将这些非工件参数传递至执行。
ExampleGen
的输出直接传递给
HelloWorld
组件并作为输入之一,所以两者类型需要相同。如
图 3 所示,
'input_data'
是它的规格。
因为原先下游组件得到的是 ExampleGen
的输出,而现在是 HelloWorld
组件的输出之一,所以两者类型需要相同。如 图 3 所示,'output_data'
是它的规格 (Spec)。
在 Parameters 规格部分,出于演示目的,只声明'name'
。
class HelloComponentSpec(types.ComponentSpec):
"""ComponentSpec for Custom TFX Hello World Component."""
# The following declares inputs to the component.
INPUTS = {
'input_data': ChannelParameter(type=standard_artifacts.Examples),
}
# The following declares outputs from the component.
OUTPUTS = {
'output_data': ChannelParameter(type=standard_artifacts.Examples),
}
# The following declares extra parameters used to create an instance of
# this component
PARAMETERS = {
'name': ExecutionParameter(type=Text),
}
图 3 HelloWorld 组件的 ComponentSpec
Executor
下一步,我们来为新组件的 Executor
编写代码。如另一篇文章所讨论的,我们需要创建 base_executor.BaseExecutor
的新子类并覆写其 Do
函数。
class Executor(base_executor.BaseExecutor):
"""Executor for HelloWorld component."""
...
def Do(self, input_dict: Dict[Text, List[types.Artifact]],
output_dict: Dict[Text, List[types.Artifact]],
exec_properties: Dict[Text, Any]) -> None:
...
split_to_instance = {}
for artifact in input_dict['input_data']:
for split in json.loads(artifact.split_names):
uri = os.path.join(artifact.uri, split)
split_to_instance[split] = uri
for split, instance in split_to_instance.items():
input_dir = instance
output_dir = artifact_utils.get_split_uri(
output_dict['output_data'], split)
for filename in tf.io.gfile.listdir(input_dir):
input_uri = os.path.join(input_dir, filename)
output_uri = os.path.join(output_dir, filename)
io_utils.copy_file(src=input_uri, dst=output_uri, overwrite=True)
图 4 HelloWorld 组件的 Executor
如 图 4 所示,我们可以使用之前在 ComponentSpec
中定义的相同键值来获得输入和输出工件以及运行环境参数。在获得所有需要的值之后,我们可以继续使用这些值来添加更多的逻辑,并将输出写入输出工件 ('output_data'
) 所指向的 URI 中。
在继续下一步之前,先进行测试!我们已创建一个脚本,供您在投入生产之前测试您的 Executor
。您需要编写类似的代码来对您的代码进行单元测试。与其他生产软件的部署一样,在为 TFX 开发时,应确保具有良好的测试覆盖范围和强大的 CI/CD 框架。
组件接口
base_component.BaseComponent
子类;
HelloComponentSpec
类为
SPEC_CLASS
指定一个类变量;
Executor
类为
EXECUTOR_SPEC
指定一个类变量;
用参数定义 __init__()
函数,以构造HelloComponentSpec
的实例,并使用值和可选名调用super()
函数。
创建组件实例后,将调用 base_component.BaseComponent
类中的类型检查逻辑,以确保传入的参数与 HelloComponentSpec
类中定义的参数类型兼容。
from hello_component import executor
class HelloComponent(base_component.BaseComponent):
"""Custom TFX HelloWorld Component."""
SPEC_CLASS = HelloComponentSpec
EXECUTOR_SPEC = executor_spec.ExecutorClassSpec(executor.Executor)
def __init__(self,
input_data: channel.Channel,
output_data: channel.Channel,
name: Text):
if not output_data:
examples_artifact = standard_artifacts.Examples()
examples_artifact.split_names = input_data.get()[0].split_names
output_data = channel_utils.as_channel([examples_artifact])
spec = HelloComponentSpec(input_data=input_data,
output_data=output_data, name=name)
super(HelloComponent, self).__init__(spec=spec)
图 5 组件接口
加入 TFX 流水线
ExampleGen
的输出,现在需调整参数至我们新组件的输出
图 6 突出显示了这些变更。可以在我们的 GitHub 代码库中找到完整的示例。
def _create_pipeline():
...
example_gen = CsvExampleGen(input_base=examples)
hello = component.HelloComponent(
input_data=example_gen.outputs['examples'], name=u'HelloWorld')
statistics_gen = StatisticsGen(examples=hello.outputs['output_data'])
return pipeline.Pipeline(
...
components=[example_gen, hello, statistics_gen],
...
)
图 6 使用新的组件
更多信息
若要了解有关 TFX 的更多信息,请访问 TFX 网站,加入 TFX 讨论组,阅读 TFX 博客,在 YT 上观看我们的 TFX 播放列表,并订阅 TensorFlow 频道。
如果您想详细了解 本文提及 的相关内容,请参阅以下文档。这些文档深入探讨了这篇文章中提及的许多主题:
TFX
https://tensorflow.google.cn/tfx
自定义 Executor
https://blog.tensorflow.org/2019/09/creating-custom-tfx-executor_19.html
TFX 组件
https://tensorflow.google.cn/tfx/guide#anatomy_of_a_component
全新的 TFX 组件
https://tensorflow.google.cn/tfx/guide/custom_component
标准组件
https://tensorflow.google.cn/tfx/guide#tfx_pipeline_components
文章
https://blog.tensorflow.org/2019/09/creating-custom-tfx-executor_19.html
HelloWorld 组件
https://github.com/tensorflow/tfx/tree/master/tfx/examples/custom_components/hello_world
ExampleGen
https://tensorflow.google.cn/tfx/guide/examplegen
ComponentSpec
https://github.com/tensorflow/tfx/blob/master/tfx/types/component_spec.py
ExecutionParameter
https://github.com/tensorflow/tfx/blob/54aa6fbec6bffafa8352fe51b11251b1e44a2bf1/tfx/types/component_spec.py#L274
脚本
https://github.com/tensorflow/tfx/blob/master/tfx/scripts/run_executor.py
流水线
https://github.com/tensorflow/tfx/blob/master/tfx/examples/custom_components/hello_world/example/taxi_pipeline_hello.py
GitHub 代码库
https://github.com/tensorflow/tfx/tree/master/tfx/examples/custom_components/hello_world
TFX 网站
https://tensorflow.google.cn/tfx
TFX 讨论组
https://groups.google.com/a/tensorflow.org/forum/#!forum/tfx
TFX 博客
https://goo.gle/tfx-blog
TFX 播放列表
https://goo.gle/2xVkwt4
订阅
https://goo.gle/2WtM7Ak