如何在tf.estimator中处理变量?

时间:2019-04-18 15:17:38

标签: python tensorflow tensorflow-estimator

我正在尝试使用TensorFlow Estimator API从here实现模型。但是,我在处理手动更新的变量时遇到一些问题。 从原则上讲,问题是,据我所知,在任何时候调用model_fn都会重新初始化变量。

def setup_variables(batch_size, params):
    # Mask describing ended sessions, true if session ended
    ended_sessions_mask = tf.get_variable(
        'ended_sessions_mask',
        shape=(batch_size,),
        initializer=tf.zeros_initializer(),
        trainable=False,
        dtype=tf.bool)

    # Mask describing ended users, true if not more user events
    ended_users_mask = tf.get_variable(
        'ended_users_mask',
        shape=(batch_size,),
        initializer=tf.zeros_initializer(),
        trainable=False,
        dtype=tf.bool)

def model_fn(features, labels, mode, params):

    (ended_sessions_mask,
        ending_sessions_mask) = setup_variables(batch_size, params)


    # Ended sessions where the user did not change
    ended_sessions_same_user_mask = tf.logical_and(
        ended_sessions_mask,
        tf.logical_not(ended_users_mask)
    )

    # Get user_hidden_states to update
    # The hidden states to update are the ones where a session ended
    # but the user has stayed the same
    # The other hidden states are 0
    user_hidden_states = tf.map_fn(
        lambda x: tf.cond(
            x[1],
            true_fn=lambda: tf.nn.embedding_lookup(user_embeddings, x[0]),
            false_fn=lambda: tf.zeros(params['user_rnn_units'])
        ),
        [
            features['UserEmbeddingId'],
            ended_sessions_same_user_mask
        ],
        dtype=tf.float32,
        name='get_user_hidden_states_to_update')

...


    # Compute new mask for ended sessions
    ended_sessions_mask = tf.cast(
        tf.where(
            tf.equal(features['ProductId'], -1),
            tf.ones(tf.shape(ended_sessions_mask)),
            tf.zeros(tf.shape(ended_sessions_mask)),
            name='compute_ended_sessions'),
        tf.bool)

    # Compute new mask for ended users
    ended_users_mask = tf.cast(
        tf.where(
            tf.equal(features['UserId'], -1),
            tf.ones(tf.shape(ended_users_mask)),
            tf.zeros(tf.shape(ended_users_mask)),
            name='compute_ended_users'),
        tf.bool)

原则上,模型函数的流程应为:

  • 根据遮罩更新用户嵌入,遮罩是在上一步中计算的
  • 应用模型,计算损失等
  • 计算将在下一步中使用的新蒙版。

即掩码描述了上一步的结束的会话和用户。

据我了解,在使用get_variable时应该有可能,因为那样的话,变量只有在之前不存在的情况下才会创建。但是无论何时我将model_fn称为掩码,都将用零重新初始化。我希望掩码具有上一次计算的值,但事实并非如此。

0 个答案:

没有答案