TensorFlow while_loop将变量转换为常量?

时间:2017-07-04 18:15:42

标签: python python-2.7 tensorflow

我试图更新嵌套while_loop()中的二维张量。但是,当将变量传递给第二个循环时,我无法使用tf.assign()更新它,因为它会抛出此错误:

ValueError: Sliced assignment is only supported for variables

如果我在while_loop之外创建变量并且仅在第一个循环中使用它,它会以某种方式正常工作。

如何在第二个while循环中修改我的2D tf变量?
(我使用的是python 2.7和TensorFlow 1.2)

我的代码:

import tensorflow as tf
import numpy as np

tf.reset_default_graph()

BATCH_SIZE = 10
LENGTH_MAX_OUTPUT = 31

it_batch_nr = tf.constant(0)
it_row_nr = tf.Variable(0, dtype=tf.int32)
it_col_nr = tf.constant(0)
cost = tf.constant(0)

it_batch_end = lambda it_batch_nr, cost: tf.less(it_batch_nr, BATCH_SIZE)
it_row_end = lambda it_row_nr, cost_matrix: tf.less(it_row_nr, LENGTH_MAX_OUTPUT+1)

def iterate_batch(it_batch_nr, cost):
    cost_matrix = tf.Variable(np.ones((LENGTH_MAX_OUTPUT+1, LENGTH_MAX_OUTPUT+1)), dtype=tf.float32)
    it_rows, cost_matrix = tf.while_loop(it_row_end, iterate_row, [it_row_nr, cost_matrix])
    cost = cost_matrix[0,0] # IS 1.0, SHOULD BE 100.0
    return tf.add(it_batch_nr,1), cost

def iterate_row(it_row_nr, cost_matrix):
    # THIS THROWS AN ERROR:
    cost_matrix[0,0].assign(100.0)
    return tf.add(it_row_nr,1), cost_matrix

it_batch = tf.while_loop(it_batch_end, iterate_batch, [it_batch_nr, cost])

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
out = sess.run(it_batch)
print(out)

2 个答案:

答案 0 :(得分:1)

Traceback (most recent call last): File "/home/wermarter/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1039, in _do_call return fn(*args) File "/home/wermarter/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1021, in _run_fn status, run_metadata) File "/home/wermarter/anaconda3/lib/python3.5/contextlib.py", line 66, in __exit__ next(self.gen) File "/home/wermarter/anaconda3/lib/python3.5/site-packages/tensorflow/python/framework/errors_impl.py", line 466, in raise_exception_on_not_ok_status pywrap_tensorflow.TF_GetCode(status)) tensorflow.python.framework.errors_impl.InvalidArgumentError: You must feed a value for placeholder tensor 'train_data/X' with dtype float [[Node: train_data/X = Placeholder[dtype=DT_FLOAT, shape=[], _device="/job:localhost/replica:0/task:0/gpu:0"]()]] [[Node: add_5/_47 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/cpu:0", send_device="/job:localhost/replica:0/task:0/gpu:0", send_device_incarnation=1, tensor_name="edge_8_add_5", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/cpu:0"]()]] During handling of the above exception, another exception occurred: Traceback (most recent call last): File "/home/wermarter/Desktop/vae.py", line 178, in <module> main() File "/home/wermarter/Desktop/vae.py", line 172, in main vae.img_transition(trainX[4], trainX[100]) File "/home/wermarter/Desktop/vae.py", line 130, in img_transition enc_A = self.encode(A)[0] File "/home/wermarter/Desktop/vae.py", line 121, in encode return self.recognition_model.predict({self.input_data: input_data}) File "/home/wermarter/anaconda3/lib/python3.5/site-packages/tflearn/models/dnn.py", line 257, in predict return self.predictor.predict(feed_dict) File "/home/wermarter/anaconda3/lib/python3.5/site-packages/tflearn/helpers/evaluator.py", line 69, in predict return self.session.run(self.tensors[0], feed_dict=feed_dict) File "/home/wermarter/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 778, in run run_metadata_ptr) File "/home/wermarter/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 982, in _run feed_dict_string, options, run_metadata) File "/home/wermarter/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1032, in _do_run target_list, options, run_metadata) File "/home/wermarter/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1052, in _do_call raise type(e)(node_def, op, message) tensorflow.python.framework.errors_impl.InvalidArgumentError: You must feed a value for placeholder tensor 'train_data/X' with dtype float [[Node: train_data/X = Placeholder[dtype=DT_FLOAT, shape=[], _device="/job:localhost/replica:0/task:0/gpu:0"]()]] [[Node: add_5/_47 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/cpu:0", send_device="/job:localhost/replica:0/task:0/gpu:0", send_device_incarnation=1, tensor_name="edge_8_add_5", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/cpu:0"]()]] Caused by op 'train_data/X', defined at: File "/home/wermarter/Desktop/vae.py", line 178, in <module> main() File "/home/wermarter/Desktop/vae.py", line 169, in main vae = VAE() File "/home/wermarter/Desktop/vae.py", line 28, in __init__ self._build_training_model() File "/home/wermarter/Desktop/vae.py", line 78, in _build_training_model self.train_data = tflearn.input_data(shape=[None, *self.img_shape], name='train_data') File "/home/wermarter/anaconda3/lib/python3.5/site-packages/tflearn/layers/core.py", line 81, in input_data placeholder = tf.placeholder(shape=shape, dtype=dtype, name="X") File "/home/wermarter/anaconda3/lib/python3.5/site-packages/tensorflow/python/ops/array_ops.py", line 1507, in placeholder name=name) File "/home/wermarter/anaconda3/lib/python3.5/site-packages/tensorflow/python/ops/gen_array_ops.py", line 1997, in _placeholder name=name) File "/home/wermarter/anaconda3/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py", line 768, in apply_op op_def=op_def) File "/home/wermarter/anaconda3/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 2336, in create_op original_op=self._default_original_op, op_def=op_def) File "/home/wermarter/anaconda3/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1228, in __init__ self._traceback = _extract_stack() InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'train_data/X' with dtype float [[Node: train_data/X = Placeholder[dtype=DT_FLOAT, shape=[], _device="/job:localhost/replica:0/task:0/gpu:0"]()]] [[Node: add_5/_47 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/cpu:0", send_device="/job:localhost/replica:0/task:0/gpu:0", send_device_incarnation=1, tensor_name="edge_8_add_5", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/cpu:0"]()]] 个对象不能在while循环中用作循环变量,因为循环变量的实现方式不同。

因此要么在循环外创建变量并在每次迭代中使用tf.assign自行更新它,要么像循环变量一样手动跟踪更新(通过从循环lambdas返回更新的值,在您的情况下)使用内部循环中的值作为外部循环的新值。)

答案 1 :(得分:0)

通过将变量置于while_loop之外,使用@AlexandrePassos帮助实现此功能。但是,我还必须使用this.data.sort(function(a,b){ return a.count - b.count; }); 强制执行命令(因为操作不直接用在循环变量上)。循环现在看起来像这样:

tf.control_dependencies()