我正在尝试在两个循环中的tf节点上执行某些操作。但是,TensorFlow倾向于在循环的每次迭代中为图形创建更多节点。有没有办法以Python的方式在tf节点上进行多个操作,比如没有创建所有节点,只是遍历循环并进行操作?
为了举例说明下面的代码,似乎TensorFlow在循环中创建了所有节点(xi,xj和loss):
def get_marginal_loss(features, labels, threshold, margin):
with tf.variable_scope("marginal", reuse=tf.AUTO_REUSE):
xi = tf.get_variable('xi', [], dtype=tf.float32, initializer=tf.constant_initializer(0), trainable=False)
xj = tf.get_variable('xj', [], dtype=tf.float32, initializer=tf.constant_initializer(0), trainable=False)
loss = tf.get_variable('marginalLoss', [], dtype=tf.float32, initializer=tf.constant_initializer(0), trainable=False)
assign_op = loss.assign(0), xi.assign(0), xj.assign(0)
for i in range(batch_size):
for j in range(batch_size):
if i == j: break
xi = features[i] / tf.norm(features[i], ord=1)
xj = features[j] / tf.norm(features[j], ord=1)
print(xi)
print(xj)
dis = tf.norm(xi-xj, ord=2)**2
print(dis)
y = 1 if labels[i] == labels[j] else -1
loss = loss + tf.maximum(margin - y * (threshold - dis), 0)
return loss / (batch_size*(batch_size - 1)), assign_op
答案 0 :(得分:0)
如果您指的是变量节点 - 您应该使用tf.get_variable()方法。 来自official TensorFlow documentation:“使用这些参数获取现有变量或创建新参数”