以下脚本随机崩溃(即,有时崩溃并产生此回溯,大多数情况下它不会崩溃)。该脚本利用多个线程并行训练MNIST softmax模型。
您可以通过运行for ((n=0;n<100;n++)); do python mnist_softmax_parallel_issue.py; done
回溯
external/eigen_archive/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h:125: Eigen::TensorEvaluator<const Eigen::TensorBroadcastingOp<Broadcast, XprType>, Device>::T
ensorEvaluator(const XprType&, const Device&) [with Broadcast = const Eigen::IndexList<Eigen::type2index<1l>, int>; ArgType = const Eigen::TensorMap<Eigen::Tensor<float, 2, 1, long
int>, 16, Eigen::MakePointer>; Device = Eigen::ThreadPoolDevice; Eigen::TensorEvaluator<const Eigen::TensorBroadcastingOp<Broadcast, XprType>, Device>::XprType = Eigen::TensorBroadcastingOp<const Eigen::IndexList<Eigen::type2index<1l>, int>, const Eigen::TensorMap<Eigen::Tensor<float, 2, 1, long int>, 16, Eigen::MakePointer> >]: Assertion input_dims[i] > $' failed.
mnist_softmax_device_issue.py
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import threading
import numpy as np
import json
import os
import time
FLAGS = None
INTER_OP_PARALLELISM = 76
INTRA_OP_PARALLELISM = 1
BATCH_SIZE = 100
ITERATIONS = 1000
TRAINING_THREADS = 46
threads = [None] * TRAINING_THREADS
def train_function(thread_idx, mnist, sess, train_step, x, y_, y):
iterations = int(ITERATIONS/TRAINING_THREADS)
for i in range(iterations):
batch_xs, batch_ys = mnist.train.next_batch(BATCH_SIZE)
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
def main(_):
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.matmul(x, W) + b
y_ = tf.placeholder(tf.float32, [None, 10])
cross_entropy = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
train_step = tf.train.GradientDescentOptimizer(0.5, use_locking=True).minimize(cross_entropy)
sess = tf.InteractiveSession(config=tf.ConfigProto(intra_op_parallelism_threads = INTRA_OP_PARALLELISM, inter_op_parallelism_threads= INTER_OP_PARALLELISM))
sess.run(tf.global_variables_initializer())
for i in range(TRAINING_THREADS):
threads[i] = threading.Thread(target=train_function, args=[i, mnist, sess, train_step, x, y_, y])
for thread in threads:
thread.start()
for thread in threads:
thread.join()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str, default='mnist-data',
help='Directory for storing input data')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
答案 0 :(得分:0)
mnist.train.next_batch()
不是线程安全的。此后,batch_xs
和batch_ys
可以分别获得形状(0, 784)
和(0, 10)
的矩阵。
当使用零线传递这些矩阵时,会触发上述断言。