我正在尝试使用Estimator
API在Tensorflow中训练嵌入模型。训练后,我想使用get_variable_value
对象中的Estimator
函数以numpy数组的形式检索嵌入矩阵,然后根据索引列表选择嵌入的几行。在代码中,我想执行以下操作:
estimator = tf.Estimator(build_graph_fn=...)
estimator.train(...)
embeddings = estimator.get_variable_value(name=name_of_embedding_variable)
indices = [0, 10, 100]
print(embeddings[indices])
如果嵌入变量不大(例如2GB),那么我不需要做任何分区来遵守协议缓冲区的2GB限制。但是,如果嵌入大于2GB,则需要用一个分区程序对嵌入变量进行分区,例如
mb = 2 ** 20
partitioner=tf.variable_axis_size_partitioner(64 * mb)
embeddings = tf.get_variable(
name="embedding",
initializer=tf.random.uniform([2000000, 300]),
partitioner=partitioner)
我的问题是,如果对变量进行了分区,那么通过get_variable_value
检索到的numpy嵌入会对应于原始嵌入变量的连续切片吗?
此外,如果在训练过程中使用tf.nn.embedding_lookup
分区策略而不是mod
策略通过div
检索嵌入行中的行,则对结果有何影响。
谢谢您的帮助!