How to initialize tf.contrib.lookup.HashTable used in Tensorflow Estimator model_fn?

时间:2019-04-17 01:35:33

标签: tensorflow lookup-tables tensorflow-estimator

I have a tf.contrib.lookup.HashTable declared inside a Tensorflow Estimator model_fn. As the session is not directly available to us in Estimators, I am stuck with not being able to initialize the table. I am aware that if not used with Estimators, table can be initialized with table.init.run() using the session

I tried to initialize the table by using a sessionRunHook which I was already using for some other purpose. I pass the table init op as argument to session run in the before_run function. But table is still not initialized. I also tried to pass tf.tables_initializer() instead, but that did not work too. Another option I tried without success is the tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS.. command.

#sessionRunHook code below

class SaveToCSVHook(tf.train.SessionRunHook):
    def begin(self):        
        samples_weights_table = session.graph.get_tensor_by_name('samples_weights_table:0')
        self.samples_weights_table_init_op = samples_weights_table.init
        self.table_init_op = tf.tables_initializer() # also tried passing this to self.args instead - same result though
        tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, samples_weights_table.init)

    def after_create_session(self, session, coord):
        self.args ={'table_init_op':self.samples_weights_table_init_op}

    def before_run(self, run_context):
         return tf.train.SessionRunArgs(self.args)

    def after_run(self, run_context, run_values):
        print(f"Got Values: {run_values.results}")  

# Estimator model_fn code below

def model_fn(..)
    samples_weights_table = tf.contrib.lookup.HashTable(tf.contrib.lookup.KeyValueTensorInitializer(keysb, values, key_dtype=tf.string, value_dtype=tf.float32,name='samples_weights_table_init_op'), -1.0,name='samples_weights_table')

I get error "FailedPreconditionError (see above for traceback): Table not initialized" which obviously means the table is not getting initialized

1 个答案:

答案 0 :(得分:0)

如果任何人有兴趣知道答案,则在与Estimators一起使用时,无需显式初始化哈希表。默认情况下,它们会针对高级API(例如估算器)进行初始化。删除初始化程序的代码后,该错误消失,该表将按预期工作。