在训练Tensorflow MultivariateNormalDiag的sigma值时失败

时间:2017-07-11 17:41:40

标签: python tensorflow

Tensorflow版本:仅限v1.2.1 CPU OS平台:Win10 Tensorflow安装自:Anaconda

我正在尝试训练模型来学习高斯光滑的sigma参数。但是它给出了错误" InvalidArgumentError(参见上面的回溯):indices [0] = -1不在[0,2)"。

我的代码在这里。

import tensorflow as tf
import tensorflow.contrib.distributions as ds
import numpy as np
import scipy.ndimage.filters as filters


#set parameters-------------------
IMAGE_SIZE = (576, 1024)
KERNEL_SIZE = 3

#build graph--------------
image = tf.placeholder(tf.float32, shape=(None,)+IMAGE_SIZE+(1,))
label = tf.placeholder(tf.float32, shape=(None,)+IMAGE_SIZE+(1,))
sigma = tf.Variable(tf.ones(1))

xinds, yinds = np.unravel_index(range(KERNEL_SIZE*KERNEL_SIZE),                                 (KERNEL_SIZE, KERNEL_SIZE))
inds = (np.column_stack((xinds,yinds))-
                    [(KERNEL_SIZE-1)/2, (KERNEL_SIZE-1)/2]).astype(np.float32)
inds = tf.constant(inds)
loc = tf.zeros(2)
scale_diag = tf.multiply(sigma, tf.ones([2,]))

mvn = ds.MultivariateNormalDiag(
    loc=loc,
    scale_diag=scale_diag)

kernel = mvn.prob(inds)
kernel = tf.reshape(kernel, (KERNEL_SIZE, KERNEL_SIZE, 1, 1))


x = tf.nn.conv2d(image, kernel, strides=[1,1,1,1], padding='SAME')
loss = tf.nn.l2_loss(tf.subtract(x, label))

train_op = tf.train.AdamOptimizer().minimize(loss)

sess = tf.Session()
sess.run(tf.global_variables_initializer())
#start training-------------------------
stimulus = np.zeros(IMAGE_SIZE)
stimulus[300, 500] = 1
test_sigma = (30,30)
filtered = filters.gaussian_filter(stimulus, test_sigma)

for itr in range(10):
    feed_dict = {image: np.expand_dims(np.expand_dims(stimulus, 0), 3), 
                 label: np.expand_dims(np.expand_dims(filtered, 0), 3)}
    sess.run(train_op, feed_dict=feed_dict)
    loss_value = sess.run(loss, feed_dict)
    print('Training loss is %f' % loss_value)

完整的错误消息在这里。

InvalidArgumentError: indices[0] = -1 is not in [0, 2)
     [[Node: gradients/MultivariateNormalDiag_3/prob/Prod_grad/Gather = Gather[Tindices=DT_INT32, Tparams=DT_INT32, validate_indices=true, _device="/job:localhost/replica:0/task:0/cpu:0"](gradients/MultivariateNormalDiag_3/prob/Prod_grad/Shape, gradients/MultivariateNormalDiag_3/prob/Prod_grad/Reshape)]]

Caused by op 'gradients/MultivariateNormalDiag_3/prob/Prod_grad/Gather', defined at:
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\spyder\utils\ipython\start_kernel.py", line 231, in <module>
    main()
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\spyder\utils\ipython\start_kernel.py", line 227, in main
    kernel.start()
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel\kernelapp.py", line 477, in start
    ioloop.IOLoop.instance().start()
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\zmq\eventloop\ioloop.py", line 177, in start
    super(ZMQIOLoop, self).start()
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\tornado\ioloop.py", line 832, in start
    self._run_callback(self._callbacks.popleft())
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\tornado\ioloop.py", line 605, in _run_callback
    ret = callback()
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\tornado\stack_context.py", line 277, in null_wrapper
    return fn(*args, **kwargs)
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel\kernelbase.py", line 265, in enter_eventloop
    self.eventloop(self)
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel\eventloops.py", line 106, in loop_qt5
    return loop_qt4(kernel)
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel\eventloops.py", line 99, in loop_qt4
    _loop_qt(kernel.app)
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel\eventloops.py", line 83, in _loop_qt
    app.exec_()
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel\eventloops.py", line 39, in process_stream_events
    kernel.do_one_iteration()
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel\kernelbase.py", line 298, in do_one_iteration
    stream.flush(zmq.POLLIN, 1)
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\zmq\eventloop\zmqstream.py", line 352, in flush
    self._handle_recv()
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\zmq\eventloop\zmqstream.py", line 472, in _handle_recv
    self._run_callback(callback, msg)
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\zmq\eventloop\zmqstream.py", line 414, in _run_callback
    callback(*args, **kwargs)
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\tornado\stack_context.py", line 277, in null_wrapper
    return fn(*args, **kwargs)
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel\kernelbase.py", line 283, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel\kernelbase.py", line 235, in dispatch_shell
    handler(stream, idents, msg)
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel\kernelbase.py", line 399, in execute_request
    user_expressions, allow_stdin)
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel\ipkernel.py", line 196, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel\zmqshell.py", line 533, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\IPython\core\interactiveshell.py", line 2698, in run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\IPython\core\interactiveshell.py", line 2808, in run_ast_nodes
    if self.run_code(code, result):
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\IPython\core\interactiveshell.py", line 2862, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-1-cb72f96c9e6a>", line 1, in <module>
    runfile('C:/Users/pasca/Documents/7deep driving/attention_map/test_tf_gaussian.py', wdir='C:/Users/pasca/Documents/7deep driving/attention_map')
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\spyder\utils\site\sitecustomize.py", line 880, in runfile
    execfile(filename, namespace)
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\spyder\utils\site\sitecustomize.py", line 102, in execfile
    exec(compile(f.read(), filename, 'exec'), namespace)
  File "C:/Users/pasca/Documents/7deep driving/attention_map/test_tf_gaussian.py", line 40, in <module>
    train_op = tf.train.AdamOptimizer().minimize(loss)
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\optimizer.py", line 315, in minimize
    grad_loss=grad_loss)
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\optimizer.py", line 386, in compute_gradients
    colocate_gradients_with_ops=colocate_gradients_with_ops)
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\ops\gradients_impl.py", line 540, in gradients
    grad_scope, op, func_call, lambda: grad_fn(op, *out_grads))
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\ops\gradients_impl.py", line 346, in _MaybeCompile
    return grad_fn()  # Exit early
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\ops\gradients_impl.py", line 540, in <lambda>
    grad_scope, op, func_call, lambda: grad_fn(op, *out_grads))
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\ops\math_grad.py", line 129, in _ProdGrad
    reduced_num = math_ops.reduce_prod(array_ops.gather(input_shape, reduced))
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\ops\gen_array_ops.py", line 1179, in gather
    validate_indices=validate_indices, name=name)
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 767, in apply_op
    op_def=op_def)
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\ops.py", line 2506, in create_op
    original_op=self._default_original_op, op_def=op_def)
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\ops.py", line 1269, in __init__
    self._traceback = _extract_stack()

...which was originally created as op 'MultivariateNormalDiag_3/prob/Prod', defined at:
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\spyder\utils\ipython\start_kernel.py", line 231, in <module>
    main()
[elided 26 identical lines from previous traceback]
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\spyder\utils\site\sitecustomize.py", line 102, in execfile
    exec(compile(f.read(), filename, 'exec'), namespace)
  File "C:/Users/pasca/Documents/7deep driving/attention_map/test_tf_gaussian.py", line 33, in <module>
    kernel = mvn.prob(inds)
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\ops\distributions\distribution.py", line 712, in prob
    return self._call_prob(value, name)
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\ops\distributions\distribution.py", line 694, in _call_prob
    return self._prob(value, **kwargs)
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\ops\distributions\util.py", line 688, in _fn
    return fn(*args, **kwargs)
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\contrib\distributions\python\ops\mvn_linear_operator.py", line 216, in _prob
    return super(MultivariateNormalLinearOperator, self)._prob(x)
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\ops\distributions\transformed_distribution.py", line 406, in _prob
    prob = math_ops.reduce_prod(prob, self._reduce_event_indices)
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\ops\math_ops.py", line 1392, in reduce_prod
    name=name)
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\ops\gen_math_ops.py", line 1488, in _prod
    keep_dims=keep_dims, name=name)
  File "C:\Users\pasca\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 767, in apply_op
    op_def=op_def)

InvalidArgumentError (see above for traceback): indices[0] = -1 is not in [0, 2)
     [[Node: gradients/MultivariateNormalDiag_3/prob/Prod_grad/Gather = Gather[Tindices=DT_INT32, Tparams=DT_INT32, validate_indices=true, _device="/job:localhost/replica:0/task:0/cpu:0"](gradients/MultivariateNormalDiag_3/prob/Prod_grad/Shape, gradients/MultivariateNormalDiag_3/prob/Prod_grad/Reshape)]]

1 个答案:

答案 0 :(得分:0)

事实证明,Tensorflow Github上有一个commit来解决这个问题。