我需要使用tf.while_loop,但收到以下消息:
ValueError: Number of inputs and outputs of body must match loop_vars: 1, 2
total_loss, i = list(), tf.constant(0)
def sampled_softmax_body(i, total_loss):
loss = tf.nn.sampled_softmax_loss(weights=logits_weights, biases=logits_biases, labels=target, inputs=h2_decoder_outputs[:,i,:], num_sampled=n_sampled_softmax, num_classes=n_fra_words, partition_strategy="div")
total_loss.append(loss)
i = i + 1
return i, total_loss
def condition(i, total_loss):
return True and i < max_words_per_sentence["fra"]
tf.while_loop(condition, sampled_softmax_body, [i, total_loss])
答案 0 :(得分:0)
我更改了策略,并进行了此工作:
total_loss, i = list(), tf.constant(0)
def sampled_softmax_body(i):
loss = tf.nn.sampled_softmax_loss(weights=logits_weights, biases=logits_biases, labels=target, inputs=h2_decoder_outputs[:,i,:], num_sampled=n_sampled_softmax, num_classes=n_fra_words, partition_strategy="div")
total_loss.append(loss)
return tf.add(i, 1)
def condition(i):
return i < max_words_per_sentence["fra"]
tf.while_loop(condition, sampled_softmax_body, [i])
但是我应该如何获取total_list
的内容?
因为是转换购买循环。我没有创建的附加列表。
In [8]: total_loss
Out[8]: [<tf.Tensor 'while/softmax_cross_entropy_with_logits/Reshape_2:0' shape=(?,) dtype=float32>]
我的目标是通过total_loss