来自TF2的tf.print无法打印到输出

时间:2019-12-19 10:44:19

标签: python tensorflow tensorflow-datasets tensorflow2.0

我正在使用TF2,我想在运行在tf.data.Dataset管道中的函数中打印张量。

这是我的代码:

import tensorflow as tf
import sys

sys.stdout = open('tf.log', 'w')

def main():

    ## Dataset generator
    #
    numRows= 100

    indx = tf.reshape([i+1 for i in range(numRows)], [numRows,1])
    features = tf.random.uniform([numRows, 2], minval=1, maxval=10, dtype=tf.int32)

    myData = tf.concat([indx, features], 1)

    ## tf.data.Dataset
    #
    dataset = tf.data.Dataset.from_tensor_slices(myData)

    ## Pipeline
    #
    dataset.map(myFunc)

    ## Run pipeline
    #
    for d in dataset:
        print('--')

def myFunc(t):
    tf.print(t, output_stream=sys.stdout)
    return t

if __name__ == "__main__":
    main()

但是结果,我只得到了这个:

--
--
--
--
--
--
--
--
--
--
--

如何使用tf.print将张量打印到控制台?

1 个答案:

答案 0 :(得分:0)

您没有将map返回的数据集分配给任何对象。简单地做,

dataset = tf.data.Dataset.from_tensor_slices(myData)
dataset = dataset.map(myFunc)