我有一个自定义估算器,其形式为tf.estimator.BestExporter
:
exporter = tf.estimator.BestExporter(
name="best_exporter",
serving_input_receiver_fn=serving_input_receiver_fn,
exports_to_keep=5
) # this will keep the 5 best checkpoints
所以在model_dir
下,我现在有:
# (inside model_dir/)
...
export/
- best_exporter/
- <timestamp>
- variables/
- variables.data-00000-of-00001
- variables.index
- saved_model.pb
我可以通过以下方式加载和使用导出的估算器
predict_fn = predictor.from_saved_model(os.path.join(best_exporter_dir, timestamp))
我希望能够更新此估算器的值(例如某层some_layer/kernel:0
的权重)
有一个相关的(但不完全相同)的GitHub issue,它解决了如何使用模型 checkpoints (relevant part of issue)做到这一点,这已得到发布者的确认。与TensorFlow v1.4一起使用。
我试图编织代码的相关部分,至少能够更新一些权重:
def load_estimator_graph(export_dir:str)->None:
'''Solves import issues when using tf.estimator.(Best)Exporter for saving
models rather than using the last checkpoint.
Arguments:
export_dir (str): the full path to exported tf.estimator model
Returns:
None
'''
with tf.Session(graph=tf.Graph()) as sess:
meta_graph = tf.saved_model.loader.load(sess, ['serve'], export_dir)
with tf.Session() as sess:
loaded_graph = tf.train.import_meta_graph(meta_graph)
def lazy_fetch_variable_values(variable_names:list)->dict:
'''
Notes:
"lazy" refers to:
1. the use of `tf.initialize_all_variables()` to ensure
variables have values
2. the use of `tf.trainable_variables()` to search the likely
releveant values
Arguments:
variable_names (list): list of variable names (str) to retrieve from the
default tensorflow graph
Returns:
variables (dict): key:value of the variables and the values as pythonic
data types.
'''
init_op = tf.initialize_all_variables()
variables = {}
with tf.Session() as sess:
sess.run(init_op)
tvars = tf.trainable_variables()
tvars_vals = sess.run(tvars)
for var, val in zip(tvars, tvars_vals):
if var.name in variable_names:
variables[var.name] = val
return variables
def lazy_set_variable_values(variables_to_set:dict):
'''
Arguments:
variables_to_set (dict): variable_name, variable_value pairs for which
to be updated in the graph
'''
init_op = tf.initialize_all_variables()
with tf.Session() as sess:
sess.run(init_op)
tf_global_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
for var_to_find, val_to_set in variables_to_set.items():
var = [v for v in tf_global_vars if v.name == var_to_find][0]
sess.run(var)
var = var.assign(val_to_set)
sess.run(var)
然后类似:
load_estimator_graph(best_exported_model_dir)
layer_name = 'some_layer/kernel:0'
weights = lazy_fetch_variable_values([layer_name])[layer_name]
new_weights = np.copy(weights)
new_weights = 0 # <-- np.ndarray, this sets _element-wise_ all values to 0,
# has same shape as original weight tensor
lazy_set_varriable_values({layer_name: new_weights})
with tf.Session() as session:
sess.run(tf.initialize_all_variables())
saver = tf.train.Saver()
saver.save(sess, os.path.join(best_exported_model_dir, '..', 'best_updated'))
值得注意的是,这是在读取tf.estimator.BestExporter
模型中并试图导出到检查点。
因此,如果我尝试恢复检查点:
est = tf.estimator.Estimator(
model_fn = model_fn,
model_dir = os.path.join(best_exported_model_dir, '..', 'best_updated'),
config = tf.estimator.RunConfig(**_config['RunConfig']), # same as runtime call
params = _config, # same as runtime call
)
eval_fn = lambda : input_fn(mode='eval')
est.evaluate(eval_fn)
我得到:
ValueError Input 0 of layer some_layer is incompatible with the layer: : expected min_ndim=2, found ndim=1. Full shape received: [an_integer]
以上代码中的位置
weights.shape[0] == new_weights.shape[0] == an_integer
理想情况下,我希望以与tf.estimator.BestExporter
和tf.estimator.Estimator.export_savedmodel
相同的形式保存更新的模型。
但是,以上导出方法需要estimator
和相应的serving_input_receiver_fn
的实例。方法predictor.from_saved_model(exported_dir)
不会初始化估计器!因此似乎没有一种直接的方法。
注意:
-predictor
来自from tensorflow.contrib import predictor
-我想从导出的模型中导入,更新一些值(例如,偏差/权重),然后以相同的形式导出(不覆盖原始模型)。