我正在尝试使用TensorFlow Estimator API从here实现模型。但是,我在处理手动更新的变量时遇到一些问题。 从原则上讲,问题是,据我所知,在任何时候调用model_fn都会重新初始化变量。
def setup_variables(batch_size, params):
# Mask describing ended sessions, true if session ended
ended_sessions_mask = tf.get_variable(
'ended_sessions_mask',
shape=(batch_size,),
initializer=tf.zeros_initializer(),
trainable=False,
dtype=tf.bool)
# Mask describing ended users, true if not more user events
ended_users_mask = tf.get_variable(
'ended_users_mask',
shape=(batch_size,),
initializer=tf.zeros_initializer(),
trainable=False,
dtype=tf.bool)
def model_fn(features, labels, mode, params):
(ended_sessions_mask,
ending_sessions_mask) = setup_variables(batch_size, params)
# Ended sessions where the user did not change
ended_sessions_same_user_mask = tf.logical_and(
ended_sessions_mask,
tf.logical_not(ended_users_mask)
)
# Get user_hidden_states to update
# The hidden states to update are the ones where a session ended
# but the user has stayed the same
# The other hidden states are 0
user_hidden_states = tf.map_fn(
lambda x: tf.cond(
x[1],
true_fn=lambda: tf.nn.embedding_lookup(user_embeddings, x[0]),
false_fn=lambda: tf.zeros(params['user_rnn_units'])
),
[
features['UserEmbeddingId'],
ended_sessions_same_user_mask
],
dtype=tf.float32,
name='get_user_hidden_states_to_update')
...
# Compute new mask for ended sessions
ended_sessions_mask = tf.cast(
tf.where(
tf.equal(features['ProductId'], -1),
tf.ones(tf.shape(ended_sessions_mask)),
tf.zeros(tf.shape(ended_sessions_mask)),
name='compute_ended_sessions'),
tf.bool)
# Compute new mask for ended users
ended_users_mask = tf.cast(
tf.where(
tf.equal(features['UserId'], -1),
tf.ones(tf.shape(ended_users_mask)),
tf.zeros(tf.shape(ended_users_mask)),
name='compute_ended_users'),
tf.bool)
原则上,模型函数的流程应为:
即掩码描述了上一步的结束的会话和用户。
据我了解,在使用get_variable
时应该有可能,因为那样的话,变量只有在之前不存在的情况下才会创建。但是无论何时我将model_fn称为掩码,都将用零重新初始化。我希望掩码具有上一次计算的值,但事实并非如此。