tf.gradients线程安全吗?

时间:2017-08-18 03:43:41

标签: python multithreading tensorflow

我有几次打电话给tf.gradients,每次都需要一些时间,因此我想同时打电话给tf.gradients。但是,当我在图表中尝试这样做时,我收到了一些错误。我怀疑它不是线程安全的,但是无法用MWE重现错误。我尝试在我的MWE和实际代码中同时使用pathos.pools.ThreadPoolpathos.pools.ProcessPool -  只有我的真实代码失败了。这是我试过的MWE:

from pathos.pools import ThreadPool, ProcessPool
import tensorflow as tf
import numpy as np

Xs = [tf.cast(np.random.random((10,10)), dtype=tf.float64) for i in range(3)]
Ys = [Xs[0]*Xs[1]*Xs[2], Xs[0]/Xs[1]*Xs[2], Xs[0]/Xs[1]/Xs[2]]

def compute_grad(YX):
    return tf.gradients(YX[0], YX[1])

tp = ThreadPool(3)
res = tp.map(compute_grad, zip(Ys, Xs))
print(res)

这是我在尝试实际代码时遇到的部分回溯。这是ThreadPool版本。

File "pathos/threading.py", line 134, in map
    return _pool.map(star(f), zip(*args)) # chunksize
  File "multiprocess/pool.py", line 260, in map
    return self._map_async(func, iterable, mapstar, chunksize).get()
  File "multiprocess/pool.py", line 608, in get
    raise self._value
  File "multiprocess/pool.py", line 119, in worker
    result = (True, func(*args, **kwds))
  File "multiprocess/pool.py", line 44, in mapstar
    return list(map(*args))
  File "pathos/helpers/mp_helper.py", line 15, in <lambda>
    func = lambda args: f(*args)
  File "my_code.py", line 939, in gradients_with_index
    return (tf.gradients(Y, variables), b_idx)
  File "tensorflow/python/ops/gradients_impl.py", line 448, in gradients
    colocate_gradients_with_ops)
  File "tensorflow/python/ops/gradients_impl.py", line 188, in _PendingCount
    between_op_list, between_ops, colocate_gradients_with_ops)
  File "tensorflow/python/ops/control_flow_ops.py", line 1288, in MaybeCreateControlFlowState
    loop_state.AddWhileContext(op, between_op_list, between_ops)
  File "tensorflow/python/ops/control_flow_ops.py", line 1103, in AddWhileContext
    grad_state = GradLoopState(forward_ctxt, outer_grad_state)
  File "tensorflow/python/ops/control_flow_ops.py", line 737, in __init__
    cnt, outer_grad_state)
  File "tensorflow/python/ops/control_flow_ops.py", line 2282, in AddBackPropLoopCounter
    merge_count = merge([enter_count, enter_count])[0]
  File "tensorflow/python/ops/control_flow_ops.py", line 404, in merge
    return gen_control_flow_ops._merge(inputs, name)
  File "tensorflow/python/ops/gen_control_flow_ops.py", line 150, in _merge
    result = _op_def_lib.apply_op("Merge", inputs=inputs, name=name)
  File "tensorflow/python/framework/op_def_library.py", line 767, in apply_op
    op_def=op_def)
  File "tensorflow/python/framework/ops.py", line 2506, in create_op
    original_op=self._default_original_op, op_def=op_def)
  File "tensorflow/python/framework/ops.py", line 1273, in __init__
    self._control_flow_context.AddOp(self)
  File "tensorflow/python/ops/control_flow_ops.py", line 2147, in AddOp
    self._AddOpInternal(op)
  File "tensorflow/python/ops/control_flow_ops.py", line 2177, in _AddOpInternal
    self._MaybeAddControlDependency(op)
  File "tensorflow/python/ops/control_flow_ops.py", line 2204, in _MaybeAddControlDependency
    op._add_control_input(self.GetControlPivot().op)
AttributeError: 'NoneType' object has no attribute 'op'

这是另一个追溯。请注意错误不同

Traceback (most recent call last):
  File "tensorflow/python/ops/control_flow_ops.py", line 869, in AddForwardAccumulator
    enter_acc = self.forward_context.AddValue(acc)
  File "tensorflow/python/ops/control_flow_ops.py", line 2115, in AddValue
    self._outer_context.AddInnerOp(enter.op)
  File "tensorflow/python/framework/ops.py", line 3355, in __exit__
    self._graph._pop_control_dependencies_controller(self)
  File "tensorflow/python/framework/ops.py", line 3375, in _pop_control_dependencies_controller
    assert self._control_dependencies_stack[-1] is controller
AssertionError

ProcessPool版本遇到错误:

_pickle.PicklingError: Can't pickle <class 'tensorflow.python.util.tf_should_use._add_should_use_warning.<locals>.TFShouldUseWarningWrapper'>: it's not found as tensorflow.python.util.tf_should_use._add_should_use_warning.<locals>.TFShouldUseWarningWrapper

1 个答案:

答案 0 :(得分:2)

tf.gradients()功能不是线程安全的。它对您的图形进行了一系列复杂的非原子修改,并且这些修改不受锁的保护。特别是,如果您同时运行它,似乎在包含控制流操作(例如tf.gradients())的图上使用tf.while_loop()更有可能遇到问题。

请注意,发出tf.gradients()的并行调用不太可能加快速度 - 即使它是以线程安全的方式实现的。该函数不执行任何I / O,也不调用释放Python GIL的任何本机方法,因此执行很可能是序列化的。实现基于multiprocessing的并行性将需要额外的系统调用来访问共享图(以及获取/释放锁),因此这不太可能更快。