estimator.get_variable_value如何组装分区变量?

时间:2019-07-26 04:36:26

标签: python tensorflow

我正在尝试使用Estimator API在Tensorflow中训练嵌入模型。训练后,我想使用get_variable_value对象中的Estimator函数以numpy数组的形式检索嵌入矩阵,然后根据索引列表选择嵌入的几行。在代码中,我想执行以下操作:

estimator = tf.Estimator(build_graph_fn=...)
estimator.train(...)

embeddings = estimator.get_variable_value(name=name_of_embedding_variable)

indices = [0, 10, 100]
print(embeddings[indices])

如果嵌入变量不大(例如2GB),那么我不需要做任何分区来遵守协议缓冲区的2GB限制。但是,如果嵌入大于2GB,则需要用一个分区程序对嵌入变量进行分区,例如

mb = 2 ** 20
partitioner=tf.variable_axis_size_partitioner(64 * mb)
embeddings = tf.get_variable(
  name="embedding",
  initializer=tf.random.uniform([2000000, 300]),
  partitioner=partitioner)

我的问题是,如果对变量进行了分区,那么通过get_variable_value检索到的numpy嵌入会对应于原始嵌入变量的连续切片吗?

此外,如果在训练过程中使用tf.nn.embedding_lookup分区策略而不是mod策略通过div检索嵌入行中的行,则对结果有何影响。

谢谢您的帮助!

0 个答案:

没有答案