关于tensorflow.contrib.eager.metrics.Mean()的错误

时间:2019-06-12 07:21:47

标签: tensorflow

我正在研究Google Colabs中的tensorflow教程,并且已按照以下链接中的说明运行了所有内容:

https://www.tensorflow.org/tutorials/eager/custom_training_walkthrough

num_epochs = 201;

for epoch in range(num_epochs):
    epoch_loss_avg = tensorflow.contrib.eager.metrics.Mean();
    epoch_accuracy = tensorflow.contrib.eager.metrics.Accuracy();

    # Training loop - using batches of 32
    for x, y in data_set:
        # Optimize the model
        loss_value, grads = grad(model, x, y);
        optimizer.apply_gradients(zip(grads, model.variables), global_step);

        # Track progress
        epoch_loss_avg(loss_value);
        # compare predicted label to actual label
        epoch_accuracy(tensorflow.argmax(model(x), axis=1, output_type=tensorflow.int32), y);

    # end epoch
    train_loss_results.append(epoch_loss_avg.result());
    train_accuracy_results.append(epoch_accuracy.result());

此代码可以正常工作,但是如果我按如下所示重写它,则会收到如下错误:

非布尔型张量(tf.Tensor:id = 201,shape =(),dtype = float32,numpy = 3.6846912)无法转换为布尔型。

num_epochs = 201;

for epoch in range(num_epochs):
    #epoch_loss_avg = tensorflow.contrib.eager.metrics.Mean();
    #epoch_accuracy = tensorflow.contrib.eager.metrics.Accuracy();

    # Training loop - using batches of 32
    for x, y in data_set:
        # Optimize the model
        loss_value, grads = grad(model, x, y);
        optimizer.apply_gradients(zip(grads, model.variables), global_step);

        # Track progress
        tensorflow.contrib.eager.metrics.Mean(loss_value);
        # compare predicted label to actual label
        tensorflow.contrib.eager.metrics.Accuracy(tensorflow.argmax(model(x), axis=1, output_type=tensorflow.int32), y);

    # end epoch
    train_loss_results.append(epoch_loss_avg.result());
    train_accuracy_results.append(epoch_accuracy.result());

原因是什么?

1 个答案:

答案 0 :(得分:0)

epoch_loss_avg = tensorflow.contrib.eager.metrics.Mean()通过调用epoch_loss_avg(loss_value)创建了一个可用来追踪均值的对象。

但是,如果您改为使用tensorflow.contrib.eager.metrics.Mean(loss_value),则您尝试创建的平均跟踪对象具有浮动的张量,而该浮动张量无法理解。请参阅Mean__init____call__方法(文档和代码)以了解发生了什么。

也不要在python中使用分号:)