在Tensorflow上使用多个线程进行训练时随机崩溃

时间:2017-08-08 16:57:18

标签: python multithreading tensorflow parallel-processing

问题

以下脚本随机崩溃(即,有时崩溃并产生此回溯,大多数情况下它不会崩溃)。该脚本利用多个线程并行训练MNIST softmax模型。

您可以通过运行for ((n=0;n<100;n++)); do python mnist_softmax_parallel_issue.py; done

轻松重现崩溃

回溯

external/eigen_archive/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h:125: Eigen::TensorEvaluator<const Eigen::TensorBroadcastingOp<Broadcast, XprType>, Device>::T
ensorEvaluator(const XprType&, const Device&) [with Broadcast = const Eigen::IndexList<Eigen::type2index<1l>, int>; ArgType = const Eigen::TensorMap<Eigen::Tensor<float, 2, 1, long
 int>, 16, Eigen::MakePointer>; Device = Eigen::ThreadPoolDevice; Eigen::TensorEvaluator<const Eigen::TensorBroadcastingOp<Broadcast, XprType>, Device>::XprType = Eigen::TensorBroadcastingOp<const Eigen::IndexList<Eigen::type2index<1l>, int>, const Eigen::TensorMap<Eigen::Tensor<float, 2, 1, long int>, 16, Eigen::MakePointer> >]: Assertion input_dims[i] > $' failed.

源代码

mnist_softmax_device_issue.py

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import sys

from tensorflow.examples.tutorials.mnist import input_data

import tensorflow as tf
import threading
import numpy as np
import json
import os
import time

FLAGS = None

INTER_OP_PARALLELISM = 76
INTRA_OP_PARALLELISM = 1
BATCH_SIZE = 100
ITERATIONS = 1000
TRAINING_THREADS = 46

threads = [None] * TRAINING_THREADS

def train_function(thread_idx, mnist, sess, train_step, x, y_, y):
  iterations = int(ITERATIONS/TRAINING_THREADS)
  for i in range(iterations):
    batch_xs, batch_ys = mnist.train.next_batch(BATCH_SIZE)
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

def main(_):
  mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)

  x = tf.placeholder(tf.float32, [None, 784])
  W = tf.Variable(tf.zeros([784, 10]))
  b = tf.Variable(tf.zeros([10]))
  y = tf.matmul(x, W) + b

  y_ = tf.placeholder(tf.float32, [None, 10])

  cross_entropy = tf.reduce_mean(
      tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
  train_step = tf.train.GradientDescentOptimizer(0.5, use_locking=True).minimize(cross_entropy)

  sess = tf.InteractiveSession(config=tf.ConfigProto(intra_op_parallelism_threads = INTRA_OP_PARALLELISM, inter_op_parallelism_threads= INTER_OP_PARALLELISM))
  sess.run(tf.global_variables_initializer())

  for i in range(TRAINING_THREADS):
      threads[i] = threading.Thread(target=train_function, args=[i, mnist, sess, train_step, x, y_, y])

  for thread in threads:
      thread.start()
  for thread in threads:
      thread.join()


if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument('--data_dir', type=str, default='mnist-data',
                      help='Directory for storing input data')
  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

系统信息

  • OS平台和发行版(例如,Linux Ubuntu 16.04):Ubuntu 16.04.2 LTS
  • 从(来源或二进制)安装的TensorFlow :来源
  • TensorFlow版本(使用下面的命令):1.3.0-rc2
  • Python版:2.7.12
  • Bazel版本(如果从源代码编译):0.4.5

1 个答案:

答案 0 :(得分:0)

mnist.train.next_batch()不是线程安全的。此后,batch_xsbatch_ys可以分别获得形状(0, 784)(0, 10)的矩阵。

当使用零线传递这些矩阵时,会触发上述断言。