在TensorFlow上运行TensorBox ReInspect实现时出错

时间:2016-09-08 09:41:14

标签: python tensorflow

我正在尝试使用一个GPU(NVidia GeForce GTX 750 Ti)在我的机器上训练TensorBox ReInspect实现(https://github.com/Russell91/TensorBox/)。 当我运行train.py脚本时:

python train.py --hypes hypes/overfeat_rezoom.json --gpu 0 --logdir output

我收到以下错误:

Traceback (most recent call last):
  File "train.py", line 543, in <module>
    main()
  File "train.py", line 540, in main
    train(H, test_images=[])
  File "train.py", line 453, in train
    smooth_op, global_step, learning_rate, encoder_net) = build(H, q)
  File "train.py", line 346, in build
    boxes_loss[phase]) = build_forward_backward(H, x, encoder_net, phase, boxes, flags)
  File "train.py", line 245, in build_forward_backward
    pred_confidences, pred_confs_deltas, pred_boxes_deltas) = build_forward(H, x, googlenet, phase, reuse)
  File "train.py", line 167, in build_forward
    lstm_outputs = build_lstm_inner(H, lstm_input)
  File "train.py", line 50, in build_lstm_inner
    state = tf.zeros([batch_size, lstm.state_size])
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/array_ops.py", line 1136, in zeros
    shape = ops.convert_to_tensor(shape, dtype=dtypes.int32, name="shape")
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 628, in convert_to_tensor
    ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/constant_op.py", line 180, in _constant_tensor_conversion_function
    return constant(v, dtype=dtype, name=name)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/constant_op.py", line 163, in constant
    tensor_util.make_tensor_proto(value, dtype=dtype, shape=shape))
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/tensor_util.py", line 354, in make_tensor_proto
    nparray = np.array(values, dtype=np_dt)
ValueError: setting an array element with a sequence.

如果我在没有GPU且只在CPU上的机器上运行相同的代码,则不会发生这种情况。

导致此错误的原因是什么?有没有办法调试它?

1 个答案:

答案 0 :(得分:0)

问题似乎是因为state_is_true标志。请参考此主题以获取解决方案: https://github.com/Russell91/TensorBox/issues/59#issuecomment-245566316