我试图用tf.dynamic_partition()
代替tf.gather()
来避免将稀疏表示形式隐式转换为密集矩阵。
这是我的代码,
# edge_source_states = tf.gather(params=node_states_per_layer[-1], indices=edge_sources)
partitions0 = tf.reduce_sum(tf.one_hot(edge_sources, tf.shape(node_states_per_layer[-1])[0], dtype='int32'),
0)
edge_source_states = tf.dynamic_partition(node_states_per_layer[-1], partitions0, 2)
edge_source_states = edge_source_states[1]
注释是tf.gather()
的原始用法。使用tf.gather()
时没有任何错误,唯一的问题是它将稀疏表示转换为密集矩阵,从而消耗了大量内存。
但是,当我改用tf.dynamic_partition()
方法时,出现错误:
InvalidArgumentError (see above for traceback): partitions[21] = 2 is not in [0, 2)
根据追溯,此错误是由以下句子引起的:
edge_source_states = tf.dynamic_partition(node_states_per_layer[-1], partitions0, 2)
作为一个新手,我真的不明白。
我的问题是:
1)我认为使用tf.dynamic_partition()
的新代码在功能上等同于使用tf.gather()
的原始代码。那为什么会有错误呢?
2)tf.dynamic_partition()
是否可以避免像tf.gather()
那样从稀疏到密集的隐式转换?还有其他解决方案吗?我真的需要严格控制内存消耗。