本文原作者:于洋,经授权后发布。
1. 开篇
通常,我们在使用Tensorflow低级API编程时(非Eager模式), 一般有下面三个步骤:
- 使用tensorflow python侧的API构建图。图通常包括了两部分:正向计算图和反向计算图;
构建的关键字是:新建的tf.Operation
(节点)和tf.Tensor
(边)对象并将它们添加到tf.Graph
实例中。例如,典型添加op操作就是tf.matmul
。 - 创建tf.Session会话;
此步骤的关键字是:创建默认本地会话with tf.Session() as sess:
,创建分布式会话with tf.Session("grpc://example.org:2222"):
- 在tf.Session会话中,初始化全局变量,并批量运行图。
此步骤的关键语句是:sess.run(init_op)
,sess.run(train_op)
众所周知,tensorflow使用支持多种前端语言(python,js,swift,go等),执行引擎为C/C++后端实现。
那么,在上述三个步骤中,当用户python构建图,以及运行的图的时候。C/C++后端有在执行哪些工作呢?
按照对应的三个步骤,我们做如下拆解:
- python在构建图的过程中,也是C/C++构造图的过程。
即python在新增的tf.Operation
(节点)和tf.Tensor
(边)的同时,C/C++的后端也生成对应的节点和边,从而构造后端的图。 - 图创建好后,python调用tf.Session语句,C/C++端会根据参数创建对应本地Session运行图,或者分布式Session运行图。
- 通过sess.run触发一次图的正向计算,以及反向计算。
本次分享的设计模式,就是在上述第二阶段时:创建本地session和分布式session时,tensorflow是怎样利用抽象工厂设计模式的?
2. 抽象工厂设计模式(Abstract Factory)
在《设计模式》中描述的23设计模式,分为三类:创建型、结构型、行为型。其中,抽象工厂设计模式属于创建性设计模式。即是解决对象的创建需求。关于抽象工厂模式我的理解是这样的:
调用者有创建不同对象的需求(对象有一定相似性,例如轿车、卡车),调用者无需关注具体的实现类,而是通过抽象类定义的接口,就能创造不同对象。
当然,个人抽象理解和描述还是很难理解的。我们根据GOF书中,抽象工厂的模式结构图(图需要从右上角看起)在来理解一下:
- 调用者(Client)有创建对象ProductA1或ProductA2的需求,
- 但是Client类没有直接调用实现类CreateProductA1、CreateProductA2。
- 而是通过抽象工厂AbstractFactory的接口创建了不同的对象(即:创建对象ProductA1或ProductA2)。
[ 抽象工厂的模式结构图 - 《设计模式》58页 ]
有了上面粗浅的理解后,我们看一下tensorflow是如何使用抽象工厂模式,创建本地session和分布式session?
首先,我们看一下python创建Session调用栈:
NewSession
的代码如下:
Status NewSession(const SessionOptions& options, Session** out_session) {
SessionFactory* factory;
Status s = SessionFactory::GetFactory(options, &factory);
if (!s.ok()) {
*out_session = nullptr;
LOG(ERROR) << s;
return s;
}
s = factory->NewSession(options, out_session);
if (!s.ok()) {
*out_session = nullptr;
}
return s;
}
代码很枯燥,我们看一下上述代码的时序图(以创建DirectSessione为例)。
上述代码对应着时序图的阶段2和阶段3。其中:
- 阶段2对应代码
SessionFactory::GetFactory(options, &factory);
, - 阶段3对应代码
factory->NewSession(options, out_session);
[ NewSession的时序图 ]
看到这里,我们温习一下抽象工厂的理解:
- Client(
NewSession
)有创建GrpcSession
或者DirectSession
的需求; - 但是,Client没有直接调用
new DirectSession
或者new GrpcSession
创建; - 而是,通过调用抽象工厂(
SessionFactory
)接口GetFactory
找到DirectSessionFactory
。最终通过DirectSessionFactory->NewSession
创建; - 最终返回实例为
Session
型(多态可以到GrpcSesion
或者DirectSession
对象)。
值得说明的是:Client在整个过程中,并不清楚里面不同的Factory(GrpcSessionFactory
和DirectSessionFactory
),也不清楚不同的Session类型(GrpcSession
和DirectSession
)。
最后,参考抽象工厂结构图,大致画了如下Session的创建环节,大家可以在回味一下该设计模式(图也是从右上角看起):
[ 抽象工厂模式创建Session ]
至此,创建Session的主题框架已经大致梳理出来了。但是,上面的时序图中的阶段1一直还没有说明吧?
好,这部分涉及了单件设计模式。
后记:按照下面的定义,上述创建Session的模式(因为只创建了一种Session产品)是不是叫“工厂方法”会好一点?
- 简单工厂:一个工厂类,一个产品抽象类。
- 工厂方法:多个工厂类,一个产品抽象类。
- 抽象工厂:多个工厂类,多个产品抽象类。
说一下个人理解,tensorflow在设计这段代码的时候,做了很高程度的抽象,具备完成多个产品抽象的能力。我这里姑且认为应用的是抽象工厂模式。
大家也可以按照“工厂方法”模式理解上述代码,宗旨是:希望大家在学习tensorflow代码的过程,能了解里面蕴含的设计模式。
3. 单件设计模式(Singleton)
NewSession
中有这样的代码,不知道大家是否有注意到SessionFactory::GetFactory(options, &factory);
?这段代码的含义也就是根据传递的options
信息,选择是DirectSessionFactory
还是分布式GrpcSessionFactory
。
但是,大家在看时候,有没有这样的疑问:不同的SessionFactory
的是什么时候写入到SessionFactory map
中的?何况tensorflow这种没有main函数的程序?这个问题曾经一直很困扰我,在gdb debug后,我发现了下面的小trick。
诀窍在这行代码中static DirectSessionRegistrar registrar;
。
SessionFactory map
初始化的能量蕴含在这个static
变量的构造函数。下面的流程图揭示所有的秘密。结合代码,从图的左下角看起(下面的代码对应上面NewSession的时序图)。
和全局变量一样,static变量一直存储在程序的静态存储区。当程序初始化static变量时,通过
DirectSessionRegistrar
和GrpcSessionFactory
的构造函数完成初始化,将不同的SessionFactory
(工厂对象)写入到SessionFactory map
中。
[ SessionFactory map的初始化过程 ]
囧~~~,扯了半天的代码和流程,貌似一点都没有提及单件设计模式。其实,单件设计模式在还是比较简单的。GOF中定义如下:
保证一个类仅有一个实例,并提供一个访问它的全局访问点。
tensorflow这里使用了单例中一种更灵活的模式:单件注册表,也就是使用的一个Singleton类的集合(从上图看到存储结构是std::unordered_map
),Singleton类通过一个注册接口将自己的单件实例注册到集合中。而这里的tensorflow是通过DirectSessionRegistrar
和GrpcSessionFactory
构造函数中的SessionFactory::Register
接口完成注册。
4. 进阶
其实,在tensorflow中,上述模式还有很多资源管理的场景中使用。如下给出代码指引,感兴趣的同学可自行学习:
- DeviceFactory //Tensorflow设备管理的代码
- ExecutorFactory //Tensorflow图执行单元的代码
5. 参考
- 代码参考:tensorflow v1.12.0
- 画图:draw.io
更多优质内容请关注官方微信公众号