如何在验证时间内关闭自动混合精度?

时间:2019-08-06 01:26:15

标签: validation tensorflow precision mixed

我尝试以自动混合精度运行MAC网络(https://github.com/stanfordnlp/mac-network/tree/gqa)。

def addOptimizerOp(self): 
    with tf.variable_scope("trainAddOptimizer"):            
            self.globalStep = tf.Variable(0, dtype = tf.int32, trainable = False, name = "globalStep") # init to 0 every run?
        optimizer = tf.train.AdamOptimizer(learning_rate = self.lr)
        optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)
        if config.subsetOpt:
            self.subsetOptimizer = tf.train.AdamOptimizer(learning_rate = self.lr * config.subsetOptMult)

    return optimizer

在第一个时期,训练就可以了。但是,当模型在验证集上运行评估时,出现了此错误。

Training epoch 1...
2019-08-05 14:51:13.625899: I tensorflow/core/grappler/optimizers/auto_mixed_precision.cc:1767] Running auto_mixed_precision graph optimizer
2019-08-05 14:51:13.709959: I tensorflow/core/grappler/optimizers/auto_mixed_precision.cc:1723] Converted 1504/6920 nodes to float16 precision using 150 cast(s) to float16 (excluding Const and Variable casts)
2019-08-05 14:51:16.930248: I tensorflow/stream_executor/platform/default/dso_loader.cc:42] Successfully opened dynamic library libcublas.so.10.0
2019-08-05 14:51:17.331687: I tensorflow/stream_executor/platform/default/dso_loader.cc:42] Successfully opened dynamic library libcudnn.so.7
2019-08-05 14:51:29.378905: I tensorflow/core/grappler/optimizers/auto_mixed_precision.cc:1767] Running auto_mixed_precision graph optimizer
2019-08-05 14:51:29.380633: I tensorflow/core/grappler/optimizers/auto_mixed_precision.cc:1241] No whitelist ops found, nothing to do
eb  1, 10000,(160010 / 943000), t = 0.12 (0.00+0.11), lr 0.0003, l = 2.8493, a = 0.4250, avL = 2.5323, avA = 0.4188, g = 3.7617, emL = 2.3097, emA = 0.4119; gqaExperiment
Restoring EMA weights
2019-08-05 14:51:31.132804: I tensorflow/core/grappler/optimizers/auto_mixed_precision.cc:1767] Running auto_mixed_precision graph optimizer
2019-08-05 14:51:31.136122: I tensorflow/core/grappler/optimizers/auto_mixed_precision.cc:1241] No whitelist ops found, nothing to do
2019-08-05 14:51:32.322369: I tensorflow/core/grappler/optimizers/auto_mixed_precision.cc:1767] Running auto_mixed_precision graph optimizer
2019-08-05 14:51:32.341609: I tensorflow/core/grappler/optimizers/auto_mixed_precision.cc:1723] Converted 661/1848 nodes to float16 precision using 38 cast(s) to float16 (excluding Const and Variable casts)
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py", line 1356, in _do_call
    return fn(*args)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py", line 1341, in _run_fn
    options, feed_dict, fetch_list, target_list, run_metadata)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py", line 1429, in _call_tf_sessionrun
    run_metadata)
tensorflow.python.framework.errors_impl.InvalidArgumentError: 2 root error(s) found.
  (0) Invalid argument: TensorArray dtype is float but Op is trying to write dtype half.
     [[{{node macModel/tower0/encoder/birnnLayer/bidirectional_rnn/fw/fw/while/TensorArrayWrite/TensorArrayWriteV3}}]]
     [[macModel/tower0/MACnetwork/MACCell_3/write/inter2attselfAttention/Softmax/_1661]]
  (1) Invalid argument: TensorArray dtype is float but Op is trying to write dtype half.
     [[{{node macModel/tower0/encoder/birnnLayer/bidirectional_rnn/fw/fw/while/TensorArrayWrite/TensorArrayWriteV3}}]]
0 successful operations.
0 derived errors ignored.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "main.py", line 866, in <module>
    main()
  File "main.py", line 728, in main
    evalRes = runEvaluation(sess, model, data["main"], dataOps, epoch, getPreds = getPreds, prevRes = evalRes)
  File "main.py", line 248, in runEvaluation
    minLoss = prevRes["train"]["minLoss"] if prevRes else float("inf"))
  File "main.py", line 594, in runEpoch
    res = model.runBatch(sess, batch, imagesBatch, train, getPreds, getAtt)
  File "/content/model.py", line 948, in runBatch
    feed_dict = feed)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py", line 950, in run
    run_metadata_ptr)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py", line 1173, in _run
    feed_dict_tensor, options, run_metadata)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py", line 1350, in _do_run
    run_metadata)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py", line 1370, in _do_call
    raise type(e)(node_def, op, message)

我认为错误可能是由于未在评估时关闭自动混合精度或使用相同的会话和模型进行训练和评估。我尝试过“ tf.train.experimental.disable_mixed_precision_graph_rewrite()”,但我不知道如何正确使用它。

我该如何解决?谢谢大家。

def main():
    with open(config.configFile(), "a+") as outFile:
        json.dump(vars(config), outFile)

    # set gpus
    if config.gpus != "":
        config.gpusNum = len(config.gpus.split(","))
        os.environ["CUDA_VISIBLE_DEVICES"] = config.gpus
    tf.logging.set_verbosity(tf.logging.ERROR)

    # process data
    print(bold("Preprocess data..."))
    start = time.time()
    preprocessor = Preprocesser()
    data, embeddings, answerDict, questionDict = preprocessor.preprocessData()
    print("took {} seconds".format(bcolored("{:.2f}".format(time.time() - start), "blue")))

    nextElement = None
    dataOps = None

    # build model
    print(bold("Building model..."))
    start = time.time()
    model = MACnet(embeddings, answerDict, questionDict, nextElement)
    print("took {} seconds".format(bcolored("{:.2f}".format(time.time() - start), "blue")))

    # initializer
    init = tf.global_variables_initializer()

    # savers
    savers = setSavers(model)
    saver, emaSaver = savers["saver"], savers["emaSaver"]

    # sessionConfig
    sessionConfig = setSession()

    with tf.Session(config = sessionConfig) as sess:

        # ensure no more ops are added after model is built
        sess.graph.finalize()

        # restore / initialize weights, initialize epoch variable
        epoch = loadWeights(sess, saver, init)

        trainRes, evalRes = None, None

        if config.train:
            start0 = time.time()

            bestEpoch = epoch
            bestRes = None
            prevRes = None

            # epoch in [restored + 1, epochs]
            for epoch in range(config.restoreEpoch + 1, config.epochs + 1):
                print(bcolored("Training epoch {}...".format(epoch), "green"))
                start = time.time()

                # train
                # calle = lambda: model.runEpoch(), collectRuntimeStats, writer
                trainingData, alterData = chooseTrainingData(data)
                trainRes = runEpoch(sess, model, trainingData, dataOps, train = True, epoch = epoch, 
                    saver = saver, alterData = alterData, 
                    maxAcc = trainRes["maxAcc"] if trainRes else 0.0,
                    minLoss = trainRes["minLoss"] if trainRes else float("inf"),)

                # save weights
                saver.save(sess, config.weightsFile(epoch))
                if config.saveSubset:
                    subsetSaver.save(sess, config.subsetWeightsFile(epoch))                   

                # load EMA weights 
                if config.useEMA:
                    print(bold("Restoring EMA weights"))
                    emaSaver.restore(sess, config.weightsFile(epoch))

                # evaluation  
                getPreds = config.getPreds or (config.analysisType != "")

                evalRes = runEvaluation(sess, model, data["main"], dataOps, epoch, getPreds = getPreds, prevRes = evalRes)
                extraEvalRes = runEvaluation(sess, model, data["extra"], dataOps, epoch, 
                    evalTrain = not config.extraVal, getPreds = getPreds)

                # restore standard weights
                if config.useEMA:
                    print(bold("Restoring standard weights"))
                    saver.restore(sess, config.weightsFile(epoch))

                print("")

                epochTime = time.time() - start
                print("took {:.2f} seconds".format(epochTime))

                # print results
                printDatasetResults(trainRes, evalRes, extraEvalRes)

0 个答案:

没有答案