一.Saved Model模块介绍
saved_model模块主要用于TensorFlow Serving。TF Serving是一个将训练好的模型部署至生产环境的系统,主要的优点在于可以保持Server端与API不变的情况下,部署新的算法或进行试验,同时还有很高的性能。
- 仅用Saver来保存/载入变量。这个方法显然不行,仅保存变量就必须在inference的时候重新定义Graph(定义模型),这样不同的模型代码肯定要修改。即使同一种模型,参数变化了,也需要在代码中有所体现,至少需要一个配置文件来同步,这样就很繁琐了。
- 使用
tf.train.import_meta_graph
导入graph信息并创建Saver, 再使用Saver restore变量。相比第一种,不需要重新定义模型,但是为了从graph中找到输入输出的tensor,还是得用graph.get_tensor_by_name()
来获取,也就是还需要知道在定义模型阶段所赋予这些tensor的名字。如果创建各模型的代码都是同一个人完成的,还相对好控制,强制这些输入输出的命名都一致即可。如果是不同的开发者,要在创建模型阶段就强制tensor的命名一致就比较困难了。这样就不得不再维护一个配置文件,将需要获取的tensor名称写入,然后从配置文件中读取该参数。
1.利用tf.train.Saver()保存和加载模型
1 | """保存模型和变量""" |
1 | """恢复模型和变量""" |
2.saved_model 保存/载入模型
保存
1 | builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir) |
首先构造SavedModelBuilder对象,初始化方法只需要传入用于保存模型的目录名,目录不用预先创建。
add_meta_graph_and_variables
方法导入graph的信息以及变量,这个方法假设变量都已经初始化好了,对于每个SavedModelBuilder这个方法一定要执行一次用于导入第一个meta graph。
第一个参数传入当前的session,包含了graph的结构与所有变量。
第二个参数是给当前需要保存的meta graph一个标签,标签名可以自定义,在之后载入模型的时候,需要根据这个标签名去查找对应的MetaGraphDef,找不到就会报如RuntimeError: MetaGraphDef associated with tags 'foo' could not be found in SavedModel
这样的错。标签也可以选用系统定义好的参数,如tf.saved_model.tag_constants.SERVING
与tf.saved_model.tag_constants.TRAINING
。
save方法就是将模型序列化到指定目录底下。
保存好以后到saved_model_dir目录下,会有一个saved_model.pb
文件以及variables
文件夹。顾名思义,variables
保存所有变量,saved_model.pb
用于保存模型结构等信息。
载入
1 | # 使用`tf.saved_model.loader.load`方法就可以载入模型。如 |
第一个参数就是当前的session,第二个参数是在保存的时候定义的meta graph的标签,标签一致才能找到对应的meta graph。第三个参数就是模型保存的目录。
load完以后,也是从sess对应的graph中获取需要的tensor来inference。如
1 | x = sess.graph.get_tensor_by_name('input_x:0') |
3.使用SignatureDef
保存
SignatureDef定义了一些协议,对我们所需的信息进行封装,我们根据这套协议来获取信息,从而实现创建与使用模型的解耦。SignatureDef,将输入输出tensor的信息都进行了封装,并且给他们一个自定义的别名,所以在构建模型的阶段,可以随便给tensor命名,只要在保存训练好的模型的时候,在SignatureDef中给出统一的别名即可。
TensorFlow的关于这部分的例子中用到了不少signature_constants,这些constants的用处主要是提供了一个方便统一的命名。假设定义模型输入的别名为“input_x”,输出的别名为“output” ,使用SignatureDef的代码如下
1 | builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir) |
上述inputs增加一个keep_prob是为了说明inputs可以有多个, build_tensor_info
方法将tensor相关的信息序列化为TensorInfo protocol buffer。
inputs,outputs都是dict,key是我们约定的输入输出别名,value就是对具体tensor包装得到的TensorInfo。
然后使用build_signature_def
方法构建SignatureDef,第三个参数method_name暂时先随便给一个。
创建好的SignatureDef是用在add_meta_graph_and_variables
的第三个参数signature_def_map
中,但不是直接传入SignatureDef对象。事实上signature_def_map
接收的是一个dict,key是我们自己命名的signature名称,value是SignatureDef对象。
载入
1 | ## 略去构建sess的代码 |
我们只需要约定好输入输出的别名,在保存模型的时候使用这些别名创建signature,输入输出tensor的具体名称已经完全隐藏,这就实现创建模型与使用模型的解耦。