在我的TensorFlow网络中,我计算
saveChanges
其中:
Y = (X * C) * R,
是X
,(n, d)
是C
,(d, L)
是X * C
,(n, L)
是R
,(L, d)
是Y
。问题在于(n, d)
和n ≈ 1e6
,因此L ≈ 1e5
不适合内存,但是X * C
,因此d = 4
适合内存。我真的不需要存储Y
。
X * C
可以分批计算,即
Y = (X * C) * R
有没有一种简洁的方法告诉TensorFlow这样做?或者这里的标准程序是什么?
答案 0 :(得分:0)
以下似乎有效但创建变量会产生不良副作用!
关于如何改进的任何想法?
n, d, L = 40, 2, 3
num_batches = 10
rows_per_batch = n // num_batches
X = tf.constant(np.arange(n*d, dtype=np.float32).reshape(n, d))
C = tf.constant(np.arange(d*L, dtype=np.float32).reshape(d, L))
R = tf.constant(np.arange(L*d, dtype=np.float32).reshape(L, d))
Y = tf.matmul(tf.matmul(X, C), R)
yvar = tf.Variable(tf.constant(0, shape=Y.get_shape(), dtype=tf.float32))
def condition(current_batch):
return current_batch < num_batches
def body(current_batch):
current_rows_to_update = tf.range(rows_per_batch * current_batch,
rows_per_batch * (current_batch + 1))
xbatch = tf.gather(X, current_rows_to_update)
ybatch = tf.matmul(tf.matmul(xbatch, C), R, name='matmul_da')
yvar_update = tf.scatter_update(yvar, current_rows_to_update, ybatch)
with tf.control_dependencies([yvar_update]):
return current_batch + 1
loop = tf.while_loop(condition, body, loop_vars=[0])
with tf.control_dependencies([loop]):
ybatched = tf.identity(yvar)