如何将字典传递给tensorflow py_func

时间:2018-02-20 09:56:08

标签: python tensorflow

使用tensorflow数据集API时,map函数接受lambda, python function

如果map中的python函数进行数据库调用,py_func可以包装python函数。当py_func参数接收字典时,inp函数会抛出。

### Dict in Tensorflow
def parse(data):
    # Make some DB calls with data, but here just add the values.
    return data[0] + data[1]

iterator = tf.data.Dataset.range(10).map(
     lambda x: {0: [x * 2], 1: [x ** 2]}).map(lambda x: 
     tf.py_func(parse, inp=[x], 
        Tout=tf.int64)).make_initializable_iterator()

init_op = iterator.initializer
get_next = iterator.get_next()

with tf.Session() as sess:
    sess.run(init_op)
    print(sess.run(get_next))

回溯

TypeError                                 Traceback (most recent call last)
 ~/miniconda3/envs/metadata/lib/python3.6/site-packages/tensorflow/python/framework/tensor_util.py in make_tensor_proto(values, dtype, shape, verify_shape)
467     try:
--> 468       str_values = [compat.as_bytes(x) for x in proto_values]
469     except TypeError:

~/miniconda3/envs/metadata/lib/python3.6/site-packages/tensorflow/python/framework/tensor_util.py in <listcomp>(.0)
467     try:
 --> 468       str_values = [compat.as_bytes(x) for x in proto_values]
469     except TypeError:

~/miniconda3/envs/metadata/lib/python3.6/site-packages/tensorflow/python/util/compat.py in as_bytes(bytes_or_text, encoding)
 64     raise TypeError('Expected binary or unicode string, got %r' %
 ---> 65                     (bytes_or_text,))
      66 

TypeError: Expected binary or unicode string, got {0: <tf.Tensor 'arg0:0' shape=(1,) dtype=int64>, 1: <tf.Tensor 'arg1:0' shape=(1,) dtype=int64>}

解决方法是将字典中的键和值作为单独的参数传递。

py_func可以将字典作为调用函数的参数传递吗?

Tensorflow版本== 1.4,Python == 3.5

0 个答案:

没有答案