数据集API RecursionError:超过最大递归深度

时间:2019-05-08 12:59:15

标签: python-3.x tensorflow tensorflow-datasets

我正在尝试使用tensorflow数据集API和地图函数构建可扩展的最小最大缩放器。

首先,我遍历数据集以找到所有特征的最小值和最大值(3),然后我想使用map函数将最小值/最大值缩放器应用于数据集。

这是我的简单代码。

import numpy as np
import tensorflow as tf

b = np.array([[1, 2, 3], [4, 5, 6], [7,8,9],[10,11,12]])
b_ds = tf.data.Dataset.from_tensor_slices(b).batch(2)

my_iterator = b_ds.make_one_shot_iterator()

def compute_min_max(i, my_min, my_max):
    new_batch = my_iterator.get_next()
    my_min = tf.minimum(my_min,tf.reduce_min(new_batch, axis=0))
    my_max = tf.maximum(my_max,tf.reduce_max(new_batch, axis=0))
    return [i+1, my_min, my_max]

i = tf.constant(0)
feat_min = tf.Variable([10,10,10],dtype=tf.int64)
feat_max = tf.Variable([0,0,0],dtype=tf.int64)

c = lambda i, min, max: i < 2
b = lambda i, min, max: compute_min_max(i, min, max)
res_i, res_min, res_max = tf.while_loop(c, b, loop_vars=[i, feat_min, feat_max])

def min_max_ds(feat):
    return tf.cast(feat-res_min,dtype=tf.float64)/tf.cast(res_max-res_min, dtype=tf.float64)

minmax_scaled_ds = b_ds.map(min_max_ds)

scaled_batch = minmax_scaled_ds.make_one_shot_iterator().get_next()

with tf.Session() as sess:
    init=tf.global_variables_initializer()
    sess.run(init)
    print(sess.run((res_min, res_max, scaled_batch)))

执行此代码时,我得到一个

  

RecursionError:超过最大递归深度

我的猜测是min_max_ds函数大约每批都会调用tf.while_loop语句,但是我无法弄清楚如何冻结res_min和res_max以便它们在min_max_ds函数中用作常量。

1 个答案:

答案 0 :(得分:0)

也许您可以使用以下方式为递归深度设置更高的限制:

sys.setrecursionlimit(10000)

python3.X中的默认值为1000。也许可以使用更大的值。