如何使用TensorFlow的热切执行来检查张量的内容?

时间:2019-05-15 10:15:02

标签: python tensorflow tensor eager-execution

我通过急切的执行使用TensorFlow 1.12,并且我具有以下(不完整的)函数来检查一些中间张量:

def parse_example(example_proto, width, height, num_classes):
    features = {
        'image/encoded': tf.FixedLenFeature((), tf.string),
        'image/height': tf.FixedLenFeature((), tf.int64),
        'image/width': tf.FixedLenFeature((), tf.int64),
        'image/filename': tf.FixedLenFeature((), tf.string),
        'image/object/bbox/xmin': tf.VarLenFeature(tf.float32),
        'image/object/bbox/xmax': tf.VarLenFeature(tf.float32),
        'image/object/bbox/ymin': tf.VarLenFeature(tf.float32),
        'image/object/bbox/ymax': tf.VarLenFeature(tf.float32),
        'image/object/class/label': tf.VarLenFeature(tf.int64),
        'image/object/class/text': tf.VarLenFeature(tf.string),
        'image/object/mask': tf.VarLenFeature(tf.string),
        'image/depth': tf.FixedLenFeature((), tf.string)
    }

    parsed_example = tf.parse_single_example(example_proto, features)

    #print(tf.sparse_tensor_to_dense(parsed_example['image/object/mask'], default_value=0))

    # Decode image
    image = tf.image.decode_jpeg(parsed_example['image/encoded'])
    parsed_example['image/encoded'] = image

    # Depth + RGBD
    depth = utilities.decode_depth(parsed_example['image/depth'])
    parsed_example['image/depth'] = depth
    rgbd = tf.concat([tf.image.convert_image_dtype(image, tf.float32), depth], axis=2)
    rgbd = tf.reshape(rgbd, shape=tf.stack([height, width, 4]))
    parsed_example['image/rgbd'] = rgbd

    mask = tf.sparse.to_dense(parsed_example['image/object/mask'], default_value="")
    mask = tf.map_fn(utilities.decode_png_mask, mask, dtype=tf.uint8)
    mask = tf.reshape(mask, shape=tf.stack([-1, height, width]), name='mask')
    print(mask)
    sys.exit()

但是,print(mask)仅返回Tensor("mask:0", shape=(?, 1000, 1200), dtype=uint8),而我希望看到实际值。如TensorFlow’s eager execution guide所示,这应该是可能的。我也尝试过tf.print(mask, output_stream=sys.stdout),但是只打印了空白行。 mask.dtypeuint8,因此我想它应该包含整数,因为它具有 a 形状。我还感到奇怪的是,mask.device是空字符串。应该将其存储在某些设备上,对吧?

如何打印mask张量的内容?

1 个答案:

答案 0 :(得分:1)

如果启用了急切执行功能,那么您应该可以致电

mask.numpy() 

返回该张量中的numpy值数组。

我的印象是,启用急切执行时,print也应打印内容,但这可能取决于张量的大小。

无论哪种方式,都值得通过调用以下命令检查是否启用了急切执行:

tf.enable_eager_execution()